Masked and Causal Attention
Learn how masked attention enables autoregressive generation and prevents information leakage in transformers, essential for language models and sequential generation.
Best viewed on desktop for optimal interactive experience
Masked and Causal Attention: Preserving Causality in Generation
Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.
Interactive Masked Attention Visualization
Explore how masking patterns control information flow in attention:
Why Masked Attention?
The Information Leakage Problem
In standard self-attention, every position can attend to every other position:
- During training: Model can "cheat" by looking at future tokens
- During inference: Future tokens don't exist yet
Solution: Apply masks to prevent attending to future positions
Types of Masking
- Causal Mask: For autoregressive generation (GPT-style)
- Padding Mask: For variable-length sequences
- Custom Masks: For specific attention patterns
- Combined Masks: Multiple masks applied together
How Causal Masking Works
The Causal Mask
For a sequence of length n, the causal mask is a lower triangular matrix:
This ensures position i can only attend to positions 0 through i.
Applying the Mask
def apply_causal_mask(scores, mask): """ scores: [batch, heads, seq_len, seq_len] mask: [seq_len, seq_len] or broadcastable shape """ # Replace masked positions with -inf scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax will turn -inf into 0 attention = F.softmax(scores, dim=-1) return attention
Visualization of Causal Mask
Position: 0 1 2 3 4 0 [1 0 0 0 0] → Can only see position 0 1 [1 1 0 0 0] → Can see positions 0-1 2 [1 1 1 0 0] → Can see positions 0-2 3 [1 1 1 1 0] → Can see positions 0-3 4 [1 1 1 1 1] → Can see all previous
Implementation
Creating Causal Masks
def create_causal_mask(seq_len): """Create a causal mask for autoregressive attention""" # Lower triangular matrix mask = torch.tril(torch.ones(seq_len, seq_len)) return mask def create_causal_mask_efficient(seq_len): """Memory-efficient version using broadcasting""" row_indices = torch.arange(seq_len).unsqueeze(-1) col_indices = torch.arange(seq_len) mask = row_indices >= col_indices return mask
Masked Self-Attention
class MaskedSelfAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) # Register causal mask as buffer self.register_buffer( "causal_mask", torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024) ) def forward(self, x, mask=None): batch_size, seq_len, d_model = x.size() # Generate Q, K, V Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k) K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k) V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k) # Transpose for attention Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) # Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # Apply causal mask causal_mask = self.causal_mask[:, :, :seq_len, :seq_len] scores = scores.masked_fill(causal_mask == 0, float('-inf')) # Apply additional mask if provided (e.g., padding) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax and dropout attention = F.softmax(scores, dim=-1) attention = self.dropout(attention) # Apply attention to values context = torch.matmul(attention, V) # Reshape and project context = context.transpose(1, 2).contiguous() context = context.view(batch_size, seq_len, d_model) output = self.W_o(context) return output, attention
Training vs Inference
Training: Parallel Processing
During training, we can process the entire sequence at once:
def train_step(model, input_ids, target_ids): # Process entire sequence with causal mask logits = model(input_ids) # Causal mask applied internally # Compute loss for all positions loss = F.cross_entropy( logits.view(-1, vocab_size), target_ids.view(-1) ) return loss
Inference: Sequential Generation
During inference, generate one token at a time:
def generate(model, prompt, max_length=100): input_ids = tokenize(prompt) for _ in range(max_length): # Get logits for all positions with torch.no_grad(): logits = model(input_ids) # Only use the last position's logits next_token_logits = logits[:, -1, :] # Sample next token next_token = torch.multinomial( F.softmax(next_token_logits, dim=-1), num_samples=1 ) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=1) if next_token == eos_token_id: break return input_ids
KV Cache Optimization
For efficient generation, cache previous key-value pairs:
class CachedAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.attention = MaskedSelfAttention(d_model, n_heads) self.cache = {'k': None, 'v': None} def forward(self, x, use_cache=True): batch_size, seq_len, _ = x.size() # Compute K, V for current token(s) K_new = self.W_k(x) V_new = self.W_v(x) if use_cache and self.cache['k'] is not None: # Concatenate with cached K, V K = torch.cat([self.cache['k'], K_new], dim=1) V = torch.cat([self.cache['v'], V_new], dim=1) else: K = K_new V = V_new # Update cache if use_cache: self.cache['k'] = K self.cache['v'] = V # Compute Q only for new positions Q = self.W_q(x) # Attention computation return self.attention(Q, K, V)
Types of Attention Masks
1. Standard Causal Mask
# Lower triangular matrix mask = torch.tril(torch.ones(n, n))
2. Padding Mask
def create_padding_mask(lengths, max_len): """Mask padding tokens""" batch_size = len(lengths) mask = torch.zeros(batch_size, max_len) for i, length in enumerate(lengths): mask[i, :length] = 1 return mask
3. Prefix LM Mask
def create_prefix_lm_mask(seq_len, prefix_len): """Bidirectional attention for prefix, causal for generation""" mask = torch.ones(seq_len, seq_len) # Causal mask for positions after prefix mask[prefix_len:, :prefix_len] = 0 mask[prefix_len:, prefix_len:] = torch.tril( torch.ones(seq_len - prefix_len, seq_len - prefix_len) ) return mask
4. Block-Sparse Mask
def create_block_sparse_mask(seq_len, block_size): """Attention within blocks + global attention""" mask = torch.zeros(seq_len, seq_len) # Local blocks for i in range(0, seq_len, block_size): end = min(i + block_size, seq_len) mask[i:end, i:end] = 1 # Global tokens (first few) mask[:, :2] = 1 # All can attend to first 2 mask[:2, :] = 1 # First 2 can attend to all return mask
Attention Patterns with Masking
Visualization of Different Masks
def visualize_mask(mask, title): plt.figure(figsize=(6, 6)) plt.imshow(mask, cmap='Blues', interpolation='nearest') plt.colorbar() plt.title(title) plt.xlabel('Key Position') plt.ylabel('Query Position') plt.tight_layout() # Examples visualize_mask(create_causal_mask(8), "Causal Mask") visualize_mask(create_prefix_lm_mask(8, 3), "Prefix LM Mask") visualize_mask(create_block_sparse_mask(16, 4), "Block-Sparse Mask")
Special Considerations
1. Numerical Stability
# Use -1e9 instead of -inf for better stability mask_value = -1e9 # Large negative but not infinite # Or use torch.finfo mask_value = torch.finfo(scores.dtype).min
2. Efficient Masking
# Pre-compute masks when possible class EfficientMaskedAttention(nn.Module): def __init__(self, max_seq_len=1024): super().__init__() # Pre-compute and register as buffer mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) self.register_buffer('mask', mask) def get_mask(self, seq_len): # Slice pre-computed mask return self.mask[:seq_len, :seq_len]
3. Flash Attention with Causal Mask
# PyTorch 2.0+ optimized implementation output = F.scaled_dot_product_attention( Q, K, V, is_causal=True, # Automatically applies causal mask dropout_p=0.1 )
Common Applications
Language Modeling (GPT)
class GPTBlock(nn.Module): def __init__(self, config): super().__init__() self.masked_attn = MaskedSelfAttention( config.d_model, config.n_heads ) self.mlp = FeedForward(config.d_model) self.ln1 = nn.LayerNorm(config.d_model) self.ln2 = nn.LayerNorm(config.d_model) def forward(self, x): # Masked self-attention with residual x = x + self.masked_attn(self.ln1(x))[0] # Feedforward with residual x = x + self.mlp(self.ln2(x)) return x
Decoder in Seq2Seq
class TransformerDecoder(nn.Module): def forward(self, x, encoder_output): # Masked self-attention (causal) x = self.masked_self_attn(x) # Cross-attention (no mask needed) x = self.cross_attn(x, encoder_output) # Feedforward x = self.ffn(x) return x
Performance Implications
Memory Usage
- Standard: O(seq_len²) for mask storage
- Optimized: O(1) with on-the-fly generation
- Flash Attention: Fused kernels eliminate mask materialization
Computational Cost
- Masking itself: O(seq_len²) comparisons
- Can be fused with attention computation
- Negligible overhead with proper implementation
Best Practices
- Pre-compute masks when sequence length is known
- Use buffers for fixed masks to avoid re-allocation
- Leverage built-in functions like
is_causal
in PyTorch 2.0+ - Combine masks efficiently using logical operations
- Profile memory usage for long sequences
Common Pitfalls
Pitfall 1: Wrong Mask Shape
# Wrong: mask shape doesn't match scores mask = torch.tril(torch.ones(seq_len, seq_len)) # [seq_len, seq_len] scores = ... # [batch, heads, seq_len, seq_len] # Correct: broadcast-compatible shape mask = mask.view(1, 1, seq_len, seq_len)
Pitfall 2: Forgetting to Mask During Inference
# Wrong: No masking during generation logits = model(input_ids, mask=None) # Correct: Always apply causal mask logits = model(input_ids, causal_mask=True)
Pitfall 3: Mask Value Too Small
# Wrong: Small negative values don't work scores.masked_fill_(mask == 0, -1) # Too small! # Correct: Large negative value scores.masked_fill_(mask == 0, -1e9)