MHA vs GQA vs MQA: Choosing the Right Attention
Compare Multi-Head, Grouped-Query, and Multi-Query Attention mechanisms to understand their trade-offs and choose the optimal approach for your use case.
Best viewed on desktop for optimal interactive experience
MHA vs GQA vs MQA: The Complete Comparison
Understanding the trade-offs between Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA) is crucial for deploying efficient transformer models. Each approach offers different balances between quality, memory, and speed.
Interactive Comparison Tool
Compare the three attention mechanisms side-by-side:
MHA
GQA-8
RecommendedGQA-4
MQA
Visual Comparison
Recommendation for Your Configuration
Based on 7B model with 2048 token context and batch size 1:
Quick Decision Matrix
Use Case | Recommended | Why |
---|---|---|
Research/Training | MHA | Maximum quality, parameter count |
Cloud Serving (>30B) | GQA-8 | Balance of quality and efficiency |
Edge Deployment | MQA | Minimum memory footprint |
Long Context (>8K) | GQA-4 or MQA | Memory becomes critical |
Batch Inference | GQA-8 | Good balance for multiple requests |
Real-time Systems | MQA | Lowest latency |
Detailed Comparison
Architecture Differences
Feature | MHA | GQA | MQA |
---|---|---|---|
Q Projections | H separate | H separate | H separate |
K Projections | H separate | G groups | 1 shared |
V Projections | H separate | G groups | 1 shared |
Parameters | 3 × H × D2 | (H + 2G) × D2 | (H + 2) × D2 |
KV Heads | H | G | 1 |
Where H = number of heads, G = number of groups, D = model dimension
Memory Footprint
For a typical configuration (H=32, L=2048, D=128):
Method | KV Cache Size | Relative | Example (Llama 70B) |
---|---|---|---|
MHA | 2 × L × H × D | 100% | 8.4 GB/sequence |
GQA-8 | 2 × L × 8 × D | 25% | 2.1 GB/sequence |
GQA-4 | 2 × L × 4 × D | 12.5% | 1.0 GB/sequence |
MQA | 2 × L × 1 × D | 3.1% | 0.26 GB/sequence |
Performance Metrics
Metric | MHA | GQA-8 | MQA |
---|---|---|---|
Quality (Perplexity) | 10.0 (best) | 10.1 | 10.3 |
Inference Speed | 1.0× | 1.5× | 2.0× |
Training Speed | 1.0× | 1.1× | 1.2× |
Max Batch Size | 1× | 4× | 32× |
Implementation Complexity | High | Medium | Low |
Mathematical Formulations
MHA: Full Expressiveness
Each head has independent parameters:
- Total attention parameters: 3HD2
- KV cache per token: 2HD
GQA: Balanced Approach
Where g(i) = \lfloor i · G / H \rfloor maps heads to groups:
- Total attention parameters: (H + 2G)D2
- KV cache per token: 2GD
MQA: Maximum Sharing
All heads share the same K,V:
- Total attention parameters: (H + 2)D2
- KV cache per token: 2D
Production Model Configurations
Current Industry Adoption
Model | Size | Attention Type | Config | Rationale |
---|---|---|---|---|
GPT-3 | 175B | MHA | 96 heads | Quality priority |
Llama 2 | 70B | GQA | 64Q, 8KV | Balanced approach |
Llama 2 | 7B | GQA | 32Q, 32KV | Small model, less reduction needed |
Mistral | 7B | GQA + SWA | 32Q, 8KV | Combined optimizations |
Falcon | 40B | MQA | 64Q, 1KV | Maximum efficiency |
PaLM | 540B | MQA | 48Q, 1KV | Extreme scale requires MQA |
GPT-4 | ? | GQA (likely) | ? | Balanced for quality |
Configuration Examples
# Llama 2 70B Configuration config_llama_70b = { "n_heads": 64, "n_kv_heads": 8, # GQA with 8 groups "group_size": 8, # 64/8 = 8 heads per group "hidden_size": 8192, "head_dim": 128 } # Falcon 40B Configuration config_falcon_40b = { "n_heads": 64, "n_kv_heads": 1, # MQA "group_size": 64, # All heads share "hidden_size": 8192, "head_dim": 128 } # Mistral 7B Configuration config_mistral_7b = { "n_heads": 32, "n_kv_heads": 8, # GQA "group_size": 4, "hidden_size": 4096, "head_dim": 128, "sliding_window": 4096 # Additional optimization }
Quality Impact Analysis
Perplexity Comparison (from papers)
Model Size | MHA | GQA-8 | GQA-4 | MQA |
---|---|---|---|---|
1B params | 15.2 | 15.3 | 15.5 | 15.8 |
7B params | 10.1 | 10.2 | 10.3 | 10.5 |
30B params | 8.2 | 8.3 | 8.3 | 8.5 |
70B params | 7.1 | 7.1 | 7.2 | 7.3 |
Observations:
- Quality gap decreases with model size
- GQA-8 nearly matches MHA quality
- MQA shows ~2-3% degradation
Downstream Task Performance
Task | MHA | GQA-8 | MQA |
---|---|---|---|
MMLU | 67.3% | 67.1% | 66.5% |
HumanEval | 48.2% | 48.0% | 47.1% |
GSM8K | 78.5% | 78.2% | 77.3% |
Translation | 35.2 BLEU | 35.0 BLEU | 34.6 BLEU |
Deployment Scenarios
Scenario 1: Maximum Quality (Research)
# Use MHA for best results config = { "attention_type": "MHA", "n_heads": 32, "n_kv_heads": 32, # Same as n_heads "rationale": "Quality matters most, resources available" }
Scenario 2: Production API (Cloud)
# Use GQA for balance config = { "attention_type": "GQA", "n_heads": 32, "n_kv_heads": 8, # 4:1 ratio "rationale": "Balance quality with serving costs" }
Scenario 3: Edge Device
# Use MQA for efficiency config = { "attention_type": "MQA", "n_heads": 32, "n_kv_heads": 1, "rationale": "Memory/power constraints critical" }
Scenario 4: Long Context (32K+)
# Use aggressive GQA or MQA config = { "attention_type": "GQA", "n_heads": 32, "n_kv_heads": 4, # 8:1 ratio "rationale": "KV cache dominates at long context" }
Conversion Strategies
Converting Between Attention Types
def convert_attention_type(model, from_type, to_type): """Convert model between attention types""" if from_type == "MHA" and to_type == "GQA": # Average K,V weights within groups return average_kv_weights_for_groups(model) elif from_type == "MHA" and to_type == "MQA": # Average all K,V weights return average_all_kv_weights(model) elif from_type == "GQA" and to_type == "MQA": # Average group K,V weights return average_group_weights(model) elif from_type in ["GQA", "MQA"] and to_type == "MHA": # Replicate K,V weights (requires fine-tuning) return replicate_kv_weights(model)
Fine-tuning After Conversion
# Recommended fine-tuning settings fine_tune_config = { "MHA_to_GQA": { "lr": 1e-5, "steps": 10000, "warmup": 1000 }, "MHA_to_MQA": { "lr": 5e-6, "steps": 20000, "warmup": 2000 } }
Optimization Combinations
Synergistic Optimizations
Base Attention | + Optimization | Result | Example |
---|---|---|---|
GQA | + Flash Attention | Fast + memory efficient | Llama 2 |
GQA | + Sliding Window | Local + global efficiency | Mistral |
MQA | + Flash Attention | Maximum efficiency | Optimized Falcon |
GQA | + RoPE | Efficient + better positions | Most modern LLMs |
Implementation Stack
class OptimizedAttention(nn.Module): """Combine multiple optimizations""" def __init__(self, config): super().__init__() # Choose base attention if config.n_kv_heads == 1: self.attention = MQA(config) elif config.n_kv_heads < config.n_heads: self.attention = GQA(config) else: self.attention = MHA(config) # Add optimizations if config.use_flash: self.attention = FlashWrapper(self.attention) if config.use_sliding_window: self.attention = SlidingWindowWrapper(self.attention) if config.use_rope: self.pos_encoding = RoPE(config)
Cost Analysis
Serving Cost Comparison
For 1M tokens/day with 2K context:
Attention | GPU Hours | Memory | Cost/Day |
---|---|---|---|
MHA | 24 | 80GB A100 | $72 |
GQA-8 | 16 | 40GB A100 | $32 |
MQA | 12 | 40GB A100 | $24 |
Scaling Analysis
def estimate_max_batch_size(gpu_memory_gb, model_params_gb, attention_type, context_len): """Estimate maximum batch size for given GPU""" # Available memory for KV cache available = gpu_memory_gb - model_params_gb - 2 # 2GB overhead # KV cache per sequence (GB) kv_cache_size = { "MHA": context_len * n_layers * n_heads * head_dim * 2 * 4 / 1e9, "GQA-8": context_len * n_layers * 8 * head_dim * 2 * 4 / 1e9, "MQA": context_len * n_layers * 1 * head_dim * 2 * 4 / 1e9 } return int(available / kv_cache_size[attention_type]) # Example: 80GB A100, 30GB model, 4K context print(estimate_max_batch_size(80, 30, "MHA", 4096)) # ~15 print(estimate_max_batch_size(80, 30, "GQA-8", 4096)) # ~60 print(estimate_max_batch_size(80, 30, "MQA", 4096)) # ~480
Best Practices
Selection Guidelines
-
Start with GQA-8 as default
- Good balance for most use cases
- Minimal quality loss
- 4× memory savings
-
Consider MQA when:
- Serving at scale (more than 1000 QPS)
- Memory constrained (less than 40GB)
- Long context (more than 16K)
- Batch size critical
-
Stick with MHA when:
- Research/experimentation
- Quality is paramount
- Small models (less than 1B params)
- Abundant resources
Implementation Checklist
- Profile memory usage with target context length
- Benchmark quality on your specific tasks
- Test different group sizes for GQA
- Consider combining with other optimizations
- Implement proper KV cache management
- Add monitoring for attention patterns
Future Directions
Emerging Approaches
-
Dynamic Attention Selection
- Switch between MHA/GQA/MQA per layer
- Adapt based on input complexity
-
Learned Group Assignment
- Learn which heads to group
- Task-specific grouping
-
Hybrid Architectures
- MHA for critical layers
- MQA for others