A Very Long Span of Attention
A primer on important optimization done on top of our attention mechanism, including methods like MHA, MQA, GQA, SWA, sparse and low-rank attention, Flash-attention-v1 and v2, Paged-Attention, etc.
Table of content
Advanced Attention Architectures (intro)
0. Introduction
Attention has quietly become the workhorse of modern AI. What started as a clever trick in sequence-to-sequence translation has now grown into a universal recipe for intelligence: from Large Language Models (LLMs) like GPT and LLaMA, to Vision-Language Models (VLMs) like CLIP and BLIP, to speech, audio, protein folding, medical imaging, even satellite data processing. Everywhere you look, attention is the common thread.
But what is it about attention that makes it so powerful? And why is it that every new paper whether on efficient Transformers, long-context models, or diffusion architectures seems to circle back to rethinking attention?
In this article, we’ll peel back the layers of the attention mechanism. We’ll start from the basics: how compute is measured, why memory is a bottleneck, and what parallelism tricks GPUs use to survive attention-heavy workloads. We’ll then dive into the core equations of attention, explore its pros and cons, and revisit the first additive and multiplicative forms (Bahdanau, Luong). From there, we’ll walk across architectures like CNNs, RNNs, and Transformers, before exploring efficient variants like MHA, MQA, GQA, and SWA. We’ll also look at approximation tricks like sparse and low-rank attention, and end with bleeding-edge ideas like Flash-Attention and Paged-Attention.
By the end, you should not only understand how attention works, but also why it struggles to scale and what researchers are doing to fix it.
1. Prerequisites
Before diving into attention, we need to build some quick intuition about compute what it really means, how it’s measured, and why GPUs and memory systems dictate so much of how modern large scale models are designed.
How Compute is Measured
In deep learning, compute is often reported in FLOPs (Floating Point Operations). One FLOP is a single addition or multiplication of two floating-point numbers. Since models run billions or even trillions of these operations, we usually scale up to:
GFLOPs (10⁹) : billions of operations
TFLOPs (10¹²) : trillions of operations
PFLOPs (10¹⁵) : quadrillions of operations
When people say “this model costs 1000 PFLOPs to train,” they mean: if you count every multiply-add across all layers, activations, and forward+backward passes, it sums up to that scale.
But FLOPs only tell part of the story. What really matters is:
Throughput (how many FLOPs per second a GPU can sustain).
Memory bandwidth (how fast data can be moved between memory and compute cores).
Often, models aren’t compute-bound, they’re memory-bound. The GPU has enough ALUs to perform trillions of FLOPs, but if the data doesn’t arrive fast enough, cores sit idle. This becomes very important for attention, where we shuffle large matrices around.
A Quick Look at GPU Architecture
Modern GPUs (like NVIDIA A100, H100) are built as massively parallel processors:
They contain tens of thousands of CUDA cores, organized into Streaming Multiprocessors (SMs). Each SM can handle many threads in parallel.
They also have specialized units like Tensor Cores which are designed to accelerate dense matrix multiplications, the heart of deep learning.
Memory hierarchy:
SRAM (GPU On-Chip Memory)
SRAM is the top of the pyramid, fast, small, and used directly for computations. This includes registers, shared memory, and L1/L2 caches built right into the GPU chip. With bandwidths up to ~19 TB/s but only tens of MBs in capacity, SRAM is where active math happens. Every multiply-add, every attention scores; all of them flows through here first. The speed is unmatched, but the trade-off is size: once data spills beyond SRAM, performance slows dramatically.
HBM (High Bandwidth Memory)
The middle tier is HBM, the GPU’s dedicated off-chip memory. Think of it as the GPU’s “working desk” for large tensors like model weights, activations, and the KV cache in transformers. Bandwidth is still massive (≈ 1.5 TB/s on modern GPUs), and capacity goes into the tens of GBs (40–80 GB typical). However, it’s much slower than SRAM (roughly 10–100× slower). Efficient models and kernels, such as Flash Attention, are designed to minimize round-trips to HBM by squeezing as much work as possible into SRAM before going back.
Main Memory (CPU DRAM)
At the bottom of the pyramid sits system RAM, connected to the GPU through PCIe or NVLink. Its bandwidth is only around 12.8 GB/s, far below HBM. This is where data overflows when GPU memory runs out, or where offloaded parameters live in techniques like ZeRO-offload. While it gives you scale (hundreds of GBs to multiple TBs), every transfer across this bus is expensive. Training slows significantly if your workload constantly swaps between HBM and DRAM.
In short: SRAM is lightning-fast but tiny, HBM balances size and speed, and DRAM gives scale but at a heavy cost. Optimizing attention and large-model training is all about reshaping computations so they live closer to the top of this pyramid. When you run attention, QKᵀ involves huge matrix multiplications. If intermediate results don’t fit into cache, they spill into HBM, costing time and energy. That’s why many efficient attention methods exist: to reduce these transfers.
The Out-of-Memory (OOM) Problem
As models scale, the biggest headache is OOM errors: “CUDA out of memory.” Why does this happen?
Each activation in the forward pass needs to be stored for backprop.
Attention is quadratic in sequence length ( O(n²)). If n = 32k tokens, memory can explode into hundreds of GB.
GPUs typically have 40–80 GB of HBM. Training GPT-4 scale models requires model-parallel setups with hundreds of GPUs.
So, compute isn’t just about how many FLOPs we could do, but also it’s about fitting the workload inside limited GPU memory efficiently.
Tensor Parallelism (and Why It Helps)
To deal with these limits, we split tensors across devices:
Data Parallelism: split the batch across GPUs (each GPU gets different samples).
Model Parallelism: split the model parameters.
Tensor Parallelism: split large matrix multiplications across GPUs.
Example: In attention, multiplying Q (batch × seq × d) with Kᵀ (seq × d) is too large for one GPU. With tensor parallelism, each GPU holds a shard of K and computes part of the dot product, then results are aggregated.
This allows us to train models that are 100B+ parameters large, but at the cost of more communication overhead between GPUs.
Why all this matters for attention?
Because attention layers sit right at the intersection of FLOPs and memory:
They’re FLOP-heavy (big matrix multiplications).
They’re memory-heavy (O(n²) storage for attention scores).
That’s why efficient attention is not just a theoretical exercise, but, it’s a survival tactic for scaling models.
2. Setting up the Context
Why do we even need attention?
Before attention, most architectures (CNNs, RNNs) struggled with long-range dependencies. RNNs, for example, had to process sequences step by step, making it hard to remember things said 100 tokens ago. Attention fixes this by allowing direct interactions between any two positions in the sequence, regardless of distance. Instead of “remembering” everything through hidden states, attention explicitly looks back at all relevant tokens and decides which ones matter.
The Core Attention Equation
At the heart of attention lies a simple weighted lookup table.
Given queries (Q), keys (K), and values (V):
QK^T: Computes similarity between queries and keys.
softmax: Normalizes these similarities into a probability distribution.
Multiplying with V: Extracts information from values weighted by importance.
Also, lets setup a quick terminology as well,
We will use these S, P, O throughout the article. Also, the function f, g are dot product in current attention as well know, but, could be any aggregation method (imagine this).
This gives the model a flexible way to focus on specific tokens while ignoring irrelevant ones.
Advantages of Attention
Parallelization: Unlike RNNs, attention allows looking at all tokens at once.
Long-Range Dependency Capture: Any token can attend to any other.
Interpretability: Attention scores show what the model is focusing on.
Disadvantages of Attention
Quadratic Complexity: Computing QK^T for all token pairs is O(n^2) in both time and memory.
Memory Bottlenecks: Storing the entire attention matrix becomes infeasible for very long contexts.
Redundant Attention: Not all tokens are equally important, yet the mechanism wastes compute on all of them.
But, Why do these disadvantages exist? (The role of softmax & tradeoffs)
At the core, softmax attention requires normalization over all tokens. Every query must “see” every key to compute valid probabilities. This global normalization step is what forces quadratic scaling.
Time vs Memory tradeoff: KV cache accelerates computation but shifts the burden to memory.
Softmax rigidity: Even if most tokens are irrelevant, softmax doesn’t skip them, it still computes attention scores for all.
Example at Scale
Imagine a 100K-token context window (like some of today’s long-context LLMs).
Standard attention requires computing a 100K × 100K matrix = 10 billion interactions per layer.
Storing this matrix, even in half precision, would demand ~40 GB per layer, that is already at the edge of a single A100 GPU.
And that’s just one layer. Multiply by 48 layers, and it becomes impossible without approximation.
KV Caching: The inference-time Fix
During inference, re-computing QK^T for every new token would be extremely costly. KV caching solves this by storing past keys and values. When a new token arrives, only its query interacts with the cached K and V, rather than recomputing everything from scratch. Though there is a trade-off: KV cache saves time but eats memory, especially for long sequences (every token adds more keys & values to memory). You can read a detailed version in one of our article here : MHLA
In short: Attention is powerful because it breaks the sequential bottlenecks, but its quadratic nature makes it unsustainable at scale. The rest of this article is essentially about clever tricks (software or hardware level) to keep its benefits while taming its costs.
3. Additive and Multiplicative Attention
Now that we know why attention exists, do you remember the statement where I mentioned about f and g functions? (that idea will come into picture now) and how the vanilla QKᵀ softmax formulation works (and why it struggles at scale), let’s rewind a bit to the early forms of attention the first flavors that made sequence-to-sequence models actually work: Additive (Bahdanau) Attention and Multiplicative (Luong) Attention.
Think of this as the first experiments in “how do we make a model focus?”, long before Transformers became mainstream.
3.1 Additive Attention (Bahdanau)
Additive attention was introduced by Bahdanau et al, 2015, for machine translation. Its key idea is that instead of a simple dot-product, compute a learned scoring function that combines the query and key. Its necessarily a shifting/translate operation, where encoder hidden state vector is shifted in feature space by decoder state vector (imagine walking in a cartesian space).
Mathematically:
Where:
hj = encoder hidden states (keys/values)
sj = decoder hidden state (query)
v, W1, W2 are learnable parameters
αij = attention weights
ci = context vector fed to the decoder
Imagine each query si asking “how much should I care about each encoder state hj?” The tanh + linear layer acts like a scaling and projections, deciding attention scores flexibly.
Pros:
Flexible: learnable scoring function can capture complex relations.
Works well on small-to-medium sequences.
Cons:
Slower at scale: as it is computing W1hi+W2sj for all pairs is O(n²) in sequence length.
More parameters to train (v, W₁, W₂).
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdditiveAttention(nn.Module):
def __init__(self, query_dim, key_dim, hidden_dim):
super().__init__()
self.W_q = nn.Linear(query_dim, hidden_dim)
self.W_k = nn.Linear(key_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, 1)
def forward(self, query, keys, values):
# query: [batch, query_dim], keys/values: [batch, seq, key_dim]
q_proj = self.W_q(query).unsqueeze(1) # [batch, 1, hidden_dim]
k_proj = self.W_k(keys) # [batch, seq, hidden_dim]
scores = self.v(torch.tanh(q_proj + k_proj)).squeeze(-1) # [batch, seq]
attn_weights = F.softmax(scores, dim=-1)
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)
return context, attn_weights
3.2 Multiplicative Attention (Luong)
Luong et al in 2015, proposed multiplicative (dot-product) attention, again playing with our f and g, which is computationally cheaper (as there is just one set of params to learn : W) and works well when query and key dimensions match. Unlike additive attention, this actually scales/rotates the feature vector, which is a much sophisticated operation.
Instead of learning a fancy scoring function, we just measure “alignment” via dot products like checking how parallel two vectors are less flexible than Bahdanau, but much faster.
class MultiplicativeAttention(nn.Module):
def __init__(self, query_dim, key_dim):
super().__init__()
self.W = nn.Linear(query_dim, key_dim, bias=False)
def forward(self, query, keys, values):
scores = torch.bmm(self.W(query).unsqueeze(1), keys.transpose(1, 2)).squeeze(1)
attn_weights = F.softmax(scores, dim=-1)
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)
return context, attn_weights
3.3 Bahdanau vs Luong: A Quick Comparison
In summary, Bahdanaus method shifts and Luongs method scales. Both are fine, but for a normalized representation (which is often the case), Luong-like mechanism makes more sense (as it is like we are in a polar space vs a cartesian system as in Bahdanaus system). This paved the way for scaled dot-product attention in Transformers, which merges speed + reasonable flexibility. Some food for thought: How would another operation like concatenation work here?
4. Attention Across Architectures
Attention isn’t just a Transformer thing, it has been a universal idea for focusing computation on what matters. Let’s see how it appears across different architectures and why it behaves differently in each.
4.1 CNNs + Attention
Traditionally, CNNs focus on local neighborhoods using convolutional kernels. But what if we want global context?
Attention in CNNs:
Called “Self-Attention” or Non-Local Blocks in vision.
Each spatial position attends to all others to capture long-range dependencies.
Mathematically, for an image feature map X∈R (shape of X:H×W×C):
Where Q, K, V are linear projections of flattened X (shape: HW×C).
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.q = nn.Conv2d(in_channels, in_channels, 1)
self.k = nn.Conv2d(in_channels, in_channels, 1)
self.v = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
B, C, H, W = x.shape
q = self.q(x).view(B, C, -1).permute(0, 2, 1) # [B, HW, C]
k = self.k(x).view(B, C, -1) # [B, C, HW]
v = self.v(x).view(B, C, -1).permute(0, 2, 1) # [B, HW, C]
attn = F.softmax(torch.bmm(q, k), dim=-1)
out = torch.bmm(attn, v).permute(0, 2, 1).view(B, C, H, W)
return outIntuitively, Instead of a pixel only seeing its neighbors, it can “peek” at all other pixels useful for tasks like segmentation or super-resolution (think about saliency maps).
4.2 RNNs + Attention
RNNs process sequences step by step, making long-range dependencies hard to capture. Attention was first introduced here (Bahdanau & Luong, which we saw above) to help.
Self-Attention: Each hidden state can attend to all its previous states (attention over samples from same distribution).
Decoder (sort of cross) Attention: Decoder queries attend over all encoder states (sequence-to-sequence). Attention over samples from two distributions (which might or might not be alike)
Intuition: Attention acts like a “shortcut memory,” bypassing the step-by-step bottleneck of RNNs.
class RNNAttention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W_q = nn.Linear(hidden_dim, hidden_dim)
self.W_k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, 1)
def forward(self, query, keys, values):
q = self.W_q(query).unsqueeze(1)
k = self.W_k(keys)
scores = self.v(torch.tanh(q + k)).squeeze(-1)
attn_weights = F.softmax(scores, dim=-1)
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)
return context, attn_weights
4.3 Transformers
This is the architecture most associated with attention. Key characteristics:
Self-Attention: Each token attends to all others.
Multi-Head Attention: Multiple independent attention heads capture different relationships.
Position Encoding: Adds order information since self-attention is permutation-invariant.
Scaled Dot-Product Attention Equation:
Scaling by sqrt(dk) prevents softmax from saturating with large dot products. Multi-heads let the model “look at different aspects” of the sequence simultaneously.
class TransformerAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.q(x)
K = self.k(x)
V = self.v(x)
scores = torch.bmm(Q, K.transpose(1, 2)) / (x.size(-1)**0.5)
attn_weights = F.softmax(scores, dim=-1)
out = torch.bmm(attn_weights, V)
return out, attn_weights
4.4 Self-Attention vs Cross-Attention
Attention isn’t one-sequence-fits-all. The way queries, keys, and values are selected from in-distribution vs out-distribution gives rise to self-attention and cross-attention. Understanding this distinction is critical, especially for Transformers and multi-modal models. Lets see a code block first:
import torch
import torch.nn as nn
import torch.nn.functional as F
# apologies for mixing up Snake and Camel case :-C
class SelfAttention_vs_CrossAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
def forward_self_attention(self, x):
Q = self.q(x)
K = self.k(x)
V = self.v(x)
scores = torch.bmm(Q, K.transpose(1, 2)) / (x.size(-1)**0.5)
attn_weights = F.softmax(scores, dim=-1)
out = torch.bmm(attn_weights, V)
return out, attn_weights
def forward_cross_attention(self, x_target, x_source):
Q = self.q(x_target)
K = self.k(x_source)
V = self.v(x_source)
scores = torch.bmm(Q, K.transpose(1, 2)) / (x_target.size(-1)**0.5)
attn_weights = F.softmax(scores, dim=-1)
out = torch.bmm(attn_weights, V)
return out, attn_weights
# self-attention
B, seq_len, d_model = 2, 5, 16
x = torch.randn(B, seq_len, d_model)
attn = SelfAttention_vs_CrossAttention(d_model)
out, weights = self_attn.forward_self_attention(x)
print("Self-Attention output:", out.shape) # [B, seq_len, d_model]
# cross-attention
B, tgt_len, src_len, d_model = 2, 3, 5, 16
x_target = torch.randn(B, tgt_len, d_model)
x_source = torch.randn(B, src_len, d_model)
attn = SelfAttention_vs_CrossAttention(d_model)
out, weights = attn.forward_cross_attention(x_target, x_source)
print("Cross-Attention output:", out.shape) # [B, tgt_len, d_model]Using the above code as reference;
Self-Attention
Each element in the input sequence attends to other elements in the same sequence. Queries, keys, and values all come from the same source.
Imagine you are writing a summary of a book chapter. Each sentence looks at all other sentences from same book to decide which parts are important. That’s self-attention the sequence talking to itself.
Use Cases:
Transformer encoder layers (BERT, GPT)
Intra-sequence reasoning in text, vision, or audio
Cross-Attention
Queries come from one sequence (e.g., decoder), while keys and values come from another (e.g., encoder).
You’re again writing a summary but also have a reference book. Each sentence in your draft (query) looks at all sentences in the reference (keys/values) to decide what to include. That’s cross-attention, hence, an inter-sequence consultation.
Use Cases:
Transformer decoder layers (GPT, T5, BART)
Multi-modal tasks: text-to-image (Stable Diffusion), audio-visual fusion
5. Ideal Case: Scaling Attention Efficiently
We’ve seen attention’s power: it lets every token look at every other token. But this comes at a quadratic cost in memory and compute:
Where n = sequence length, d = hidden/embedding dimension.
For extremely long sequences (think 32k+ tokens or high-res images), vanilla attention becomes infeasible.
The ideal attention mechanism should satisfy:
Large context space preservation: every token can attend to all relevant tokens.
Linear or constant scaling: memory and compute scale linearly or sub-quadratically with sequence length.
Fast computation: leverage GPU efficiently.
Approximation friendly: allow sparsity, low-rank approximations, or streaming.
Let’s explore the ways to approach this.
5.1 Linear / Constant Scaling Approaches
Several tricks allow us to reduce the quadratic bottleneck:
Recurrent / Streaming Attention
Instead of storing full QKᵀ, maintain running summaries (like an RNN) to update context.
Complexity reduces to O(n · d²) instead of O(n² · d).
Low-Rank Factorization
Approximate attention matrix A = QK^T using low-rank decomposition:
Reduces memory and compute cost to O(n · k · d).
Sliding Window / Local Attention
Each token attends to nearby tokens only.
Complexity becomes O(n · w · d), where w = window size.
Think of this as “zooming in” on a local neighborhood first, then gradually expanding your view for long-range dependencies (similar to how CNN start from short range features/receptive field and then aggregate them over depth of model).
5.2 Surface-Level vs Root-Level Fixes
Researchers tackle scaling problems in two ways:
Surface-level fixes are like rearranging your desk to work faster. Root-level fixes is similar to redesigning your office layout to eliminate bottlenecks.
5.3 Approximating Attention in Practice
Example: Linear Attention
Vanilla attention:
Linear approximation:
Where ϕ(⋅) is a kernel feature map that allows associativity:(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV). Hence the memory & compute reduced from O(n²) to O(n).
In summary,
Quadratic attention is feasible for short sequences, but becomes memory-bound at scale.
Surface-level fixes like Sliding Window, MHA, KV caching help but only partially.
Root-level fixes like Flash Attention, low-rank methods, linear approximations unlock large-context capability while remaining GPU-friendly.
Hybrid strategies often combine local attention, linear approximations, and optimized kernels for best performance.
In the next section, we’ll explore specific efficient attention architectures (MHA, MQA, GQA, SWA), including equations, time/space complexity, and full code implementations.
6. Efficient Attention Architectures
Vanilla attention is powerful but quadratic in memory and compute. Over the years, researchers developed variants optimized for speed, memory, or both (as we idealized in previous section).
We saw two culprits/potential zones of optimization in our attention time complexity (N and d), d is generally of order 10^2 or 10^3 and N is much larger (10^5-10^7).
In this section, we’ll cover:
Optimization over d
Multi-Head Attention (MHA) : H-to-H
Multi-Query Attention (MQA) : H-to-1
Group-Query Attention (GQA) : H-to-K
Optimization over N
Sliding Window Attention (SWA) : local attention; N →W
Here, H is the number of heads and N is sequence length (You will get more context below). Let’s understand why they exist, how they work internally, equations, and code.
6.1 Multi-Head Attention (MHA)
The idea is simple, Instead of computing a single attention map, split the model into H heads. Each head learns to focus on different aspects of the sequence, this is born out of the fact, that our complexity term contains a “d” term (embedding dimension) which is generally in order of 10^2 or 10^3, which in itself adds to cost; hence, why not fix this first before fixing issues with sequences.
Think of multiple “eyes” looking at the same scene, one head focuses on verbs, another on nouns, another on different aspects like style.
Time & Space Complexity:
Time: O(H · n² · d/H) → O(n² · d) (same as vanilla, but parallelizable [see the Tensor Parallelism section above])
Space: O(n² · d)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_head = d_model // n_heads
self.n_heads = n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, seq_len, d_model = x.shape
Q = self.q(x).view(B, seq_len, self.n_heads, self.d_head).transpose(1,2)
K = self.k(x).view(B, seq_len, self.n_heads, self.d_head).transpose(1,2)
V = self.v(x).view(B, seq_len, self.n_heads, self.d_head).transpose(1,2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head**0.5)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1,2).contiguous().view(B, seq_len, d_model)
return self.out(out), attn6.2 Multi-Query Attention (MQA)
Idea lies at a simple observation, Use multiple query projections but share K and V across heads. Hence, this reduces memory footprint significantly because only one set of K, V needs to be stored. But, doing this is simply killing degree of freedom by a large factor (reducing number of trainable parameters), hence, performance also gets affected.
Imagine many people (queries) asking the same expert (shared K,V), it results in less memory used than each having their own expert (but puts theoretical and practical bounds on the expert).
Time & Space Complexity:
Time: O(H · n² · d/H)
Memory: O(n² · d/H) → ~H× savings vs MHA for K/V
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, self.d_head) # single K
self.v = nn.Linear(d_model, self.d_head) # single V
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, seq_len, d_model = x.shape
Q = self.q(x).view(B, seq_len, self.n_heads, self.d_head).transpose(1,2)
K = self.k(x).unsqueeze(1) # shared across heads
V = self.v(x).unsqueeze(1)
scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_head**0.5)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1,2).contiguous().view(B, seq_len, d_model)
return self.out(out), attn6.3 Group-Query Attention (GQA)
Above we saw two extreme ends, in MHA each head had seperate Q,K,V tensors; whereas in MQA we had single-shared K,V; but multiple Qs. What if we can do something in between? So, Why not Partition heads into groups, each with a shared K,V, hence, instead of N or 1 we have g (groups); which in itself is a hyper-parameter (something I personally don’t like). Hence, GQA is trade-off between MHA and MQA which simply balances between memory savings and expressivity.
This could be thought of like you have several query groups, each consulting a smaller team of experts, instead of everyone having separate experts (MHA) or a single expert (MQA).
Time & Space Complexity:
Memory: O(n² · d / g) where g = number of groups
class GroupQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_groups):
super().__init__()
assert n_heads % n_groups == 0
self.n_heads = n_heads
self.n_groups = n_groups
self.d_head = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, self.d_head * n_groups)
self.v = nn.Linear(d_model, self.d_head * n_groups)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, seq_len, d_model = x.shape
Q = self.q(x).view(B, seq_len, self.n_heads, self.d_head).transpose(1,2)
K = self.k(x).view(B, seq_len, self.n_groups, self.d_head).transpose(1,2)
V = self.v(x).view(B, seq_len, self.n_groups, self.d_head).transpose(1,2)
# assign each head to a group
group_size = self.n_heads // self.n_groups
out_heads = []
attn_list = []
for i in range(self.n_groups):
Qg = Q[:, i*group_size:(i+1)*group_size]
Kg = K[:, i]
Vg = V[:, i]
scores = torch.matmul(Qg, Kg.transpose(-2,-1)) / (self.d_head**0.5)
attn = F.softmax(scores, dim=-1)
out_heads.append(torch.matmul(attn, Vg))
attn_list.append(attn)
out = torch.cat(out_heads, dim=1).transpose(1,2).contiguous().view(B, seq_len, d_model)
return self.out(out), attn_list6.4 Sliding Window Attention (SWA)
Till now we saw efforts on “d” and corresponding heads, but, what if we do something about N itself. This is where SWA comes in picture, the idea is simple, instead of looking all tokens for Each token, why not just look at few of them, hence, SWA attends to local window of size w. This reduces quadratic complexity of N to linear in sequence length. This results in formation of a non-lower triangular matrix which we generally see in masked/auto-regressive/decoder-only attention.
You only “look around your neighborhood” instead of the entire sequence useful in long sequences like 64k tokens, again similar to CNNs receptive fields.
def sliding_window_attention(Q, K, V, window_size):
B, seq_len, d = Q.shape
out = torch.zeros_like(Q)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
scores = torch.bmm(Q[:, i:i+1], K[:, start:end].transpose(1,2)) / (d**0.5)
attn = F.softmax(scores, dim=-1)
out[:, i:i+1] = torch.bmm(attn, V[:, start:end])
return outEnhancements / Hybrid Approaches on top of SWA
Dilated / Strided Windows: Attend to every kth token to capture longer-range context.
Global Tokens: Some tokens (like CLS or important entities) attend to all others.
Hierarchical Attention: Combine SWA with occasional full attention layers.
Most of the current models use SWA or some variant of it instead of full-attention. For example;
Longformer: SWA + global attention
BigBird: SWA + random + global attention
GPT-3 / 64k context: sliding/dilated + block-sparse attention
In summary:
Multi-head attention is expressive but memory-heavy.
MQA and GQA trade memory for slight loss in expressivity.
SWA reduces memory/time from quadratic → linear.
Choosing the right mechanism depends on sequence length, GPU memory, and task.
7. Approximation / Bootstrapping Methods for Attention
When sequence lengths get very long, even Sliding Window Attention may not suffice. Quadratic memory and compute can still overwhelm GPUs. Researchers introduced approximation techniques that reduce attention cost by exploiting sparsity or low-rank structure in the attention matrix.
Core Idea
Attention can be seen as:
A is large for long sequences → O(n²) storage.
But in practice:
Many entries in A are near-zero (most tokens are irrelevant).
A often has low-rank structure, meaning the matrix can be approximated by a smaller set of bases.
Hence, Not every token cares about every other token. Approximation methods “ignore irrelevant attention” or “compress the attention matrix” without losing much performance.
7.1 Sparse Attention
Idea: Only compute attention for a subset of entries in A, not specifically following a pattern as we saw in SWA (local window), but, combining global, local and random (introduces long-range dependency cheaply) attention. This is like finding a sweet spot between locally and full-connected graph.
Si = subset of keys that token i attends to (local window, random selection, global tokens).
Reduces O(n²·d) → O(n·k·d) where k≪n.
Popular Sparse Attention Patterns:
Sliding window (covered in Section 6.4)
Strided / dilated attention
Random attention (BigBird)
Global tokens (CLS, special entities)
# Local + Global + sparse attention
import torch
import torch.nn.functional as F
import random
def sparse_attention(Q, K, V, window_size, num_random=2, global_idx=None):
"""
Q, K, V: [B, seq_len, d]
window_size: number of tokens to attend on each side (local window)
num_random: number of random tokens to attend per token
global_idx: index of global token(s) that attend to all tokens
"""
B, seq_len, d = Q.shape
out = torch.zeros_like(Q)
# Precompute random indices for each token
random_indices = []
for i in range(seq_len):
# Avoid choosing indices in local window
local_range = set(range(max(0, i-window_size), min(seq_len, i+window_size+1)))
possible_indices = list(set(range(seq_len)) - local_range)
if len(possible_indices) > 0:
random_indices.append(random.sample(possible_indices, min(num_random, len(possible_indices))))
else:
random_indices.append([])
for i in range(seq_len):
# Local window
start = max(0, i-window_size)
end = min(seq_len, i+window_size+1)
window_idxs = list(range(start, end))
# Random indices
rand_idxs = random_indices[i]
all_idxs = window_idxs + rand_idxs
q_slice = Q[:, i:i+1] # [B,1,d]
k_slice = K[:, all_idxs] # [B, w+r, d]
v_slice = V[:, all_idxs] # [B, w+r, d]
scores = torch.bmm(q_slice, k_slice.transpose(1,2)) / (d**0.5)
attn = F.softmax(scores, dim=-1)
out[:, i:i+1] = torch.bmm(attn, v_slice)
# Global token attention
if global_idx is not None:
if isinstance(global_idx, int):
global_idx = [global_idx]
for g_idx in global_idx:
Q_global = Q[:, g_idx:g_idx+1] # [B,1,d]
scores = torch.bmm(Q_global, K.transpose(1,2)) / (d**0.5)
attn = F.softmax(scores, dim=-1)
out[:, g_idx] = torch.bmm(attn, V)[:,0]
return outSparse attention is linear in sequence length, memory-friendly, and easy to combine with SWA or other patterns.
7.2 Low-Rank Attention
Idea is to Approximate attention matrix A = QK^T (which is a full-rank matrix) as a product of smaller matrices:
Exploits the low-rank nature of many real-world sequences (e.g., repeated patterns, redundant info).
Complexity reduces O(n²·d) → O(n·r·d).
Methods:
Nyström Method: Sample landmarks, compute attention via low-rank projection.
Linformer: Project K,V to lower dimension rrr via learned linear maps.
Performer / FAVOR+: Random feature approximation of softmax kernel.
Mathematical Formulation (Nyström Example):
Think of summarizing a crowd with a few representatives, you don’t exactly need every individual to compute attention, just few “landmarks” capture the essence.
# Linformer-style projection
class LowRankAttention(nn.Module):
def __init__(self, d_model, seq_len, r):
super().__init__()
self.r = r
self.E = nn.Linear(seq_len, r, bias=False) # project keys/values
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, n, d = x.shape
Q = self.q(x)
K = self.k(x)
V = self.v(x)
K_low = self.E(K.transpose(1,2)).transpose(1,2) # [B, n, r]
V_low = self.E(V.transpose(1,2)).transpose(1,2)
scores = torch.bmm(Q, K_low.transpose(1,2)) / (d**0.5)
attn = F.softmax(scores, dim=-1)
out = torch.bmm(attn, V_low)
return self.out(out), attn
r ≪ ncontrols memory/computation trade-off.Works well for long sequences with redundancy.
Low rank approximation necessarily kills irrelevant dimensions (where the information is not well spread), hence, we capture the main patterns of attention without wasting compute on near-zero or redundant correlations.
Example time;
That said, lets have a look at quick example to understand why this low-rank computation will save compute.
Scenario: Full Attention
Sequence length: n=100
Dimension: d=64
Attention matrix A=QK^T ∈ R (100×100)
Compute cost:
Memory needed to store 10,000 entries for A (assuming single precision, 4 bytes → 40 KB). Confused about this computation? check out our article here.
Low-Rank Approximation
Assume attention matrix is low-rank: rank (r)=10
Approximate:
Compute cost:
Savings: 10× less compute! Which is massive, given our choice of N, now imagine this for N of order 10^6, 10^7 or even larger.
Memory cost:
Store Q’ and K’:
100⋅10+100⋅10=2,000 => 8 KBInstead of 10,000 entries => 5× less memory
Till here, we saw all the issues, bottlenecks and then eventual optimization at the architectural level of attention. But, what if we go even beyond this? A layer deeper into not just masking the problem, but, actually fixing it (at the GPU allocation level)
8. GPU Memory & Compute Bottlenecks in Attention
Modern GPUs are extremely fast, but memory movement is often the true bottleneck, not raw FLOPs. Understanding this is key to grasping why attention scales poorly and why architectures like Flash and Paged Attention matter.
8.1 GPU Memory Hierarchy Recap
A simple observation is that even though ALUs / tensor cores can do trillions of FLOPs per second, they stall if data doesn’t arrive fast enough. Hence, movement from and to these cores serve as actual bottlenecks.
8.2 Read/Write Bottlenecks in Attention
Attention involves large matrix multiplications and softmax normalization, which creates multiple memory access layers:
Read Q, K, V from memory
For a 50k-token sequence, Q, K, V matrices are huge: O(n·d) each.
Fetching from HBM → many GBs moved per layer.
Compute QKᵀ
Requires n² multiplies.
Results stored in attention matrix A.
Softmax normalization
Softmax is element-wise, but requires row-wise sums → additional reads/writes.
Multiply with V to get O
Another read of V, multiply-add, write output O.
This leads to the problem, each read/write may involve moving data between SRAM ↔ HBM ↔ DRAM, costing 10–100× more time than actual computation. Intuitively, GPU is like a chef with blazing knives (ALUs) but tiny counter space (SRAM). If ingredients (Q, K, V) are in the pantry (HBM/DRAM), the chef spends more time fetching and returning ingredients than chopping.
8.3 Quadratic Scaling Worsens Bottlenecks
For sequence length n, attention matrix is n² entries.
Each entry is read/written at least twice: once for softmax, once for multiplying with V.
Example: n = 32k, d = 1024 → QKᵀ is ~1 TB in FP16!
Even if GPU has enough ALUs, HBM bandwidth can’t keep up, resulting in compute stalls.
Though techniques (as we saw above) like KV caching avoids recomputation. But storing all previous keys/values increases HBM usage => potentially spilling to DRAM. Hence, the Trade-off is to save compute vs. memory pressure.
In summary, we need a method to fix the compute allocations, read/writes between our HBM and SRAM, this will lot only fasten things up, but, will also be our next best attempt to make scalable architectures beyond optimizing attention (which we saw above).
9. Advanced Attention Architectures
This section would be just a intro section, as I myself am learning about the exact working, implementation and intuition behind these techniques, once done, I will write a much cleaner, elaborate and in-depth article on just these three techniques and surrounding prerequisites. That said, let’s get started.
9.1 Flash Attention
Reorganize computation to avoid storing the full attention matrix in memory. The we compute attention in tiles that fit in fast on-chip SRAM (shared memory). Softmax is computed on the fly (online softmax), per tile (This tiling operation is not very clear to me, hence, treat it as a black-box for now). Given this, the memory complexity: O(n·d) instead of O(n²).
You can think of this as “streaming the matrix” through fast cache, never holding the full n×n attention at once. Could also be perceived as how we do data-loading in torch/TF/keras, where the data sits in memory and we just load the required sample and perform computation on top of that only or something like gradient accumulation.
Equation-wise:
Flash attention computes this row-wise, caching only a small block of Q, K, V at a time.
Key Benefits:
Reduces memory footprint drastically, this allows longer sequences.
Leverages GPU’s SRAM and tensor cores efficiently.
Linear memory in N, still quadratic in compute, but compute is GPU-friendly.
9.2 Flash Attention v2
This is actually an enhancements over v1 (as the name suggests):
Optimized for fused operations and multi-head attention.
Reduces redundant memory writes and reads between registers, shared memory, and HBM.
Handles variable sequence lengths without padding overhead (typically done in the tiling part of v1).
Practical Impact:
It can train LLMs with 50k+ token context on a single A100/H100 GPU that would otherwise fail due to OOM.
This is 2x faster than Flash Attention and 9x faster than conventional Attention.
9.3 Paged Attention
Apart from all optimizations on allocation we saw till now, Paged Attention works over KV cache. In itself KV cache grows and shrink in size during inference, hence allocation becomes challenging (GPU likes fixed size pre-allocations). Hence, PA divides KV cache into fixed-size, memory-aligned blocks (pages), similar to virtual memory paging in OS.
Load a page of tokens into fast memory (SRAM / shared memory).
Compute attention only for that page + necessary overlapping context.
Evict page, then load next page.
This Reduces peak memory usage and external memory fragmentation → O(page_size²) instead of O(n²) and is also compatible with Flash Attention tiling, enabling long-context LLMs.
(more about each keyword and jargon in next blog soon, so, stay tuned)
10. Thoughts
Attention has come a long way: from a clever mechanism to sequence models superpower, to a general template for multi-modal reasoning, and now to highly optimized GPU-aware architectures. But even after all these innovations, several patterns and challenges stand out.
1. Memory Is the Hidden Bottleneck
We often talk about FLOPs and compute, but in practice, memory movement dominates runtime.
Reading Q, K, V from HBM or DRAM can stall even the fastest tensor cores.
KV caching solves recomputation but trades compute for memory.
Advanced methods like Flash Attention cleverly stream tiles through SRAM, showing that algorithmic genius is often about memory, not math.
Maybe the next big gains in LLMs won’t come from bigger models alone, but from memory-efficient attention primitives that let GPUs reach full utilization.
2. Approximation Works, but, It’s a Double-edged sword
Sparse attention, low-rank attention, and hybrid approaches like Big-Bird illustrate a key principle:
You don’t always need every token interacting with every other token.
Approximations reduce O(n²) → O(n·k) or O(n·r), saving compute and memory.
Yet, trade-offs remain:
Sparse patterns may miss subtle long-range dependencies.
Low-rank methods can smooth out fine-grained token interactions.
It’s like summarizing a conversation, you might capture the gist but may miss nuanced jokes or side remarks.
3. Hardware and Algorithm Co-Design Is Key
The rise of Flash Attention and Paged Attention shows a broader trend:
Algorithm design can’t ignore GPU architecture.
Techniques that look great on paper (O(n·r)) may underperform if they cause cache thrashing or excessive HBM reads.
Future attention mechanisms may co-optimize compute, memory, and communication simultaneously, especially for ultra-long sequences or multi-modal data.
4. Scaling Is a Multi-Dimensional Problem
Sequence length, model depth, hidden dimension, memory bandwidth, and compute all interplay.
Quadratic attention isn’t just a mathematical nuisance, it’s a practical limit enforced by hardware.
Every innovation (SWA, MQA, Flash, Paged) is essentially a creative workaround to this fundamental scaling challenge.
5. The Bigger Picture
Attention isn’t just about transformers anymore. Its principles generalize to vision, audio, proteins, and even satellite imagery.
The core idea that “focus where it matters, ignore the rest” is universal.
What remains is to scale this principle efficiently without exploding memory or compute costs. Perhaps the next revolution isn’t bigger models, but attention mechanisms that adapt dynamically, deciding which tokens, channels, or modalities truly deserve compute in real-time.
10. Conclusion
And that’s a wrap on the current attention discussion; the humble mechanism that quietly became the backbone of modern AI. What started as a neat trick for sequence-to-sequence translation has now spread its wings: LLMs, VLMs, audio, proteins, satellite imagery; attention is everywhere.
We began by understanding why compute and memory matter. Without knowing how GPUs tick; the SRAM lightning-fast counters, the HBM “desk space”, and the DRAM “warehouse”, scaling attention is like trying to carry an ocean in a teacup. Then we dug into the core: Q, K, V, softmax, and KV caches. Suddenly, that elegant formula looked deceptively simple; until the quadratic monster reared its head.
We retraced history through Bahdanau and Luong attention, additive vs multiplicative, and realized each choice is a trade-off between precision, memory, and compute. And then the scaling problem hit: long sequences, massive matrices, and GPUs that can compute trillions of FLOPs but spend half the time just fetching data from memory.
That’s where all the clever hacks come in: MHA, MQA, GQA, SWA, and the approximation tricks like sparse and low-rank attention. Each one is basically a way of saying: “Hey, not every token matters equally, hence, let’s be smart about where we spend our compute calories.” And finally, the hardware-smart strategies like Flash Attention, Flash Attention v2, and Paged Attention show us that sometimes, the solution isn’t just math; it’s thinking like a GPU, streaming tiles through fast memory, paging long sequences, and keeping tensor cores happily busy
(Though we are yet to have a proper deep-dive into these three, potentially in our next nlog).
The big picture: Attention isn’t just about seeing all the tokens; it’s about deciding what to focus on, and doing it efficiently, hence the real magic happens at the intersection of math, memory, and compute.
Some References:
GQA : https://arxiv.org/pdf/2305.13245
https://verticalserve.medium.com/group-query-attention-58283b337c65
MLHA and KV-cache : https://vizuara.substack.com/p/decoding-multi-head-latent-attention
Linformer : https://arxiv.org/abs/2006.04768
BigBird : https://arxiv.org/abs/2007.14062
Sparse attention : https://arxiv.org/abs/2406.16747
LongFormer : https://arxiv.org/abs/2004.05150
Data types and precision : https://vizuara.substack.com/p/4-bit-llm-training-and-primer-on
That's all for today.
Follow me on LinkedIn and Substack for more insightful posts, till then happy Learning. Bye👋























Thank you for thorough and organized information.