Linear Attention Approximations

Explore linear complexity attention mechanisms including Performer, Linformer, and other efficient transformers that scale to very long sequences.

Best viewed on desktop for optimal interactive experience

Linear Attention: From O(n²) to O(n)

Linear attention approximations break the quadratic complexity barrier of self-attention, enabling transformers to process sequences with millions of tokens while maintaining reasonable quality.

Interactive Linear Attention Explorer

Compare different linear attention methods and their trade-offs:

Visualizing the n² Bottleneck

Watch how the attention matrix grows quadratically

Standard Attention: (Q × K^T) × V

Step 1: Compute Q × K^T (Creates the n×n bottleneck!)
Q
8 × 4
K^T
4 × 8
Attention Matrix
8 × 8 ⚠️
The Problem: This 8×8 matrix requires 64 computations and memory! For n=16K tokens, that's 256M values!
Step 2: Multiply Attention Matrix × V
Attention
8 × 8
V
8 × 4
Output
8 × 4
Why This is O(n²):
  • • Must compute 8×8 = 64 attention scores
  • • Must store 8×8 = 64 values in memory
  • • For n=16,384 tokens: 268M values = 1GB+ memory just for attention!

The Kernel Trick: Reverse the Multiplication Order

Compute Q × (K^T × V) instead of (Q × K^T) × V

Linear Attention: φ(Q) × (φ(K)^T × V)

Step 1: Compute φ(K)^T × V FIRST (Avoids n×n matrix!)
φ(K)^T
4 × 8
V
8 × 4
KV Matrix
4 × 4
The Magic: This is only 4×4 = 16 values! For d=64, that's just 4,096 values regardless of sequence length!
Step 2: Multiply φ(Q) × KV (Still avoids n×n!)
φ(Q)
8 × 4
KV
4 × 4
Output
8 × 4
Why This is O(n) instead of O(n²):
❌ Standard Attention
  • • Creates 8×8 matrix
  • • Memory: O(n²)
  • • Computation: O(n²d)
✓ Linear Attention
  • • Creates 4×4 matrix
  • • Memory: O(d²)
  • • Computation: O(nd²)

Different Ways to Achieve Linearity

See how each method transforms the matrices

The key insight: Change multiplication order or compress dimensions to avoid materializing the n×n attention matrix.

The Linearization Problem

Standard attention has quadratic complexity:

Attention(Q,K,V) = softmax(QKT√(d))V

The bottleneck is the n × n attention matrix. Linear methods approximate or avoid computing it explicitly.

Major Linear Attention Methods

1. Performer (FAVOR+)

Key Idea: Approximate softmax kernel using random Fourier features

Architecture:

  • Use random projection matrix for feature mapping
  • Orthogonal or Gaussian random features
  • FAVOR+ algorithm for positive random features

Core Steps:

  1. Create random projection matrix (orthogonal preferred)
  2. Project Q and K to random feature space: φ(Q), φ(K)
  3. Apply feature map (ReLU + small constant for positivity)
  4. Compute using associative property: φ(Q) @ (φ(K)^T @ V) instead of (φ(Q) @ φ(K)^T) @ V
  5. Normalize with sum of K features

Complexity: O(nkd) where k = number of random features (typically 256)

2. Linformer

Key Idea: Attention is approximately low-rank, so project K,V to smaller dimension

Architecture:

  • Learnable projection matrices E and F
  • Project sequence length dimension from n → k
  • Apply standard attention in lower-dimensional space

Core Steps:

  1. Compute Q, K, V normally
  2. Project K using E: K' = E × K (reduces seq_len from n to k)
  3. Project V using F: V' = F × V (reduces seq_len from n to k)
  4. Compute attention: softmax(Q × K'^T / √d_k) × V'
  5. Output has original sequence length n

Complexity: O(nkd) where k = projection dimension (typically 256-512)

3. Linear Transformer

Key Idea: Simple kernel trick with φ(x) = elu(x) + 1 feature map

Architecture:

  • Minimal changes to standard attention
  • Use ELU activation + 1 as feature map
  • Leverages associative property of matrix multiplication

Core Steps:

  1. Compute Q, K, V normally
  2. Apply feature map: φ(Q) = elu(Q) + 1, φ(K) = elu(K) + 1
  3. Non-causal: Compute KV = φ(K)^T @ V, then output = φ(Q) @ KV
  4. Causal: Use cumulative sum trick for O(1) per-token generation
  5. Normalize by sum of K features

Complexity: O(nd²) where d = head_dim

Special Feature: Causal masking with constant-time decoding via cumulative sums

4. Cosformer

Key Idea: Decompose attention using cosine basis functions

Architecture:

  • Apply cosine-based reweighting to Q and K
  • Use position-dependent weights
  • Combines benefits of positional encoding with linear complexity

Core Steps:

  1. Pre-compute cosine weights: cos(πi/2n) for position i
  2. Apply weights to Q and K
  3. Use linear attention mechanism on weighted features
  4. Normalize output

Complexity: O(nd log n) due to cosine computation

Benefit: Better captures positional information than simple feature maps

Comparison of Methods

MethodKernel φ(x)TimeMemoryQualityKey Idea
PerformerRandom FourierO(nkd)O(nk)95%FAVOR+ algorithm
LinformerLow-rank projectionO(nkd)O(kd)92%Attention is low-rank
Linearelu(x) + 1O(nd²)O(d²)90%Simple feature map
CosformerCosine reweightingO(nd log n)O(nd)93%Decomposition
Flash*Exact softmaxO(n²d)O(n)100%IO-aware

*Flash is not linear but achieves linear memory through tiling.

Advanced Techniques

Hybrid Approaches

Combine Local and Global Attention:

  • Split heads between local windowed attention and linear global attention
  • Local attention: High quality for nearby tokens
  • Global attention: Linear complexity for long-range dependencies
  • Combine outputs additively or with learned gating

Benefits:

  • Best of both worlds: quality + efficiency
  • Tuneable quality-speed trade-off
  • Used in Longformer, BigBird

Learned Feature Maps

Adaptive Kernel Learning:

  • Replace fixed feature maps (elu, ReLU) with learned neural networks
  • Small MLP projects queries/keys to feature space
  • Use softplus activation to ensure positivity
  • Can adapt to data distribution during training

Trade-off: More parameters but potentially better approximation

Implementation Tips

1. Numerical Stability

Add Small Epsilon to Normalizer:

  • Prevent division by zero
  • Use eps=1e-6 typically
  • Apply before division: out / (normalizer + eps)

Stable Computation Order:

  • Compute KV matrix first: einsum('...nd,...ne->...de', k, v)
  • Then apply queries: einsum('...nd,...de->...ne', q, kv)
  • Normalize with sum of k features

2. Memory-Efficient Training

Chunked Processing:

  • Process queries in chunks of 64-128 tokens
  • Compute full KV matrix once
  • Apply each query chunk separately
  • Concatenate results

Benefits:

  • Reduces peak memory usage
  • Allows training with limited GPU memory
  • Minimal overhead

3. Causal Masking

Cumulative Sum Trick:

  • Maintain running sums: kv_cumsum and k_cumsum
  • For each position i, update sums with current K,V
  • Compute output using only cumulative values
  • O(1) complexity per token for generation

Implementation:

  • Initialize sums to zero
  • Loop through sequence positions
  • Update sums incrementally
  • Apply query to cumulative sums

Production Configurations

Performer

Typical Configuration:

  • d_model: 1024
  • n_heads: 16
  • nb_features: 256 (random feature dimension)
  • use_orthogonal: True (orthogonal random features)
  • redraw_features: True (periodically redraw for better approximation)
  • redraw_interval: 1000 steps

Linformer

Typical Configuration:

  • d_model: 768
  • n_heads: 12
  • seq_len: 4096 (fixed maximum sequence length)
  • k: 256 (projection dimension, ~6% of seq_len)
  • share_projections: True (share E, F across layers to save parameters)

Linear Transformer

Typical Configuration:

  • d_model: 512
  • n_heads: 8
  • feature_map: "elu" (can use "relu", "gelu")
  • eps: 1e-6 (numerical stability)
  • use_rotary_embeddings: True (better positional encoding)

Best Practices

Choosing the Right Method

Decision Guide:

Quality > 98% required:

  • Use FlashAttention (exact attention with linear memory)

Very long sequences (>10K tokens):

  • Use Performer (best scaling for ultra-long contexts)

Quality acceptable at 90-92%:

  • Use Linear Transformer (simplest, fastest)

Tight memory budget:

  • Use Linformer (most memory efficient, low-rank projections)

Balanced requirements:

  • Use Cosformer (good quality-efficiency trade-off)

Combining with Other Techniques

Layerwise Strategy:

  • Early layers: Local/sliding window attention (capture fine-grained patterns)
  • Middle layers: Linear attention (efficient long-range modeling)
  • Final layers: Full or sparse attention (high-quality output)

Benefits:

  • Leverages strengths of each method
  • Progressive refinement of representations
  • Optimal quality-speed trade-off

If you found this explanation helpful, consider sharing it with others.

Mastodon