Masked and Causal Attention

Learn how masked attention enables autoregressive generation and prevents information leakage in transformers, essential for language models and sequential generation.

Best viewed on desktop for optimal interactive experience

Masked and Causal Attention: Preserving Causality in Generation

Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.

Interactive Masked Attention Visualization

Explore how masking patterns control information flow in attention:

Causal (GPT) Mask
Autoregressive: only see past
Token Sequence
0: The
1: cat
2: sat
3: on
4: the
5: mat
6: and
7: slept
Attention Matrix (with mask) (with weights)
0
1
2
3
4
5
6
7
0: The
1.00
1: cat
1.00
1.00
2: sat
1.00
1.00
1.00
3: on
1.00
1.00
1.00
1.00
4: the
1.00
1.00
1.00
1.00
1.00
5: mat
1.00
1.00
1.00
1.00
1.00
1.00
6: and
1.00
1.00
1.00
1.00
1.00
1.00
1.00
7: slept
1.00
1.00
1.00
1.00
1.00
1.00
1.00
1.00
Can Attend
Masked (Cannot Attend)
High Attention
Low Attention
Position 3 ("on") can attend to:The, cat, sat, on

Why Masked Attention?

The Information Leakage Problem

In standard self-attention, every position can attend to every other position:

  • During training: Model can "cheat" by looking at future tokens
  • During inference: Future tokens don't exist yet

Solution: Apply masks to prevent attending to future positions

Types of Masking

  1. Causal Mask: For autoregressive generation (GPT-style)
  2. Padding Mask: For variable-length sequences
  3. Custom Masks: For specific attention patterns
  4. Combined Masks: Multiple masks applied together

How Causal Masking Works

The Causal Mask

For a sequence of length n, the causal mask is a lower triangular matrix:

Mij = \begin{cases} 0 & \text{if } i ≥ j \ -∞ & \text{if } i < j \end{cases}

This ensures position i can only attend to positions 0 through i.

Applying the Mask

def apply_causal_mask(scores, mask): """ scores: [batch, heads, seq_len, seq_len] mask: [seq_len, seq_len] or broadcastable shape """ # Replace masked positions with -inf scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax will turn -inf into 0 attention = F.softmax(scores, dim=-1) return attention

Visualization of Causal Mask

Position: 0 1 2 3 4 0 [1 0 0 0 0] → Can only see position 0 1 [1 1 0 0 0] → Can see positions 0-1 2 [1 1 1 0 0] → Can see positions 0-2 3 [1 1 1 1 0] → Can see positions 0-3 4 [1 1 1 1 1] → Can see all previous

Implementation

Creating Causal Masks

def create_causal_mask(seq_len): """Create a causal mask for autoregressive attention""" # Lower triangular matrix mask = torch.tril(torch.ones(seq_len, seq_len)) return mask def create_causal_mask_efficient(seq_len): """Memory-efficient version using broadcasting""" row_indices = torch.arange(seq_len).unsqueeze(-1) col_indices = torch.arange(seq_len) mask = row_indices >= col_indices return mask

Masked Self-Attention

class MaskedSelfAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) # Register causal mask as buffer self.register_buffer( "causal_mask", torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024) ) def forward(self, x, mask=None): batch_size, seq_len, d_model = x.size() # Generate Q, K, V Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k) K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k) V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k) # Transpose for attention Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) # Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # Apply causal mask causal_mask = self.causal_mask[:, :, :seq_len, :seq_len] scores = scores.masked_fill(causal_mask == 0, float('-inf')) # Apply additional mask if provided (e.g., padding) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax and dropout attention = F.softmax(scores, dim=-1) attention = self.dropout(attention) # Apply attention to values context = torch.matmul(attention, V) # Reshape and project context = context.transpose(1, 2).contiguous() context = context.view(batch_size, seq_len, d_model) output = self.W_o(context) return output, attention

Training vs Inference

Training: Parallel Processing

During training, we can process the entire sequence at once:

def train_step(model, input_ids, target_ids): # Process entire sequence with causal mask logits = model(input_ids) # Causal mask applied internally # Compute loss for all positions loss = F.cross_entropy( logits.view(-1, vocab_size), target_ids.view(-1) ) return loss

Inference: Sequential Generation

During inference, generate one token at a time:

def generate(model, prompt, max_length=100): input_ids = tokenize(prompt) for _ in range(max_length): # Get logits for all positions with torch.no_grad(): logits = model(input_ids) # Only use the last position's logits next_token_logits = logits[:, -1, :] # Sample next token next_token = torch.multinomial( F.softmax(next_token_logits, dim=-1), num_samples=1 ) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=1) if next_token == eos_token_id: break return input_ids

KV Cache Optimization

For efficient generation, cache previous key-value pairs:

class CachedAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.attention = MaskedSelfAttention(d_model, n_heads) self.cache = {'k': None, 'v': None} def forward(self, x, use_cache=True): batch_size, seq_len, _ = x.size() # Compute K, V for current token(s) K_new = self.W_k(x) V_new = self.W_v(x) if use_cache and self.cache['k'] is not None: # Concatenate with cached K, V K = torch.cat([self.cache['k'], K_new], dim=1) V = torch.cat([self.cache['v'], V_new], dim=1) else: K = K_new V = V_new # Update cache if use_cache: self.cache['k'] = K self.cache['v'] = V # Compute Q only for new positions Q = self.W_q(x) # Attention computation return self.attention(Q, K, V)

Types of Attention Masks

1. Standard Causal Mask

# Lower triangular matrix mask = torch.tril(torch.ones(n, n))

2. Padding Mask

def create_padding_mask(lengths, max_len): """Mask padding tokens""" batch_size = len(lengths) mask = torch.zeros(batch_size, max_len) for i, length in enumerate(lengths): mask[i, :length] = 1 return mask

3. Prefix LM Mask

def create_prefix_lm_mask(seq_len, prefix_len): """Bidirectional attention for prefix, causal for generation""" mask = torch.ones(seq_len, seq_len) # Causal mask for positions after prefix mask[prefix_len:, :prefix_len] = 0 mask[prefix_len:, prefix_len:] = torch.tril( torch.ones(seq_len - prefix_len, seq_len - prefix_len) ) return mask

4. Block-Sparse Mask

def create_block_sparse_mask(seq_len, block_size): """Attention within blocks + global attention""" mask = torch.zeros(seq_len, seq_len) # Local blocks for i in range(0, seq_len, block_size): end = min(i + block_size, seq_len) mask[i:end, i:end] = 1 # Global tokens (first few) mask[:, :2] = 1 # All can attend to first 2 mask[:2, :] = 1 # First 2 can attend to all return mask

Attention Patterns with Masking

Visualization of Different Masks

def visualize_mask(mask, title): plt.figure(figsize=(6, 6)) plt.imshow(mask, cmap='Blues', interpolation='nearest') plt.colorbar() plt.title(title) plt.xlabel('Key Position') plt.ylabel('Query Position') plt.tight_layout() # Examples visualize_mask(create_causal_mask(8), "Causal Mask") visualize_mask(create_prefix_lm_mask(8, 3), "Prefix LM Mask") visualize_mask(create_block_sparse_mask(16, 4), "Block-Sparse Mask")

Special Considerations

1. Numerical Stability

# Use -1e9 instead of -inf for better stability mask_value = -1e9 # Large negative but not infinite # Or use torch.finfo mask_value = torch.finfo(scores.dtype).min

2. Efficient Masking

# Pre-compute masks when possible class EfficientMaskedAttention(nn.Module): def __init__(self, max_seq_len=1024): super().__init__() # Pre-compute and register as buffer mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) self.register_buffer('mask', mask) def get_mask(self, seq_len): # Slice pre-computed mask return self.mask[:seq_len, :seq_len]

3. Flash Attention with Causal Mask

# PyTorch 2.0+ optimized implementation output = F.scaled_dot_product_attention( Q, K, V, is_causal=True, # Automatically applies causal mask dropout_p=0.1 )

Common Applications

Language Modeling (GPT)

class GPTBlock(nn.Module): def __init__(self, config): super().__init__() self.masked_attn = MaskedSelfAttention( config.d_model, config.n_heads ) self.mlp = FeedForward(config.d_model) self.ln1 = nn.LayerNorm(config.d_model) self.ln2 = nn.LayerNorm(config.d_model) def forward(self, x): # Masked self-attention with residual x = x + self.masked_attn(self.ln1(x))[0] # Feedforward with residual x = x + self.mlp(self.ln2(x)) return x

Decoder in Seq2Seq

class TransformerDecoder(nn.Module): def forward(self, x, encoder_output): # Masked self-attention (causal) x = self.masked_self_attn(x) # Cross-attention (no mask needed) x = self.cross_attn(x, encoder_output) # Feedforward x = self.ffn(x) return x

Performance Implications

Memory Usage

  • Standard: O(seq_len²) for mask storage
  • Optimized: O(1) with on-the-fly generation
  • Flash Attention: Fused kernels eliminate mask materialization

Computational Cost

  • Masking itself: O(seq_len²) comparisons
  • Can be fused with attention computation
  • Negligible overhead with proper implementation

Best Practices

  1. Pre-compute masks when sequence length is known
  2. Use buffers for fixed masks to avoid re-allocation
  3. Leverage built-in functions like is_causal in PyTorch 2.0+
  4. Combine masks efficiently using logical operations
  5. Profile memory usage for long sequences

Common Pitfalls

Pitfall 1: Wrong Mask Shape

# Wrong: mask shape doesn't match scores mask = torch.tril(torch.ones(seq_len, seq_len)) # [seq_len, seq_len] scores = ... # [batch, heads, seq_len, seq_len] # Correct: broadcast-compatible shape mask = mask.view(1, 1, seq_len, seq_len)

Pitfall 2: Forgetting to Mask During Inference

# Wrong: No masking during generation logits = model(input_ids, mask=None) # Correct: Always apply causal mask logits = model(input_ids, causal_mask=True)

Pitfall 3: Mask Value Too Small

# Wrong: Small negative values don't work scores.masked_fill_(mask == 0, -1) # Too small! # Correct: Large negative value scores.masked_fill_(mask == 0, -1e9)

If you found this explanation helpful, consider sharing it with others.

Mastodon