Grouped-Query Attention (GQA)

Learn how Grouped-Query Attention balances the quality of Multi-Head Attention with the efficiency of Multi-Query Attention, enabling faster inference in large language models.

Best viewed on desktop for optimal interactive experience

Grouped-Query Attention: The Best of Both Worlds

Grouped-Query Attention (GQA) is an attention mechanism that strikes an optimal balance between the quality of Multi-Head Attention (MHA) and the efficiency of Multi-Query Attention (MQA), making it the preferred choice for modern large language models.

Interactive GQA Visualization

Explore how queries are grouped to share keys and values:

The Problem: Quality vs Efficiency Trade-off

Multi-Head Attention (MHA) offers best quality but uses massive memory. Multi-Query Attention (MQA) is highly efficient but sacrifices quality. We need something in between that balances both concerns.

Multi-Head Attention (MHA)
Best quality - each head learns unique patterns
Huge memory: 2560 MB cache
Limited batch sizes in production
Multi-Query Attention (MQA)
Highly efficient: 80 MB cache
32× larger batches possible
Quality degradation (~2% perplexity loss)

The Insight: Group Heads to Share K,V

Instead of all heads sharing one K,V (MQA) or each having their own (MHA), Grouped-Query Attention (GQA) divides heads into groups. Heads within a group share K,V, giving us configurable quality-efficiency balance.

How Grouping Works:
32 heads with 8 groups = 4 heads per group
Group 1:
Q1, Q2, Q3, Q4 → K₁, V₁
Group 2:
Q5, Q6, Q7, Q8 → K₂, V₂
Group 3:
Q9, Q10, Q11, Q12 → K₃, V₃
...
...
✓ Each group maintains some diversity while sharing K,V

Architecture Comparison: Visual Difference

Compare the three approaches side-by-side. Notice how GQA sits between MHA (many K,V pairs) and MQA (single K,V pair), using 8 groups for 32 heads.

Query Heads (32)
Q1
Q2
Q3
Q4
Q5
Q6
Q7
Q8
Q9
Q10
Q11
Q12
Q13
Q14
Q15
Q16
Q17
Q18
Q19
Q20
Q21
Q22
Q23
Q24
Q25
Q26
Q27
Q28
Q29
Q30
Q31
Q32
K,V Pairs (8)
K0V0
K1V1
K2V2
K3V3
K4V4
K5V5
K6V6
K7V7
Number of Groups: 84 heads/group
1 (=MQA)32 (=MHA)

Perfect Balance: 75.0% Memory Savings, Minimal Quality Loss

With 8 groups, GQA uses only 25.0% of MHA's cache while maintaining quality much closer to MHA than MQA. This sweet spot makes GQA the preferred choice for modern LLMs.

GQA Cache Size
640 MB
Per sequence, 2K context
Memory Savings
75.0%
vs MHA baseline
Quality Impact
<1%
Perplexity degradation
Configurable Trade-off:
More groups (→32) = Better quality, more memory (approaches MHA)
Fewer groups (→1) = Less memory, slight quality loss (approaches MQA)
Sweet spot: 8 groups for 32 heads = 75% memory savings, <1% quality loss

Industry Standard: Llama 2, Mistral, and Beyond

GQA has become the de facto standard for modern LLMs, adopted by Meta's Llama 2, Mistral AI, and many others. It delivers the quality needed for production with efficiency that scales.

Llama 2 (Meta)
• 70B parameters
• 64 query heads
• 8 KV heads (GQA-8)
• 87.5% cache reduction
Mistral 7B
• 7B parameters
• 32 query heads
• 8 KV heads (GQA-8)
• Top-tier performance/efficiency
Why GQA Wins:
✓ Quality
Minimal degradation (<1%) compared to MHA
✓ Efficiency
75-87% memory savings vs MHA
✓ Flexibility
Configurable groups for custom trade-offs
Adjust the number of groups in Section 3 to see how GQA balances quality and efficiency. Sweet spot: 8 groups for 32 heads.

The Evolution: MHA → MQA → GQA

Multi-Head Attention (MHA)

  • Every head has its own Q, K, V
  • Best quality but highest memory usage
  • KV cache size: 2 × L × H × D

Multi-Query Attention (MQA)

  • All heads share single K, V
  • Most efficient but quality degradation
  • KV cache size: 2 × L × D

Grouped-Query Attention (GQA)

  • Groups of heads share K, V
  • Balanced quality and efficiency
  • KV cache size: 2 × L × G × D

Where L = sequence length, H = num heads, G = num groups, D = head dimension

How GQA Works

The Grouping Mechanism

Instead of H separate KV pairs (MHA) or 1 shared KV pair (MQA), GQA uses G groups:

Configuration Example (32 heads, 8 groups):

  • Total attention heads: 32
  • Number of KV groups: 8
  • Group size: 4 heads per group
  • Memory savings: 75% reduction compared to MHA

Mathematical Formulation

For head h in group g:

Attentionh = softmax(Qh KgT√(dk))Vg

Where:

  • Qh is the query for head h
  • Kg, Vg are shared keys/values for group g
  • Group assignment: g = \lfloor h × G / H \rfloor

Implementation

Key Architecture Components

Projection Layers:

  • Query projections: Separate for each head (num_heads × d_model)
  • Key projections: Shared across groups (num_kv_heads × d_model)
  • Value projections: Shared across groups (num_kv_heads × d_model)
  • Output projection: Standard linear layer combining all heads

Forward Pass Steps:

  1. Project input to multi-head queries Q
  2. Project input to grouped K and V (fewer projections than queries)
  3. Repeat/expand K, V to match query head count using efficient views
  4. Compute scaled dot-product attention for each head
  5. Concatenate outputs and apply final projection

Key Constraint: Number of query heads must be divisible by number of KV heads to ensure even grouping

Efficient KV Cache Management

Cache Structure:

  • Pre-allocated tensors for keys and values
  • Shape: [num_layers, batch_size, num_kv_heads, max_seq_len, head_dim]
  • Significantly smaller than MHA cache (uses num_kv_heads instead of num_heads)
  • Incremental updates during autoregressive generation

Cache Operations:

  • Store only num_kv_heads KV pairs per layer (not num_heads)
  • Concatenate with previous cache for each new token
  • Return expanded cache to match query head count during attention
  • Enables efficient long-context inference

Memory and Compute Analysis

Memory Comparison

MethodKV Cache SizeRelative SizeQuality
MHA2 × L × H × D100%Best
GQA-82 × L × 8 × D25% (H=32)Near-MHA
GQA-42 × L × 4 × D12.5% (H=32)Good
MQA2 × L × 1 × D3.1% (H=32)Degraded

Compute Overhead

GQA adds minimal compute overhead:

  • Projection: Same as MHA (different weight shapes)
  • Repeat operation: O(1) using efficient tensor views (no data copying)
  • Attention computation: Identical to MHA

Efficient K,V Expansion:

  • Use unsqueeze and expand operations instead of copying
  • Creates view over existing memory
  • No additional memory allocation during forward pass

GQA in Production Models

Llama 2 Configuration

Llama 2 70B Architecture:

  • Query heads: 64
  • KV heads: 8 (GQA-8)
  • Group size: 8 heads per group
  • Context length: 4096 tokens
  • Head dimension: 128

Memory Impact:

  • MHA cache: ~8.4 GB per sequence
  • GQA cache: ~1.0 GB per sequence
  • 88% memory reduction

Mistral Configuration

Mistral 7B Architecture:

  • Query heads: 32
  • KV heads: 8 (GQA-8)
  • Sliding window: 4096 tokens
  • Combined with sliding window attention for double efficiency
  • Enables efficient long-context processing

Training Considerations

Converting MHA to GQA

Uptraining Strategy - Convert pre-trained MHA models to GQA:

Conversion Steps:

  1. Keep query projections unchanged (all heads)
  2. Group K,V projections and average weights within each group
  3. Reshape averaged weights to match num_kv_heads dimension
  4. Fine-tune with lower learning rate to recover quality
  5. Typically converges within 5-10% of original training steps

Benefits:

  • Leverage existing pre-trained weights
  • Faster than training GQA from scratch
  • Often achieves comparable final quality to original MHA

Training from Scratch

GQA can be trained directly:

  • Similar convergence to MHA
  • Slightly faster training (less memory movement)
  • Regularization effect from parameter sharing

Choosing the Right Configuration

Decision Guide

Research and Experimentation:

  • Use MHA when quality is paramount
  • Acceptable when memory is abundant

Cloud Serving (Large Models greater than 30B):

  • Use GQA-8 for memory-critical scenarios
  • Balances quality with serving efficiency

Cloud Serving (Smaller Models less than 30B):

  • Use GQA-16 or higher
  • Can afford more groups for better quality

Edge Devices:

  • Use MQA or GQA-4 for maximum efficiency
  • Prioritize memory savings over quality

Batch Serving:

  • Use GQA-8 to GQA-16
  • Balance based on batch size and throughput requirements
Model SizeBatch SizeRecommendedGroupsRationale
Less than 7B1MHA or GQA-1616-32Quality focus
7B-13B1-8GQA-88Balanced
13B-70B1-4GQA-88Memory critical
More than 70B1-2GQA-4 or MQA4-1Extreme efficiency

Performance Tips

1. Hardware Considerations

GPU-Specific Optimization:

  • A100 GPUs: Use num_kv_heads as multiples of 8 (tensor core optimization)
  • V100 GPUs: Prefer fewer groups (4-8) due to limited memory
  • Align group sizes with hardware warp/thread block sizes

2. Dynamic Group Size

Adaptive Strategy:

  • Use fewer groups for long sequences (>2048 tokens) to save memory
  • Use more groups for short sequences to maintain quality
  • Switch dynamically based on input length
  • Balance memory usage with quality requirements

3. Optimize Memory Layout

Cache Organization:

  • Use contiguous memory layout: [layer, batch, num_kv_heads, seq_len, head_dim]
  • Optimize for sequential access patterns during autoregressive generation
  • Pre-allocate cache to avoid dynamic memory allocation
  • Keep cache aligned to cache line boundaries

Common Pitfalls

Pitfall 1: Uneven Group Sizes

Problem: Number of query heads not divisible by number of KV heads

  • Example: 32 query heads with 7 KV heads (32 % 7 ≠ 0)
  • Solution: Ensure num_heads is divisible by num_kv_heads
  • Recommended: Use powers of 2 for both (8, 16, 32, 64)

Pitfall 2: Inefficient Repeat

Problem: Using operations that copy data instead of creating views

  • Copying data wastes memory and bandwidth
  • Solution: Use unsqueeze + expand + reshape for zero-copy expansion
  • Creates view over existing memory without duplication

Pitfall 3: Wrong Cache Shape

Problem: Allocating cache with num_heads instead of num_kv_heads

  • Wastes memory by storing unnecessary duplicates
  • Solution: Cache shape should use num_kv_heads dimension
  • Correct: [batch, seq_len, num_kv_heads, head_dim]
  • Wrong: [batch, seq_len, num_heads, head_dim]

If you found this explanation helpful, consider sharing it with others.

Mastodon