Linear Attention Approximations
Explore linear complexity attention mechanisms including Performer, Linformer, and other efficient transformers that scale to very long sequences.
Best viewed on desktop for optimal interactive experience
Linear Attention: From O(n²) to O(n)
Linear attention approximations break the quadratic complexity barrier of self-attention, enabling transformers to process sequences with millions of tokens while maintaining reasonable quality.
Interactive Linear Attention Explorer
Compare different linear attention methods and their trade-offs:
Visualizing the n² Bottleneck
Watch how the attention matrix grows quadratically
Standard Attention: (Q × K^T) × V
- • Must compute 8×8 = 64 attention scores
- • Must store 8×8 = 64 values in memory
- • For n=16,384 tokens: 268M values = 1GB+ memory just for attention!
The Kernel Trick: Reverse the Multiplication Order
Compute Q × (K^T × V) instead of (Q × K^T) × V
Linear Attention: φ(Q) × (φ(K)^T × V)
- • Creates 8×8 matrix
- • Memory: O(n²)
- • Computation: O(n²d)
- • Creates 4×4 matrix
- • Memory: O(d²)
- • Computation: O(nd²)
Different Ways to Achieve Linearity
See how each method transforms the matrices
The Linearization Problem
Standard attention has quadratic complexity:
The bottleneck is the n × n attention matrix. Linear methods approximate or avoid computing it explicitly.
Major Linear Attention Methods
1. Performer (FAVOR+)
Key Idea: Approximate softmax kernel using random Fourier features
Architecture:
- Use random projection matrix for feature mapping
- Orthogonal or Gaussian random features
- FAVOR+ algorithm for positive random features
Core Steps:
- Create random projection matrix (orthogonal preferred)
- Project Q and K to random feature space: φ(Q), φ(K)
- Apply feature map (ReLU + small constant for positivity)
- Compute using associative property: φ(Q) @ (φ(K)^T @ V) instead of (φ(Q) @ φ(K)^T) @ V
- Normalize with sum of K features
Complexity: O(nkd) where k = number of random features (typically 256)
2. Linformer
Key Idea: Attention is approximately low-rank, so project K,V to smaller dimension
Architecture:
- Learnable projection matrices E and F
- Project sequence length dimension from n → k
- Apply standard attention in lower-dimensional space
Core Steps:
- Compute Q, K, V normally
- Project K using E: K' = E × K (reduces seq_len from n to k)
- Project V using F: V' = F × V (reduces seq_len from n to k)
- Compute attention: softmax(Q × K'^T / √d_k) × V'
- Output has original sequence length n
Complexity: O(nkd) where k = projection dimension (typically 256-512)
3. Linear Transformer
Key Idea: Simple kernel trick with φ(x) = elu(x) + 1 feature map
Architecture:
- Minimal changes to standard attention
- Use ELU activation + 1 as feature map
- Leverages associative property of matrix multiplication
Core Steps:
- Compute Q, K, V normally
- Apply feature map: φ(Q) = elu(Q) + 1, φ(K) = elu(K) + 1
- Non-causal: Compute KV = φ(K)^T @ V, then output = φ(Q) @ KV
- Causal: Use cumulative sum trick for O(1) per-token generation
- Normalize by sum of K features
Complexity: O(nd²) where d = head_dim
Special Feature: Causal masking with constant-time decoding via cumulative sums
4. Cosformer
Key Idea: Decompose attention using cosine basis functions
Architecture:
- Apply cosine-based reweighting to Q and K
- Use position-dependent weights
- Combines benefits of positional encoding with linear complexity
Core Steps:
- Pre-compute cosine weights: cos(πi/2n) for position i
- Apply weights to Q and K
- Use linear attention mechanism on weighted features
- Normalize output
Complexity: O(nd log n) due to cosine computation
Benefit: Better captures positional information than simple feature maps
Comparison of Methods
Method | Kernel φ(x) | Time | Memory | Quality | Key Idea |
---|---|---|---|---|---|
Performer | Random Fourier | O(nkd) | O(nk) | 95% | FAVOR+ algorithm |
Linformer | Low-rank projection | O(nkd) | O(kd) | 92% | Attention is low-rank |
Linear | elu(x) + 1 | O(nd²) | O(d²) | 90% | Simple feature map |
Cosformer | Cosine reweighting | O(nd log n) | O(nd) | 93% | Decomposition |
Flash* | Exact softmax | O(n²d) | O(n) | 100% | IO-aware |
*Flash is not linear but achieves linear memory through tiling.
Advanced Techniques
Hybrid Approaches
Combine Local and Global Attention:
- Split heads between local windowed attention and linear global attention
- Local attention: High quality for nearby tokens
- Global attention: Linear complexity for long-range dependencies
- Combine outputs additively or with learned gating
Benefits:
- Best of both worlds: quality + efficiency
- Tuneable quality-speed trade-off
- Used in Longformer, BigBird
Learned Feature Maps
Adaptive Kernel Learning:
- Replace fixed feature maps (elu, ReLU) with learned neural networks
- Small MLP projects queries/keys to feature space
- Use softplus activation to ensure positivity
- Can adapt to data distribution during training
Trade-off: More parameters but potentially better approximation
Implementation Tips
1. Numerical Stability
Add Small Epsilon to Normalizer:
- Prevent division by zero
- Use eps=1e-6 typically
- Apply before division: out / (normalizer + eps)
Stable Computation Order:
- Compute KV matrix first: einsum('...nd,...ne->...de', k, v)
- Then apply queries: einsum('...nd,...de->...ne', q, kv)
- Normalize with sum of k features
2. Memory-Efficient Training
Chunked Processing:
- Process queries in chunks of 64-128 tokens
- Compute full KV matrix once
- Apply each query chunk separately
- Concatenate results
Benefits:
- Reduces peak memory usage
- Allows training with limited GPU memory
- Minimal overhead
3. Causal Masking
Cumulative Sum Trick:
- Maintain running sums: kv_cumsum and k_cumsum
- For each position i, update sums with current K,V
- Compute output using only cumulative values
- O(1) complexity per token for generation
Implementation:
- Initialize sums to zero
- Loop through sequence positions
- Update sums incrementally
- Apply query to cumulative sums
Production Configurations
Performer
Typical Configuration:
- d_model: 1024
- n_heads: 16
- nb_features: 256 (random feature dimension)
- use_orthogonal: True (orthogonal random features)
- redraw_features: True (periodically redraw for better approximation)
- redraw_interval: 1000 steps
Linformer
Typical Configuration:
- d_model: 768
- n_heads: 12
- seq_len: 4096 (fixed maximum sequence length)
- k: 256 (projection dimension, ~6% of seq_len)
- share_projections: True (share E, F across layers to save parameters)
Linear Transformer
Typical Configuration:
- d_model: 512
- n_heads: 8
- feature_map: "elu" (can use "relu", "gelu")
- eps: 1e-6 (numerical stability)
- use_rotary_embeddings: True (better positional encoding)
Best Practices
Choosing the Right Method
Decision Guide:
Quality > 98% required:
- Use FlashAttention (exact attention with linear memory)
Very long sequences (>10K tokens):
- Use Performer (best scaling for ultra-long contexts)
Quality acceptable at 90-92%:
- Use Linear Transformer (simplest, fastest)
Tight memory budget:
- Use Linformer (most memory efficient, low-rank projections)
Balanced requirements:
- Use Cosformer (good quality-efficiency trade-off)
Combining with Other Techniques
Layerwise Strategy:
- Early layers: Local/sliding window attention (capture fine-grained patterns)
- Middle layers: Linear attention (efficient long-range modeling)
- Final layers: Full or sparse attention (high-quality output)
Benefits:
- Leverages strengths of each method
- Progressive refinement of representations
- Optimal quality-speed trade-off