Multi-Query Attention (MQA)
Understand Multi-Query Attention, the radical efficiency optimization that shares keys and values across all attention heads, enabling massive memory savings for inference.
Best viewed on desktop for optimal interactive experience
Multi-Query Attention: Maximum Efficiency Through Sharing
Multi-Query Attention (MQA) is a radical simplification of multi-head attention that shares a single set of keys and values across all query heads, achieving dramatic memory savings with acceptable quality trade-offs.
Interactive MQA Visualization
See how all query heads share the same keys and values:
The Problem: Memory Bottleneck in Multi-Head Attention
Traditional Multi-Head Attention (MHA) stores separate Key and Value projections for each attention head. With 8 heads across 32 layers, this creates a massive KV cache that limits batch sizes and increases latency.
The Insight: Share Keys and Values Across All Heads
Multi-Query Attention (MQA) uses a radical simplification: maintain separate Query projections for each head, but share a single set of Keys and Values across all heads. This dramatically reduces memory while maintaining most of the model's expressiveness.
Architecture Comparison: See the Difference
Compare how queries, keys, and values flow through MHA vs MQA. Notice how MQA's shared K,V pairs dramatically reduce the number of attention computations while maintaining head-specific queries.
Massive Memory Savings: The Numbers
By sharing K,V across heads, MQA reduces the KV cache from O(n·h·d) to O(n·d), where h is the number of heads. With 8 heads, this means ~87.5% memory reduction with minimal quality loss.
Real-World Impact: Production Models
MQA has been adopted by major language models including Google's PaLM and TII's Falcon. These models achieve 1.5-2× inference speedup with only ~2% quality degradation, making MQA ideal for large-scale serving where memory and throughput are critical.
The Core Insight
Traditional Multi-Head Attention (MHA) maintains separate K, V projections for each head:
- Memory: O(n · h · d)
- Redundancy: Similar patterns learned across heads
MQA's breakthrough: One K, V pair serves all heads
- Memory: O(n · d)
- Efficiency: Up to 32× KV cache reduction
How MQA Works
The Architecture
Where each head computes:
Key differences from MHA:
- Queries: Still head-specific (Q1, Q2, ..., Qh)
- Keys/Values: Shared across all heads (Kshared, Vshared)
Implementation Details
Key Architecture Components
Projection Layers:
- Queries: Separate projection for each head (n_heads × d_model)
- Keys: Single shared projection (1 × head_dim)
- Values: Single shared projection (1 × head_dim)
- Output: Standard projection combining all heads
Forward Pass Steps:
- Project input X to multi-head queries Q
- Project input X to single K and V (shared across heads)
- Expand K, V to match all query heads via broadcasting
- Compute scaled dot-product attention for each head
- Concatenate head outputs and apply final projection
KV Cache Management:
- Cache stores only one K,V pair per layer (not per head)
- Cache shape: [n_layers, 1, max_seq_len, head_dim]
- Dramatically reduced memory: ~87-96% smaller than MHA
- Simple concatenation for incremental decoding
Memory Savings Analysis
KV Cache Comparison
For a model with 32 heads, 40 layers, sequence length 2048, head dimension 128:
Method | Cache Size per Token | Total for 2K Context | Reduction |
---|---|---|---|
MHA | 2 × 40 × 32 × 128 = 327,680 floats | 640 MB | 0% |
MQA | 2 × 40 × 1 × 128 = 10,240 floats | 20 MB | 96.9% |
Batch Serving Benefits
Calculating Maximum Batch Size:
Given GPU memory and model size, the KV cache determines maximum batch capacity:
Example: 80GB A100 GPU with 30GB Model
- Available memory for cache: 50GB
- MHA cache per sequence: 640 MB → Max batch: ~78 sequences
- MQA cache per sequence: 20 MB → Max batch: ~2,500 sequences
Impact: MQA enables 32× larger batches for the same hardware, dramatically improving serving throughput.
Quality Considerations
The Trade-off
MQA trades expressiveness for efficiency:
Parameter Count Comparison:
- MHA: n_heads × d_model × d_head × 3 (separate Q, K, V per head)
- MQA: n_heads × d_model × d_head + 2 × d_model × d_head (Q per head, shared K,V)
- Reduction: Approximately 66% fewer parameters in attention layers
Empirical Results
From the original paper (Shazeer, 2019):
Model | Attention Type | Perplexity | Speed |
---|---|---|---|
Base | MHA | 10.2 | 1.0× |
Base | MQA | 10.4 | 1.8× |
Large | MHA | 8.1 | 1.0× |
Large | MQA | 8.3 | 2.4× |
Key findings:
- Small quality loss (~2% perplexity increase)
- Significant speed gains (1.8-2.4×)
- Benefits scale with model size
Training Strategies
Training from Scratch
Configuration Adjustments:
- Increase model dimension by ~10% to compensate for reduced parameters
- Keep same number of query heads
- Use single K,V projection (n_kv_heads = 1)
- Reduce learning rate by ~10% for stability
- Increase training steps by ~10% to converge
Fine-tuning from MHA
Conversion Strategy:
- Copy query projections directly from MHA
- Average K,V projections across all heads to create shared K,V
- Keep output projection unchanged
- Fine-tune with lower learning rate (1e-5)
- Usually converges within 5-10% of original training
Benefits of Uptraining:
- Start from strong pretrained weights
- Faster than training from scratch
- Often achieves better final quality
Optimization Techniques
1. Fused Kernels
Efficient Broadcasting Strategy:
- Expand K,V in-place using unsqueeze(1) for memory efficiency
- Compute attention scores with broadcasting instead of explicit expansion
- Single fused kernel reduces memory allocations
- TorchScript compilation for additional speedup
Benefits:
- Reduced memory bandwidth usage
- Fewer kernel launches
- Better cache utilization
- 20-30% faster than naive implementation
2. Memory Layout Optimization
Optimized Weight Storage:
- Pack query weights for better cache locality
- Store K,V weights contiguously
- Use einsum for efficient multi-head query computation
- Minimize memory fragmentation
Performance Impact:
- Improved CPU-to-GPU transfer efficiency
- Better CUDA kernel occupancy
- Reduced register pressure
3. Dynamic Batching
Adaptive Batch Processing:
- Group requests by similar sequence lengths
- Bucket lengths to reduce padding waste (e.g., round to nearest 128)
- Pad sequences within each bucket
- Process each bucket in single forward pass
- Unpad results before returning
Throughput Gains:
- Minimize wasted computation on padding
- Better GPU utilization
- 2-3× higher throughput for variable-length batches
When to Use MQA
Ideal Use Cases
✅ Large-scale serving
- Maximize throughput
- Minimize latency
- Large batch sizes
✅ Memory-constrained environments
- Edge devices
- Mobile deployment
- Limited GPU memory
✅ Long context applications
- Document processing
- Multi-turn dialogue
- Code generation
When to Avoid
❌ Research/Experimentation
- Need best quality
- Small-scale deployment
- Abundant resources
❌ Latency-insensitive training
- Pre-training from scratch
- Have sufficient memory
- Quality is paramount
Comparison with Alternatives
Feature | MHA | GQA-8 | MQA |
---|---|---|---|
KV Parameters | 100% | 25% | 3.1% |
Cache Size | 100% | 25% | 3.1% |
Quality | Best | Near-best | Good |
Inference Speed | 1× | 1.5× | 2× |
Implementation | Complex | Moderate | Simple |
Production Examples
PaLM (Google)
Configuration:
- d_model: 18,432
- n_heads: 48
- n_kv_heads: 1 (MQA)
- layers: 118
- context: 2,048 tokens
Memory Impact:
- MHA cache: ~35.9 GB per sequence
- MQA cache: ~0.75 GB per sequence
- Savings: 97.9%
Falcon (TII)
Falcon-40B Configuration:
- d_model: 8,192
- n_heads: 64
- attention_type: multi_query
- n_kv_heads: 1
Performance: Consistently tops open LLM benchmarks while maintaining efficient serving
Common Pitfalls and Solutions
Pitfall 1: Incorrect Broadcasting
Problem: Using repeat() creates unnecessary copies, consuming extra memory
Solution: Use unsqueeze() with broadcasting for memory-efficient expansion
Impact: Reduces memory usage by 8× (for 8 heads) during forward pass
Pitfall 2: Cache Shape Mismatch
Problem: Caching K,V with head dimension wastes memory
Correct approach: Cache shape should be (batch, 1, seq_len, head_dim) not (batch, n_heads, seq_len, head_dim)
Impact: Ensures cache stays small and shared across heads
Pitfall 3: Learning Rate
Problem: Using same learning rate as MHA can cause training instability
Solution: Reduce learning rate by 40-50% for MQA training (e.g., 5e-4 instead of 1e-3)
Reason: Shared K,V parameters update from all heads simultaneously, requiring more careful optimization
Future Directions
Hybrid Approaches
- First/last layers with MHA
- Critical layers with GQA
- Others with MQA
Learned Sharing
- Dynamic K,V sharing
- Attention-based routing
- Adaptive compression