Grouped-Query Attention (GQA)
Learn how Grouped-Query Attention balances the quality of Multi-Head Attention with the efficiency of Multi-Query Attention, enabling faster inference in large language models.
Best viewed on desktop for optimal interactive experience
Grouped-Query Attention: The Best of Both Worlds
Grouped-Query Attention (GQA) is an attention mechanism that strikes an optimal balance between the quality of Multi-Head Attention (MHA) and the efficiency of Multi-Query Attention (MQA), making it the preferred choice for modern large language models.
Interactive GQA Visualization
Explore how queries are grouped to share keys and values:
The Problem: Quality vs Efficiency Trade-off
Multi-Head Attention (MHA) offers best quality but uses massive memory. Multi-Query Attention (MQA) is highly efficient but sacrifices quality. We need something in between that balances both concerns.
The Insight: Group Heads to Share K,V
Instead of all heads sharing one K,V (MQA) or each having their own (MHA), Grouped-Query Attention (GQA) divides heads into groups. Heads within a group share K,V, giving us configurable quality-efficiency balance.
Architecture Comparison: Visual Difference
Compare the three approaches side-by-side. Notice how GQA sits between MHA (many K,V pairs) and MQA (single K,V pair), using 8 groups for 32 heads.
Perfect Balance: 75.0% Memory Savings, Minimal Quality Loss
With 8 groups, GQA uses only 25.0% of MHA's cache while maintaining quality much closer to MHA than MQA. This sweet spot makes GQA the preferred choice for modern LLMs.
Industry Standard: Llama 2, Mistral, and Beyond
GQA has become the de facto standard for modern LLMs, adopted by Meta's Llama 2, Mistral AI, and many others. It delivers the quality needed for production with efficiency that scales.
The Evolution: MHA → MQA → GQA
Multi-Head Attention (MHA)
- Every head has its own Q, K, V
- Best quality but highest memory usage
- KV cache size: 2 × L × H × D
Multi-Query Attention (MQA)
- All heads share single K, V
- Most efficient but quality degradation
- KV cache size: 2 × L × D
Grouped-Query Attention (GQA)
- Groups of heads share K, V
- Balanced quality and efficiency
- KV cache size: 2 × L × G × D
Where L = sequence length, H = num heads, G = num groups, D = head dimension
How GQA Works
The Grouping Mechanism
Instead of H separate KV pairs (MHA) or 1 shared KV pair (MQA), GQA uses G groups:
Configuration Example (32 heads, 8 groups):
- Total attention heads: 32
- Number of KV groups: 8
- Group size: 4 heads per group
- Memory savings: 75% reduction compared to MHA
Mathematical Formulation
For head h in group g:
Where:
- Qh is the query for head h
- Kg, Vg are shared keys/values for group g
- Group assignment: g = \lfloor h × G / H \rfloor
Implementation
Key Architecture Components
Projection Layers:
- Query projections: Separate for each head (num_heads × d_model)
- Key projections: Shared across groups (num_kv_heads × d_model)
- Value projections: Shared across groups (num_kv_heads × d_model)
- Output projection: Standard linear layer combining all heads
Forward Pass Steps:
- Project input to multi-head queries Q
- Project input to grouped K and V (fewer projections than queries)
- Repeat/expand K, V to match query head count using efficient views
- Compute scaled dot-product attention for each head
- Concatenate outputs and apply final projection
Key Constraint: Number of query heads must be divisible by number of KV heads to ensure even grouping
Efficient KV Cache Management
Cache Structure:
- Pre-allocated tensors for keys and values
- Shape: [num_layers, batch_size, num_kv_heads, max_seq_len, head_dim]
- Significantly smaller than MHA cache (uses num_kv_heads instead of num_heads)
- Incremental updates during autoregressive generation
Cache Operations:
- Store only num_kv_heads KV pairs per layer (not num_heads)
- Concatenate with previous cache for each new token
- Return expanded cache to match query head count during attention
- Enables efficient long-context inference
Memory and Compute Analysis
Memory Comparison
Method | KV Cache Size | Relative Size | Quality |
---|---|---|---|
MHA | 2 × L × H × D | 100% | Best |
GQA-8 | 2 × L × 8 × D | 25% (H=32) | Near-MHA |
GQA-4 | 2 × L × 4 × D | 12.5% (H=32) | Good |
MQA | 2 × L × 1 × D | 3.1% (H=32) | Degraded |
Compute Overhead
GQA adds minimal compute overhead:
- Projection: Same as MHA (different weight shapes)
- Repeat operation: O(1) using efficient tensor views (no data copying)
- Attention computation: Identical to MHA
Efficient K,V Expansion:
- Use unsqueeze and expand operations instead of copying
- Creates view over existing memory
- No additional memory allocation during forward pass
GQA in Production Models
Llama 2 Configuration
Llama 2 70B Architecture:
- Query heads: 64
- KV heads: 8 (GQA-8)
- Group size: 8 heads per group
- Context length: 4096 tokens
- Head dimension: 128
Memory Impact:
- MHA cache: ~8.4 GB per sequence
- GQA cache: ~1.0 GB per sequence
- 88% memory reduction
Mistral Configuration
Mistral 7B Architecture:
- Query heads: 32
- KV heads: 8 (GQA-8)
- Sliding window: 4096 tokens
- Combined with sliding window attention for double efficiency
- Enables efficient long-context processing
Training Considerations
Converting MHA to GQA
Uptraining Strategy - Convert pre-trained MHA models to GQA:
Conversion Steps:
- Keep query projections unchanged (all heads)
- Group K,V projections and average weights within each group
- Reshape averaged weights to match num_kv_heads dimension
- Fine-tune with lower learning rate to recover quality
- Typically converges within 5-10% of original training steps
Benefits:
- Leverage existing pre-trained weights
- Faster than training GQA from scratch
- Often achieves comparable final quality to original MHA
Training from Scratch
GQA can be trained directly:
- Similar convergence to MHA
- Slightly faster training (less memory movement)
- Regularization effect from parameter sharing
Choosing the Right Configuration
Decision Guide
Research and Experimentation:
- Use MHA when quality is paramount
- Acceptable when memory is abundant
Cloud Serving (Large Models greater than 30B):
- Use GQA-8 for memory-critical scenarios
- Balances quality with serving efficiency
Cloud Serving (Smaller Models less than 30B):
- Use GQA-16 or higher
- Can afford more groups for better quality
Edge Devices:
- Use MQA or GQA-4 for maximum efficiency
- Prioritize memory savings over quality
Batch Serving:
- Use GQA-8 to GQA-16
- Balance based on batch size and throughput requirements
Recommended Configurations
Model Size | Batch Size | Recommended | Groups | Rationale |
---|---|---|---|---|
Less than 7B | 1 | MHA or GQA-16 | 16-32 | Quality focus |
7B-13B | 1-8 | GQA-8 | 8 | Balanced |
13B-70B | 1-4 | GQA-8 | 8 | Memory critical |
More than 70B | 1-2 | GQA-4 or MQA | 4-1 | Extreme efficiency |
Performance Tips
1. Hardware Considerations
GPU-Specific Optimization:
- A100 GPUs: Use num_kv_heads as multiples of 8 (tensor core optimization)
- V100 GPUs: Prefer fewer groups (4-8) due to limited memory
- Align group sizes with hardware warp/thread block sizes
2. Dynamic Group Size
Adaptive Strategy:
- Use fewer groups for long sequences (>2048 tokens) to save memory
- Use more groups for short sequences to maintain quality
- Switch dynamically based on input length
- Balance memory usage with quality requirements
3. Optimize Memory Layout
Cache Organization:
- Use contiguous memory layout: [layer, batch, num_kv_heads, seq_len, head_dim]
- Optimize for sequential access patterns during autoregressive generation
- Pre-allocate cache to avoid dynamic memory allocation
- Keep cache aligned to cache line boundaries
Common Pitfalls
Pitfall 1: Uneven Group Sizes
Problem: Number of query heads not divisible by number of KV heads
- Example: 32 query heads with 7 KV heads (32 % 7 ≠ 0)
- Solution: Ensure num_heads is divisible by num_kv_heads
- Recommended: Use powers of 2 for both (8, 16, 32, 64)
Pitfall 2: Inefficient Repeat
Problem: Using operations that copy data instead of creating views
- Copying data wastes memory and bandwidth
- Solution: Use unsqueeze + expand + reshape for zero-copy expansion
- Creates view over existing memory without duplication
Pitfall 3: Wrong Cache Shape
Problem: Allocating cache with num_heads instead of num_kv_heads
- Wastes memory by storing unnecessary duplicates
- Solution: Cache shape should use num_kv_heads dimension
- Correct: [batch, seq_len, num_kv_heads, head_dim]
- Wrong: [batch, seq_len, num_heads, head_dim]