KV Cache: The Secret to Fast LLM Inference
Interactive visualization of key-value caching in LLMs - how caching transformer attention states enables efficient text generation without quadratic recomputation.
Best viewed on desktop for optimal interactive experience
KV Cache: Accelerating LLM Inference
The Key-Value (KV) cache is a critical optimization that makes autoregressive text generation practical. Without it, generating each new token would require recomputing attention for the entire sequence - making long-form generation prohibitively expensive.
Interactive KV Cache Demonstration
See how caching dramatically reduces computation during text generation:
KV Cache Settings
Token Generation & Caching
Memory Scaling Across Models
With KV Cache
- • Compute once, reuse many times
- • O(1) complexity per token
- • Linear memory growth
- • Fast incremental generation
- • Essential for production
Without KV Cache
- • Recompute all tokens each step
- • O(n²) complexity per token
- • No additional memory needed
- • Extremely slow generation
- • Impractical for long sequences
How KV Cache Works
During autoregressive generation, each new token needs to attend to all previous tokens. Without caching, this means recomputing the key and value projections for the entire sequence at every step.
The KV cache stores these projections after they're computed once, allowing the model to:
- Only compute K,V for the new token
- Reuse cached K,V for all previous tokens
- Reduce computation from O(n²) to O(n)
Memory formula: 2 × batch × seq_len × n_layers × n_heads × head_dim × sizeof(dtype)
The Recomputation Problem
Without Caching
During autoregressive generation, each token generation step requires:
Total computation for generating n tokens:
Where:
- L = number of layers
- H = number of attention heads
- d = head dimension
- n = sequence length
With KV Caching
Each step only computes for the new token:
Total computation becomes linear:
How KV Cache Works
The Attention Mechanism
Standard attention computation:
def attention(Q, K, V): # Q: [batch, seq_len, d_model] # K, V: [batch, seq_len, d_model] scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) weights = torch.softmax(scores, dim=-1) output = torch.matmul(weights, V) return output
Caching Strategy
During generation with caching:
class KVCache: def __init__(self, max_length, n_layers, n_heads, head_dim): self.cache_k = torch.zeros(n_layers, max_length, n_heads, head_dim) self.cache_v = torch.zeros(n_layers, max_length, n_heads, head_dim) self.seq_len = 0 def update(self, layer_idx, new_k, new_v): # new_k, new_v: [1, n_heads, head_dim] for single new token self.cache_k[layer_idx, self.seq_len] = new_k self.cache_v[layer_idx, self.seq_len] = new_v self.seq_len += 1 def get(self, layer_idx): return (self.cache_k[layer_idx, :self.seq_len], self.cache_v[layer_idx, :self.seq_len])
Generation Loop
def generate_with_cache(prompt_tokens, max_new_tokens): kv_cache = KVCache(...) # Process prompt (can be parallelized) hidden_states = embed(prompt_tokens) for layer in model.layers: k, v = layer.compute_kv(hidden_states) kv_cache.update(layer.idx, k, v) hidden_states = layer(hidden_states) # Generate new tokens (sequential) for _ in range(max_new_tokens): new_token = sample(hidden_states[-1]) new_hidden = embed(new_token) for layer in model.layers: # Only compute KV for new token new_k, new_v = layer.compute_kv(new_hidden) kv_cache.update(layer.idx, new_k, new_v) # Reuse cached KV for attention cached_k, cached_v = kv_cache.get(layer.idx) new_hidden = layer.attention(new_hidden, cached_k, cached_v) tokens.append(new_token) return tokens
Memory Requirements
Cache Size Formula
Where:
- 2 = Keys + Values
- B = batch size
- L = number of layers
- S = sequence length
- H = number of heads
- D = head dimension
Real-World Examples
Model | Context | Layers | Heads | Dim | Cache Size (FP16) |
---|---|---|---|---|---|
GPT-3 175B | 2K | 96 | 96 | 128 | 4.7 GB |
LLaMA-7B | 4K | 32 | 32 | 128 | 2.0 GB |
LLaMA-70B | 4K | 80 | 64 | 128 | 10.0 GB |
GPT-4* | 32K | 120 | 120 | 128 | 94.0 GB |
*Estimated configuration
Optimization Techniques
1. Multi-Query Attention (MQA)
Share keys and values across heads:
# Standard Multi-Head Attention K, V: [batch, seq_len, n_heads, head_dim] # Multi-Query Attention K, V: [batch, seq_len, 1, head_dim] # Shared across heads
Memory reduction: H × where H is number of heads
2. Grouped-Query Attention (GQA)
Balance between MHA and MQA:
# n_kv_heads < n_heads K, V: [batch, seq_len, n_kv_heads, head_dim] # Each KV head serves n_heads/n_kv_heads query heads
Used in: LLaMA-2 70B, Mistral
3. Sliding Window Cache
Only cache recent tokens:
class SlidingKVCache: def __init__(self, window_size): self.window_size = window_size self.cache = deque(maxlen=window_size) def update(self, k, v): self.cache.append((k, v)) # Automatically drops oldest when full
4. PagedAttention
Virtual memory for KV cache:
- Non-contiguous memory allocation
- Dynamic growth
- Memory sharing across sequences
- Used in vLLM
Cache Management Strategies
1. Static Allocation
Pre-allocate maximum size:
cache = torch.zeros(max_batch, max_length, ...)
- ✅ Simple, no fragmentation
- ❌ Wastes memory for short sequences
2. Dynamic Growth
Grow cache as needed:
if seq_len > cache_size: cache = torch.cat([cache, new_allocation], dim=1)
- ✅ Memory efficient
- ❌ Reallocation overhead
3. Block-wise Allocation
Allocate in fixed-size blocks:
blocks = [] while need_more_space: blocks.append(allocate_block())
- ✅ Balance of efficiency and flexibility
- ❌ More complex implementation
Advanced Techniques
Quantized KV Cache
Reduce precision to save memory:
def quantize_cache(cache, bits=8): scale = cache.abs().max() / (2**(bits-1) - 1) quantized = torch.round(cache / scale).to(torch.int8) return quantized, scale def dequantize_cache(quantized, scale): return quantized.to(torch.float16) * scale
Memory savings: 50% (FP16 → INT8) or 75% (FP16 → INT4)
Hierarchical Caching
Cache at multiple granularities:
- Token-level: Full resolution
- Segment-level: Compressed representations
- Document-level: Summary embeddings
Speculative Decoding Cache
Separate caches for draft and target models:
draft_cache = KVCache(small_model_config) target_cache = KVCache(large_model_config) # Generate with draft model draft_tokens = draft_model.generate(prompt, cache=draft_cache) # Verify with target model verified_tokens = target_model.verify(draft_tokens, cache=target_cache)
Performance Impact
Generation Speed
Without cache:
- Time per token: O(n) where n is current sequence length
- Total time: O(n²)
- Example: 1000 tokens = 500,000 attention computations
With cache:
- Time per token: O(1)
- Total time: O(n)
- Example: 1000 tokens = 1000 attention computations
Throughput Comparison
Sequence Length | Without Cache | With Cache | Speedup |
---|---|---|---|
128 tokens | 1.2 sec | 0.13 sec | 9.2× |
512 tokens | 19.5 sec | 0.51 sec | 38.2× |
2048 tokens | 312 sec | 2.05 sec | 152× |
8192 tokens | ~5000 sec | 8.2 sec | 610× |
Common Issues and Solutions
1. Memory Overflow
Problem: Cache exceeds available memory Solutions:
- Use sliding window cache
- Implement cache eviction
- Quantize cache values
- Use CPU offloading
2. Cache Invalidation
Problem: Prompt changes require cache reset Solutions:
- Incremental cache updates
- Prefix caching for common prompts
- Cache versioning
3. Batch Processing
Problem: Different sequences have different lengths Solutions:
- Padding and masking
- Dynamic batching
- Continuous batching (vLLM)
Implementation Best Practices
1. Memory Pool Management
class CachePool: def __init__(self, total_memory): self.pool = [] self.allocated = {} def allocate(self, request_id, size): if size <= self.available(): cache = self._get_from_pool(size) self.allocated[request_id] = cache return cache return None def free(self, request_id): cache = self.allocated.pop(request_id) self._return_to_pool(cache)
2. Cache Warming
Pre-compute common prefixes:
common_prefixes = ["You are a helpful", "Please analyze", ...] for prefix in common_prefixes: cache = compute_kv_cache(prefix) cache_store[hash(prefix)] = cache
3. Monitoring
Track cache metrics:
- Hit rate
- Memory usage
- Eviction rate
- Recomputation frequency
Related Concepts
- Context Windows - Maximum cache size limits
- Flash Attention - Memory-efficient attention
- Tokenization - What gets cached
- Attention Mechanisms - What we're caching
Conclusion
KV caching transforms LLM inference from quadratic to linear complexity, making real-time text generation feasible. Understanding cache dynamics is essential for optimizing inference performance and managing memory constraints in production deployments.