Masked and Causal Attention
Learn how masked attention enables autoregressive generation and prevents information leakage in transformers, essential for language models and sequential generation.
Best viewed on desktop for optimal interactive experience
Masked and Causal Attention: Preserving Causality in Generation
Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.
Interactive Masked Attention Visualization
Explore how masking patterns control information flow in attention:
Information Leakage in Standard Attention
Without masking, attention can access future tokens
The Problem
During Training: Model can "cheat" by seeing future tokens
During Inference: Future tokens don't exist yet
Result: Train/test mismatch breaks autoregressive generation
Every position can see all positions (including future)
Apply Masks to Control Information Flow
Set masked positions to -∞ before softmax
Causal Masking Formula
Position i can only attend to positions 0 through i
Why -∞?
Softmax(-∞) = 0
Completely blocks attention
Prevents information leakage
Different Masking Patterns
Explore various mask types and their uses
Causal Mask
Autoregressive generation (GPT)
Token Sequence
Can attend to: 0:The, 1:cat, 2:sat, 3:on, 4:the
Enables Autoregressive Generation
Essential for language models and sequential generation
Training Benefits
Process entire sequence in parallel with causal mask
Compute loss for all positions simultaneously (teacher forcing)
Fast training on GPUs with massive parallelization
Inference Benefits
Generate tokens one at a time autoregressively
Cache previous keys/values for efficiency (KV cache)
Consistent with how model was trained
Additional Mask Benefits
Efficiently handle variable-length sequences in batches
Bidirectional context for prompts, causal for generation
Reduce computation for long sequences (local + global attention)
Complexity
Memory: O(n²) for mask (often pre-computed)
Compute: Negligible overhead when fused
Used in All Autoregressive Models
Foundation of GPT, LLaMA, and decoder transformers
GPT Family
Mask Type: Causal (lower triangular)
Use Case: Language modeling, text generation
Training: Parallel with causal mask
Inference: Autoregressive with KV cache
LLaMA / Mistral
Mask Type: Causal + optional sliding window
Optimization: Flash Attention with is_causal=True
Cache: Efficient KV cache with grouped queries
Performance: Fused kernels eliminate mask materialization
T5 / BART Decoders
Mask Type: Causal for self-attention, none for cross-attention
Architecture: Encoder-decoder with different masking
Use Case: Translation, summarization
Flexibility: Combine multiple mask types
Modern Optimizations: PyTorch 2.0+ F.scaled_dot_product_attention(Q, K, V, is_causal=True) automatically applies causal masking with fused kernels for 2-4× speedup
Why Masked Attention?
The Information Leakage Problem
In standard self-attention, every position can attend to every other position:
- During training: Model can "cheat" by looking at future tokens
- During inference: Future tokens don't exist yet
Solution: Apply masks to prevent attending to future positions
Types of Masking
- Causal Mask: For autoregressive generation (GPT-style)
- Padding Mask: For variable-length sequences
- Custom Masks: For specific attention patterns
- Combined Masks: Multiple masks applied together
How Causal Masking Works
The Causal Mask
For a sequence of length n, the causal mask is a lower triangular matrix:
This ensures position i can only attend to positions 0 through i.
Applying the Mask
Mask Application Process:
- Compute attention scores: Q × K^T / √d_k
- Apply mask: Replace masked positions with -∞
- Apply softmax: -∞ becomes 0, preventing attention flow
- Result: Clean attention weights with no information leakage
Key Insight: Softmax(-∞) = 0, completely blocking attention to masked positions
Implementation
Creating Causal Masks
Standard Approach:
- Create lower triangular matrix where position i can attend to positions 0...i
- Use torch.tril() for efficient generation
- Shape: [seq_len, seq_len]
Memory-Efficient Approach:
- Generate mask on-the-fly using broadcasting with row/column indices
- Compare row_indices >= col_indices
- No memory allocation for full matrix
Masked Self-Attention
Key Components:
- Projections: Standard Q, K, V linear layers
- Causal Mask Buffer: Pre-computed lower triangular matrix registered as buffer
- Mask Application: Scores masked before softmax
- Optional Padding Mask: Can combine causal + padding masks
Forward Pass Steps:
- Project input to Q, K, V and reshape for multi-head attention
- Compute attention scores: Q × K^T / √d_k
- Slice pre-computed causal mask to current sequence length
- Apply mask: scores.masked_fill(mask == 0, -∞)
- Apply softmax and dropout
- Apply attention to values
- Reshape and project output
Training vs Inference
Training: Parallel Processing
Efficient Batch Training:
- Process entire sequence at once with causal mask applied
- Compute logits for all positions in parallel
- Calculate loss across all positions simultaneously (teacher forcing)
- GPU can parallelize across sequence dimension
Benefit: Extremely fast training compared to sequential generation
Inference: Sequential Generation
Autoregressive Token Generation:
- Start with prompt tokens
- Run forward pass to get logits for all positions
- Use only last position's logits to predict next token
- Sample next token (greedy, top-k, or nucleus sampling)
- Append to sequence and repeat
- Stop at max length or end-of-sequence token
Key Difference: Training is parallel, inference is sequential
KV Cache Optimization
Efficient Generation with KV Cache:
The Problem: Recomputing K,V for all previous tokens at each step is wasteful
The Solution:
- Store computed K,V pairs in cache
- For each new token, compute only new K,V
- Concatenate with cached K,V from previous tokens
- Compute Q only for current token
- Massive speedup for long sequences
Cache Management:
- Initialize empty cache
- For each generated token:
- Compute K,V for new token only
- Concatenate with cached K,V
- Update cache with full K,V
- Compute attention with full context
- Reuse cache across generation steps
Types of Attention Masks
1. Standard Causal Mask
Purpose: Autoregressive generation (GPT-style) Pattern: Lower triangular matrix Use Case: Language modeling, text generation
2. Padding Mask
Purpose: Handle variable-length sequences in batches Pattern: Mask positions beyond actual sequence length Implementation: Create binary mask based on sequence lengths Use Case: Efficient batching with different length inputs
3. Prefix LM Mask
Purpose: Bidirectional attention for prefix, causal for generation Pattern: Full attention within prefix, then causal Use Case: Given context (bidirectional), generate completion (causal) Example: T5, UL2 models
4. Block-Sparse Mask
Purpose: Reduce computation for long sequences Pattern: Local attention within blocks + global attention to special tokens Use Case: Long-context models (Longformer, BigBird) Benefit: O(n) complexity instead of O(n²)
Attention Patterns with Masking
Visualization of Different Masks
Interactive Visualization: Use the interactive component above to explore different mask patterns
Common Patterns:
- Causal: Lower triangular (each position attends to itself and previous)
- Padding: Block pattern (mask padding tokens)
- Prefix LM: Hybrid pattern (bidirectional prefix + causal suffix)
- Block-Sparse: Diagonal blocks + vertical/horizontal stripes for global tokens
Special Considerations
1. Numerical Stability
Mask Value Selection:
- Use -1e9 instead of -∞ for better numerical stability
- Or use torch.finfo(dtype).min for dtype-specific minimum
- Prevents NaN issues in some implementations
2. Efficient Masking
Pre-computation Strategy:
- Register causal mask as buffer in module initialization
- Pre-compute for maximum sequence length
- Slice to actual sequence length during forward pass
- Avoids repeated mask creation overhead
Benefits:
- No repeated tensor allocations
- Mask moves with model to correct device
- Included in state_dict for checkpointing
3. Flash Attention with Causal Mask
Modern Optimization (PyTorch 2.0+):
- Use F.scaled_dot_product_attention() with is_causal=True
- Automatically applies causal masking with fused kernels
- 2-4× speedup over manual implementation
- Reduces memory usage (no explicit mask materialization)
Common Applications
Language Modeling (GPT)
Architecture:
- Stack of transformer decoder blocks
- Each block contains masked self-attention
- Layer normalization before attention (pre-norm)
- Residual connections around attention and feedforward
Forward Pass:
- Apply layer norm
- Masked self-attention with causal mask
- Add residual connection
- Apply layer norm
- Feedforward network
- Add residual connection
Decoder in Seq2Seq
Architecture:
- Masked self-attention on decoder inputs (causal)
- Cross-attention to encoder outputs (no mask)
- Feedforward network
Key Differences:
- Self-attention uses causal mask (can't see future)
- Cross-attention has no mask (can see all encoder outputs)
- Used in translation, summarization models (T5, BART)
Performance Implications
Memory Usage
- Standard: O(seq_len²) for mask storage
- Optimized: O(1) with on-the-fly generation
- Flash Attention: Fused kernels eliminate mask materialization
Computational Cost
- Masking itself: O(seq_len²) comparisons
- Can be fused with attention computation
- Negligible overhead with proper implementation
Best Practices
- Pre-compute masks when sequence length is known
- Use buffers for fixed masks to avoid re-allocation
- Leverage built-in functions like
is_causalin PyTorch 2.0+ - Combine masks efficiently using logical operations
- Profile memory usage for long sequences
Common Pitfalls
Pitfall 1: Wrong Mask Shape
Problem: Mask shape [seq_len, seq_len] doesn't broadcast with scores [batch, heads, seq_len, seq_len]
Solution: Reshape mask to broadcast-compatible shape [1, 1, seq_len, seq_len]
Why: PyTorch broadcasting rules require compatible dimensions
Pitfall 2: Forgetting to Mask During Inference
Problem: Not applying causal mask during generation leads to incorrect behavior
Solution: Always apply causal mask, even during inference
Impact: Without mask, model sees "future" tokens that don't exist yet
Pitfall 3: Mask Value Too Small
Problem: Using small negative values (like -1) doesn't effectively block attention
Solution: Use large negative value (-1e9 or -∞)
Reason: Softmax needs very negative values to produce near-zero outputs
