Multi-Query Attention (MQA)

Understand Multi-Query Attention, the radical efficiency optimization that shares keys and values across all attention heads, enabling massive memory savings for inference.

Best viewed on desktop for optimal interactive experience

Multi-Query Attention: Maximum Efficiency Through Sharing

Multi-Query Attention (MQA) is a radical simplification of multi-head attention that shares a single set of keys and values across all query heads, achieving dramatic memory savings with acceptable quality trade-offs.

Interactive MQA Visualization

See how all query heads share the same keys and values:

The Problem: Memory Bottleneck in Multi-Head Attention

Traditional Multi-Head Attention (MHA) stores separate Key and Value projections for each attention head. With 8 heads across 32 layers, this creates a massive KV cache that limits batch sizes and increases latency.

MHA Cache Size
256 MB
Per sequence, 2K context
Max Batch Size
195
With 50GB GPU memory
Redundancy
8× K,V
Separate for each head

The Insight: Share Keys and Values Across All Heads

Multi-Query Attention (MQA) uses a radical simplification: maintain separate Query projections for each head, but share a single set of Keys and Values across all heads. This dramatically reduces memory while maintaining most of the model's expressiveness.

Key Difference:
Multi-Head Attention
Q1 → K1, V1
Q2 → K2, V2
Q3 → K3, V3
Q4 → K4, V4
...each head has its own K,V
Multi-Query Attention
Q1 → Kshared, Vshared
Q2 → Kshared, Vshared
Q3 → Kshared, Vshared
Q4 → Kshared, Vshared
...all heads share same K,V

Architecture Comparison: See the Difference

Compare how queries, keys, and values flow through MHA vs MQA. Notice how MQA's shared K,V pairs dramatically reduce the number of attention computations while maintaining head-specific queries.

Input (X)
Queries
Q1
Q2
Q3
Q4
Q5
Q6
Q7
Q8
Keys
K1
K2
K3
K4
K5
K6
K7
K8
Values
V1
V2
V3
V4
V5
V6
V7
V8

Massive Memory Savings: The Numbers

By sharing K,V across heads, MQA reduces the KV cache from O(n·h·d) to O(n·d), where h is the number of heads. With 8 heads, this means ~87.5% memory reduction with minimal quality loss.

MQA Cache
32 MB
vs 256 MB (MHA)
Memory Saved
87.5%
Reduction in cache size
Batch Size
1562×
vs 195× (MHA)

Real-World Impact: Production Models

MQA has been adopted by major language models including Google's PaLM and TII's Falcon. These models achieve 1.5-2× inference speedup with only ~2% quality degradation, making MQA ideal for large-scale serving where memory and throughput are critical.

Google PaLM
• 540B parameters
• 48 attention heads
• Uses MQA throughout
• ~96% cache reduction
TII Falcon
• 40B parameters
• 64 attention heads
• MQA architecture
• Tops open LLM benchmarks
Key Trade-offs:
✓ Benefits
• 87-96% memory savings
• 1.5-2× faster inference
• Larger batch sizes
~ Quality
• ~2% perplexity increase
• Minimal downstream impact
• Acceptable for most tasks
⚡ Use Cases
• Large-scale serving
• Long context windows
• Memory-constrained GPUs
Scroll through each section to understand how Multi-Query Attention achieves dramatic memory savings while maintaining model quality.

The Core Insight

Traditional Multi-Head Attention (MHA) maintains separate K, V projections for each head:

  • Memory: O(n · h · d)
  • Redundancy: Similar patterns learned across heads

MQA's breakthrough: One K, V pair serves all heads

  • Memory: O(n · d)
  • Efficiency: Up to 32× KV cache reduction

How MQA Works

The Architecture

MQA(X) = Concat(head1, ..., headh)WO

Where each head computes:

headi = Attention(Qi, Kshared, Vshared)

Key differences from MHA:

  • Queries: Still head-specific (Q1, Q2, ..., Qh)
  • Keys/Values: Shared across all heads (Kshared, Vshared)

Implementation Details

Key Architecture Components

Projection Layers:

  • Queries: Separate projection for each head (n_heads × d_model)
  • Keys: Single shared projection (1 × head_dim)
  • Values: Single shared projection (1 × head_dim)
  • Output: Standard projection combining all heads

Forward Pass Steps:

  1. Project input X to multi-head queries Q
  2. Project input X to single K and V (shared across heads)
  3. Expand K, V to match all query heads via broadcasting
  4. Compute scaled dot-product attention for each head
  5. Concatenate head outputs and apply final projection

KV Cache Management:

  • Cache stores only one K,V pair per layer (not per head)
  • Cache shape: [n_layers, 1, max_seq_len, head_dim]
  • Dramatically reduced memory: ~87-96% smaller than MHA
  • Simple concatenation for incremental decoding

Memory Savings Analysis

KV Cache Comparison

For a model with 32 heads, 40 layers, sequence length 2048, head dimension 128:

MethodCache Size per TokenTotal for 2K ContextReduction
MHA2 × 40 × 32 × 128 = 327,680 floats640 MB0%
MQA2 × 40 × 1 × 128 = 10,240 floats20 MB96.9%

Batch Serving Benefits

Calculating Maximum Batch Size:

Given GPU memory and model size, the KV cache determines maximum batch capacity:

Example: 80GB A100 GPU with 30GB Model

  • Available memory for cache: 50GB
  • MHA cache per sequence: 640 MB → Max batch: ~78 sequences
  • MQA cache per sequence: 20 MB → Max batch: ~2,500 sequences

Impact: MQA enables 32× larger batches for the same hardware, dramatically improving serving throughput.

Quality Considerations

The Trade-off

MQA trades expressiveness for efficiency:

Parameter Count Comparison:

  • MHA: n_heads × d_model × d_head × 3 (separate Q, K, V per head)
  • MQA: n_heads × d_model × d_head + 2 × d_model × d_head (Q per head, shared K,V)
  • Reduction: Approximately 66% fewer parameters in attention layers

Empirical Results

From the original paper (Shazeer, 2019):

ModelAttention TypePerplexitySpeed
BaseMHA10.21.0×
BaseMQA10.41.8×
LargeMHA8.11.0×
LargeMQA8.32.4×

Key findings:

  • Small quality loss (~2% perplexity increase)
  • Significant speed gains (1.8-2.4×)
  • Benefits scale with model size

Training Strategies

Training from Scratch

Configuration Adjustments:

  • Increase model dimension by ~10% to compensate for reduced parameters
  • Keep same number of query heads
  • Use single K,V projection (n_kv_heads = 1)
  • Reduce learning rate by ~10% for stability
  • Increase training steps by ~10% to converge

Fine-tuning from MHA

Conversion Strategy:

  1. Copy query projections directly from MHA
  2. Average K,V projections across all heads to create shared K,V
  3. Keep output projection unchanged
  4. Fine-tune with lower learning rate (1e-5)
  5. Usually converges within 5-10% of original training

Benefits of Uptraining:

  • Start from strong pretrained weights
  • Faster than training from scratch
  • Often achieves better final quality

Optimization Techniques

1. Fused Kernels

Efficient Broadcasting Strategy:

  • Expand K,V in-place using unsqueeze(1) for memory efficiency
  • Compute attention scores with broadcasting instead of explicit expansion
  • Single fused kernel reduces memory allocations
  • TorchScript compilation for additional speedup

Benefits:

  • Reduced memory bandwidth usage
  • Fewer kernel launches
  • Better cache utilization
  • 20-30% faster than naive implementation

2. Memory Layout Optimization

Optimized Weight Storage:

  • Pack query weights for better cache locality
  • Store K,V weights contiguously
  • Use einsum for efficient multi-head query computation
  • Minimize memory fragmentation

Performance Impact:

  • Improved CPU-to-GPU transfer efficiency
  • Better CUDA kernel occupancy
  • Reduced register pressure

3. Dynamic Batching

Adaptive Batch Processing:

  • Group requests by similar sequence lengths
  • Bucket lengths to reduce padding waste (e.g., round to nearest 128)
  • Pad sequences within each bucket
  • Process each bucket in single forward pass
  • Unpad results before returning

Throughput Gains:

  • Minimize wasted computation on padding
  • Better GPU utilization
  • 2-3× higher throughput for variable-length batches

When to Use MQA

Ideal Use Cases

Large-scale serving

  • Maximize throughput
  • Minimize latency
  • Large batch sizes

Memory-constrained environments

  • Edge devices
  • Mobile deployment
  • Limited GPU memory

Long context applications

  • Document processing
  • Multi-turn dialogue
  • Code generation

When to Avoid

Research/Experimentation

  • Need best quality
  • Small-scale deployment
  • Abundant resources

Latency-insensitive training

  • Pre-training from scratch
  • Have sufficient memory
  • Quality is paramount

Comparison with Alternatives

FeatureMHAGQA-8MQA
KV Parameters100%25%3.1%
Cache Size100%25%3.1%
QualityBestNear-bestGood
Inference Speed1.5×
ImplementationComplexModerateSimple

Production Examples

PaLM (Google)

Configuration:

  • d_model: 18,432
  • n_heads: 48
  • n_kv_heads: 1 (MQA)
  • layers: 118
  • context: 2,048 tokens

Memory Impact:

  • MHA cache: ~35.9 GB per sequence
  • MQA cache: ~0.75 GB per sequence
  • Savings: 97.9%

Falcon (TII)

Falcon-40B Configuration:

  • d_model: 8,192
  • n_heads: 64
  • attention_type: multi_query
  • n_kv_heads: 1

Performance: Consistently tops open LLM benchmarks while maintaining efficient serving

Common Pitfalls and Solutions

Pitfall 1: Incorrect Broadcasting

Problem: Using repeat() creates unnecessary copies, consuming extra memory

Solution: Use unsqueeze() with broadcasting for memory-efficient expansion

Impact: Reduces memory usage by 8× (for 8 heads) during forward pass

Pitfall 2: Cache Shape Mismatch

Problem: Caching K,V with head dimension wastes memory

Correct approach: Cache shape should be (batch, 1, seq_len, head_dim) not (batch, n_heads, seq_len, head_dim)

Impact: Ensures cache stays small and shared across heads

Pitfall 3: Learning Rate

Problem: Using same learning rate as MHA can cause training instability

Solution: Reduce learning rate by 40-50% for MQA training (e.g., 5e-4 instead of 1e-3)

Reason: Shared K,V parameters update from all heads simultaneously, requiring more careful optimization

Future Directions

Hybrid Approaches

  • First/last layers with MHA
  • Critical layers with GQA
  • Others with MQA

Learned Sharing

  • Dynamic K,V sharing
  • Attention-based routing
  • Adaptive compression

If you found this explanation helpful, consider sharing it with others.

Mastodon