KV Cache: The Secret to Fast LLM Inference

Interactive visualization of key-value caching in LLMs - how caching transformer attention states enables efficient text generation without quadratic recomputation.

Best viewed on desktop for optimal interactive experience

KV Cache: Accelerating LLM Inference

The Key-Value (KV) cache is a critical optimization that makes autoregressive text generation practical. Without it, generating each new token would require recomputing attention for the entire sequence - making long-form generation prohibitively expensive.

Interactive KV Cache Demonstration

See how caching dramatically reduces computation during text generation:

KV Cache Settings

0
FLOPs Saved
0.0
KB Used
0.0
Seconds
0.0
Tokens/sec

Token Generation & Caching

Prompt:
The
quick
brown
fox
Generated:

Memory Scaling Across Models

GPT-236.0 MB for 1024 tokens
12L × 12H × 64D
GPT-39216.0 MB for 2048 tokens
96L × 96H × 128D
LLaMA-7B2048.0 MB for 4096 tokens
32L × 32H × 128D
LLaMA-70B10240.0 MB for 4096 tokens
80L × 64H × 128D

With KV Cache

  • • Compute once, reuse many times
  • • O(1) complexity per token
  • • Linear memory growth
  • • Fast incremental generation
  • • Essential for production

Without KV Cache

  • • Recompute all tokens each step
  • • O(n²) complexity per token
  • • No additional memory needed
  • • Extremely slow generation
  • • Impractical for long sequences

How KV Cache Works

During autoregressive generation, each new token needs to attend to all previous tokens. Without caching, this means recomputing the key and value projections for the entire sequence at every step.

The KV cache stores these projections after they're computed once, allowing the model to:

  • Only compute K,V for the new token
  • Reuse cached K,V for all previous tokens
  • Reduce computation from O(n²) to O(n)

Memory formula: 2 × batch × seq_len × n_layers × n_heads × head_dim × sizeof(dtype)

The Recomputation Problem

Without Caching

During autoregressive generation, each token generation step requires:

Computestep\i = i × L × H × d2

Total computation for generating n tokens:

Total = Σi=1n i × L × H × d2 = O(n2)

Where:

  • L = number of layers
  • H = number of attention heads
  • d = head dimension
  • n = sequence length

With KV Caching

Each step only computes for the new token:

Computestep = L × H × d2 = O(1)

Total computation becomes linear:

Totalcached = n × L × H × d2 = O(n)

How KV Cache Works

The Attention Mechanism

Standard attention computation:

def attention(Q, K, V): # Q: [batch, seq_len, d_model] # K, V: [batch, seq_len, d_model] scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) weights = torch.softmax(scores, dim=-1) output = torch.matmul(weights, V) return output

Caching Strategy

During generation with caching:

class KVCache: def __init__(self, max_length, n_layers, n_heads, head_dim): self.cache_k = torch.zeros(n_layers, max_length, n_heads, head_dim) self.cache_v = torch.zeros(n_layers, max_length, n_heads, head_dim) self.seq_len = 0 def update(self, layer_idx, new_k, new_v): # new_k, new_v: [1, n_heads, head_dim] for single new token self.cache_k[layer_idx, self.seq_len] = new_k self.cache_v[layer_idx, self.seq_len] = new_v self.seq_len += 1 def get(self, layer_idx): return (self.cache_k[layer_idx, :self.seq_len], self.cache_v[layer_idx, :self.seq_len])

Generation Loop

def generate_with_cache(prompt_tokens, max_new_tokens): kv_cache = KVCache(...) # Process prompt (can be parallelized) hidden_states = embed(prompt_tokens) for layer in model.layers: k, v = layer.compute_kv(hidden_states) kv_cache.update(layer.idx, k, v) hidden_states = layer(hidden_states) # Generate new tokens (sequential) for _ in range(max_new_tokens): new_token = sample(hidden_states[-1]) new_hidden = embed(new_token) for layer in model.layers: # Only compute KV for new token new_k, new_v = layer.compute_kv(new_hidden) kv_cache.update(layer.idx, new_k, new_v) # Reuse cached KV for attention cached_k, cached_v = kv_cache.get(layer.idx) new_hidden = layer.attention(new_hidden, cached_k, cached_v) tokens.append(new_token) return tokens

Memory Requirements

Cache Size Formula

MemoryKV = 2 × B × L × S × H × D × sizeof(dtype)

Where:

  • 2 = Keys + Values
  • B = batch size
  • L = number of layers
  • S = sequence length
  • H = number of heads
  • D = head dimension

Real-World Examples

ModelContextLayersHeadsDimCache Size (FP16)
GPT-3 175B2K96961284.7 GB
LLaMA-7B4K32321282.0 GB
LLaMA-70B4K806412810.0 GB
GPT-4*32K12012012894.0 GB

*Estimated configuration

Optimization Techniques

1. Multi-Query Attention (MQA)

Share keys and values across heads:

# Standard Multi-Head Attention K, V: [batch, seq_len, n_heads, head_dim] # Multi-Query Attention K, V: [batch, seq_len, 1, head_dim] # Shared across heads

Memory reduction: H × where H is number of heads

2. Grouped-Query Attention (GQA)

Balance between MHA and MQA:

# n_kv_heads < n_heads K, V: [batch, seq_len, n_kv_heads, head_dim] # Each KV head serves n_heads/n_kv_heads query heads

Used in: LLaMA-2 70B, Mistral

3. Sliding Window Cache

Only cache recent tokens:

class SlidingKVCache: def __init__(self, window_size): self.window_size = window_size self.cache = deque(maxlen=window_size) def update(self, k, v): self.cache.append((k, v)) # Automatically drops oldest when full

4. PagedAttention

Virtual memory for KV cache:

  • Non-contiguous memory allocation
  • Dynamic growth
  • Memory sharing across sequences
  • Used in vLLM

Cache Management Strategies

1. Static Allocation

Pre-allocate maximum size:

cache = torch.zeros(max_batch, max_length, ...)
  • ✅ Simple, no fragmentation
  • ❌ Wastes memory for short sequences

2. Dynamic Growth

Grow cache as needed:

if seq_len > cache_size: cache = torch.cat([cache, new_allocation], dim=1)
  • ✅ Memory efficient
  • ❌ Reallocation overhead

3. Block-wise Allocation

Allocate in fixed-size blocks:

blocks = [] while need_more_space: blocks.append(allocate_block())
  • ✅ Balance of efficiency and flexibility
  • ❌ More complex implementation

Advanced Techniques

Quantized KV Cache

Reduce precision to save memory:

def quantize_cache(cache, bits=8): scale = cache.abs().max() / (2**(bits-1) - 1) quantized = torch.round(cache / scale).to(torch.int8) return quantized, scale def dequantize_cache(quantized, scale): return quantized.to(torch.float16) * scale

Memory savings: 50% (FP16 → INT8) or 75% (FP16 → INT4)

Hierarchical Caching

Cache at multiple granularities:

  1. Token-level: Full resolution
  2. Segment-level: Compressed representations
  3. Document-level: Summary embeddings

Speculative Decoding Cache

Separate caches for draft and target models:

draft_cache = KVCache(small_model_config) target_cache = KVCache(large_model_config) # Generate with draft model draft_tokens = draft_model.generate(prompt, cache=draft_cache) # Verify with target model verified_tokens = target_model.verify(draft_tokens, cache=target_cache)

Performance Impact

Generation Speed

Without cache:

  • Time per token: O(n) where n is current sequence length
  • Total time: O(n²)
  • Example: 1000 tokens = 500,000 attention computations

With cache:

  • Time per token: O(1)
  • Total time: O(n)
  • Example: 1000 tokens = 1000 attention computations

Throughput Comparison

Sequence LengthWithout CacheWith CacheSpeedup
128 tokens1.2 sec0.13 sec9.2×
512 tokens19.5 sec0.51 sec38.2×
2048 tokens312 sec2.05 sec152×
8192 tokens~5000 sec8.2 sec610×

Common Issues and Solutions

1. Memory Overflow

Problem: Cache exceeds available memory Solutions:

  • Use sliding window cache
  • Implement cache eviction
  • Quantize cache values
  • Use CPU offloading

2. Cache Invalidation

Problem: Prompt changes require cache reset Solutions:

  • Incremental cache updates
  • Prefix caching for common prompts
  • Cache versioning

3. Batch Processing

Problem: Different sequences have different lengths Solutions:

  • Padding and masking
  • Dynamic batching
  • Continuous batching (vLLM)

Implementation Best Practices

1. Memory Pool Management

class CachePool: def __init__(self, total_memory): self.pool = [] self.allocated = {} def allocate(self, request_id, size): if size <= self.available(): cache = self._get_from_pool(size) self.allocated[request_id] = cache return cache return None def free(self, request_id): cache = self.allocated.pop(request_id) self._return_to_pool(cache)

2. Cache Warming

Pre-compute common prefixes:

common_prefixes = ["You are a helpful", "Please analyze", ...] for prefix in common_prefixes: cache = compute_kv_cache(prefix) cache_store[hash(prefix)] = cache

3. Monitoring

Track cache metrics:

  • Hit rate
  • Memory usage
  • Eviction rate
  • Recomputation frequency

Conclusion

KV caching transforms LLM inference from quadratic to linear complexity, making real-time text generation feasible. Understanding cache dynamics is essential for optimizing inference performance and managing memory constraints in production deployments.

If you found this explanation helpful, consider sharing it with others.

Mastodon