MHA vs GQA vs MQA: Choosing the Right Attention

Compare Multi-Head, Grouped-Query, and Multi-Query Attention mechanisms to understand their trade-offs and choose the optimal approach for your use case.

Best viewed on desktop for optimal interactive experience

MHA vs GQA vs MQA: The Complete Comparison

Understanding the trade-offs between Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA) is crucial for deploying efficient transformer models. Each approach offers different balances between quality, memory, and speed.

Interactive Comparison Tool

Compare the three attention mechanisms side-by-side:

MHA

KV Cache2048 MB
Quality100%
Speed1×
Max Batch16

GQA-8

Recommended
KV Cache512 MB
Quality98%
Speed1.5×
Max Batch64

GQA-4

KV Cache256 MB
Quality96%
Speed1.8×
Max Batch128

MQA

KV Cache64 MB
Quality94%
Speed2×
Max Batch515

Visual Comparison

Memory Usage (MB)
MHA
2048
GQA-8
512
GQA-4
256
MQA
64
Quality Score (%)
MHA
100
GQA-8
98
GQA-4
96
MQA
94
Speed Multiplier
MHA
1.0
GQA-8
1.5
GQA-4
1.8
MQA
2.0
Max Batch Size (40GB GPU)
MHA
16
GQA-8
64
GQA-4
128
MQA
515

Recommendation for Your Configuration

Based on 7B model with 2048 token context and batch size 1:

GQA-8
Best balance of quality and efficiency

Quick Decision Matrix

Use CaseRecommendedWhy
Research/TrainingMHAMaximum quality, parameter count
Cloud Serving (>30B)GQA-8Balance of quality and efficiency
Edge DeploymentMQAMinimum memory footprint
Long Context (>8K)GQA-4 or MQAMemory becomes critical
Batch InferenceGQA-8Good balance for multiple requests
Real-time SystemsMQALowest latency

Detailed Comparison

Architecture Differences

FeatureMHAGQAMQA
Q ProjectionsH separateH separateH separate
K ProjectionsH separateG groups1 shared
V ProjectionsH separateG groups1 shared
Parameters3 × H × D2(H + 2G) × D2(H + 2) × D2
KV HeadsHG1

Where H = number of heads, G = number of groups, D = model dimension

Memory Footprint

For a typical configuration (H=32, L=2048, D=128):

MethodKV Cache SizeRelativeExample (Llama 70B)
MHA2 × L × H × D100%8.4 GB/sequence
GQA-82 × L × 8 × D25%2.1 GB/sequence
GQA-42 × L × 4 × D12.5%1.0 GB/sequence
MQA2 × L × 1 × D3.1%0.26 GB/sequence

Performance Metrics

MetricMHAGQA-8MQA
Quality (Perplexity)10.0 (best)10.110.3
Inference Speed1.0×1.5×2.0×
Training Speed1.0×1.1×1.2×
Max Batch Size32×
Implementation ComplexityHighMediumLow

Mathematical Formulations

MHA: Full Expressiveness

headi = Attention(Qi, Ki, Vi) ∀ i ∈ [1, H]

Each head has independent parameters:

  • Total attention parameters: 3HD2
  • KV cache per token: 2HD

GQA: Balanced Approach

headi = Attention(Qi, Kg(i), Vg(i))

Where g(i) = \lfloor i · G / H \rfloor maps heads to groups:

  • Total attention parameters: (H + 2G)D2
  • KV cache per token: 2GD

MQA: Maximum Sharing

headi = Attention(Qi, Kshared, Vshared) ∀ i

All heads share the same K,V:

  • Total attention parameters: (H + 2)D2
  • KV cache per token: 2D

Production Model Configurations

Current Industry Adoption

ModelSizeAttention TypeConfigRationale
GPT-3175BMHA96 headsQuality priority
Llama 270BGQA64Q, 8KVBalanced approach
Llama 27BGQA32Q, 32KVSmall model, less reduction needed
Mistral7BGQA + SWA32Q, 8KVCombined optimizations
Falcon40BMQA64Q, 1KVMaximum efficiency
PaLM540BMQA48Q, 1KVExtreme scale requires MQA
GPT-4?GQA (likely)?Balanced for quality

Configuration Examples

# Llama 2 70B Configuration config_llama_70b = { "n_heads": 64, "n_kv_heads": 8, # GQA with 8 groups "group_size": 8, # 64/8 = 8 heads per group "hidden_size": 8192, "head_dim": 128 } # Falcon 40B Configuration config_falcon_40b = { "n_heads": 64, "n_kv_heads": 1, # MQA "group_size": 64, # All heads share "hidden_size": 8192, "head_dim": 128 } # Mistral 7B Configuration config_mistral_7b = { "n_heads": 32, "n_kv_heads": 8, # GQA "group_size": 4, "hidden_size": 4096, "head_dim": 128, "sliding_window": 4096 # Additional optimization }

Quality Impact Analysis

Perplexity Comparison (from papers)

Model SizeMHAGQA-8GQA-4MQA
1B params15.215.315.515.8
7B params10.110.210.310.5
30B params8.28.38.38.5
70B params7.17.17.27.3

Observations:

  • Quality gap decreases with model size
  • GQA-8 nearly matches MHA quality
  • MQA shows ~2-3% degradation

Downstream Task Performance

TaskMHAGQA-8MQA
MMLU67.3%67.1%66.5%
HumanEval48.2%48.0%47.1%
GSM8K78.5%78.2%77.3%
Translation35.2 BLEU35.0 BLEU34.6 BLEU

Deployment Scenarios

Scenario 1: Maximum Quality (Research)

# Use MHA for best results config = { "attention_type": "MHA", "n_heads": 32, "n_kv_heads": 32, # Same as n_heads "rationale": "Quality matters most, resources available" }

Scenario 2: Production API (Cloud)

# Use GQA for balance config = { "attention_type": "GQA", "n_heads": 32, "n_kv_heads": 8, # 4:1 ratio "rationale": "Balance quality with serving costs" }

Scenario 3: Edge Device

# Use MQA for efficiency config = { "attention_type": "MQA", "n_heads": 32, "n_kv_heads": 1, "rationale": "Memory/power constraints critical" }

Scenario 4: Long Context (32K+)

# Use aggressive GQA or MQA config = { "attention_type": "GQA", "n_heads": 32, "n_kv_heads": 4, # 8:1 ratio "rationale": "KV cache dominates at long context" }

Conversion Strategies

Converting Between Attention Types

def convert_attention_type(model, from_type, to_type): """Convert model between attention types""" if from_type == "MHA" and to_type == "GQA": # Average K,V weights within groups return average_kv_weights_for_groups(model) elif from_type == "MHA" and to_type == "MQA": # Average all K,V weights return average_all_kv_weights(model) elif from_type == "GQA" and to_type == "MQA": # Average group K,V weights return average_group_weights(model) elif from_type in ["GQA", "MQA"] and to_type == "MHA": # Replicate K,V weights (requires fine-tuning) return replicate_kv_weights(model)

Fine-tuning After Conversion

# Recommended fine-tuning settings fine_tune_config = { "MHA_to_GQA": { "lr": 1e-5, "steps": 10000, "warmup": 1000 }, "MHA_to_MQA": { "lr": 5e-6, "steps": 20000, "warmup": 2000 } }

Optimization Combinations

Synergistic Optimizations

Base Attention+ OptimizationResultExample
GQA+ Flash AttentionFast + memory efficientLlama 2
GQA+ Sliding WindowLocal + global efficiencyMistral
MQA+ Flash AttentionMaximum efficiencyOptimized Falcon
GQA+ RoPEEfficient + better positionsMost modern LLMs

Implementation Stack

class OptimizedAttention(nn.Module): """Combine multiple optimizations""" def __init__(self, config): super().__init__() # Choose base attention if config.n_kv_heads == 1: self.attention = MQA(config) elif config.n_kv_heads < config.n_heads: self.attention = GQA(config) else: self.attention = MHA(config) # Add optimizations if config.use_flash: self.attention = FlashWrapper(self.attention) if config.use_sliding_window: self.attention = SlidingWindowWrapper(self.attention) if config.use_rope: self.pos_encoding = RoPE(config)

Cost Analysis

Serving Cost Comparison

For 1M tokens/day with 2K context:

AttentionGPU HoursMemoryCost/Day
MHA2480GB A100$72
GQA-81640GB A100$32
MQA1240GB A100$24

Scaling Analysis

def estimate_max_batch_size(gpu_memory_gb, model_params_gb, attention_type, context_len): """Estimate maximum batch size for given GPU""" # Available memory for KV cache available = gpu_memory_gb - model_params_gb - 2 # 2GB overhead # KV cache per sequence (GB) kv_cache_size = { "MHA": context_len * n_layers * n_heads * head_dim * 2 * 4 / 1e9, "GQA-8": context_len * n_layers * 8 * head_dim * 2 * 4 / 1e9, "MQA": context_len * n_layers * 1 * head_dim * 2 * 4 / 1e9 } return int(available / kv_cache_size[attention_type]) # Example: 80GB A100, 30GB model, 4K context print(estimate_max_batch_size(80, 30, "MHA", 4096)) # ~15 print(estimate_max_batch_size(80, 30, "GQA-8", 4096)) # ~60 print(estimate_max_batch_size(80, 30, "MQA", 4096)) # ~480

Best Practices

Selection Guidelines

  1. Start with GQA-8 as default

    • Good balance for most use cases
    • Minimal quality loss
    • 4× memory savings
  2. Consider MQA when:

    • Serving at scale (more than 1000 QPS)
    • Memory constrained (less than 40GB)
    • Long context (more than 16K)
    • Batch size critical
  3. Stick with MHA when:

    • Research/experimentation
    • Quality is paramount
    • Small models (less than 1B params)
    • Abundant resources

Implementation Checklist

  • Profile memory usage with target context length
  • Benchmark quality on your specific tasks
  • Test different group sizes for GQA
  • Consider combining with other optimizations
  • Implement proper KV cache management
  • Add monitoring for attention patterns

Future Directions

Emerging Approaches

  1. Dynamic Attention Selection

    • Switch between MHA/GQA/MQA per layer
    • Adapt based on input complexity
  2. Learned Group Assignment

    • Learn which heads to group
    • Task-specific grouping
  3. Hybrid Architectures

    • MHA for critical layers
    • MQA for others

If you found this explanation helpful, consider sharing it with others.

Mastodon