Masked and Causal Attention

Learn how masked attention enables autoregressive generation and prevents information leakage in transformers, essential for language models and sequential generation.

Best viewed on desktop for optimal interactive experience

Masked and Causal Attention: Preserving Causality in Generation

Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.

Interactive Masked Attention Visualization

Explore how masking patterns control information flow in attention:

Information Leakage in Standard Attention

Without masking, attention can access future tokens

The Problem

During Training: Model can "cheat" by seeing future tokens

During Inference: Future tokens don't exist yet

Result: Train/test mismatch breaks autoregressive generation

Standard Attention (No Mask):

Every position can see all positions (including future)

Apply Masks to Control Information Flow

Set masked positions to -∞ before softmax

Causal Masking Formula

Mij = {
0       if i ≥ j (can attend)
-∞    if i < j (masked)
}

Position i can only attend to positions 0 through i

Why -∞?

Softmax(-∞) = 0

Completely blocks attention

Prevents information leakage

Different Masking Patterns

Explore various mask types and their uses

Causal Mask

Autoregressive generation (GPT)

0
1
2
3
4
5
6
7
0
1
2
3
4
5
6
7
Can Attend
Masked

Token Sequence

0: The
1: cat
2: sat
3: on
4: the
5: mat
6: and
7: slept
Position 4 ("the")

Can attend to: 0:The, 1:cat, 2:sat, 3:on, 4:the

Enables Autoregressive Generation

Essential for language models and sequential generation

Training Benefits

Process entire sequence in parallel with causal mask

Compute loss for all positions simultaneously (teacher forcing)

Fast training on GPUs with massive parallelization

Inference Benefits

Generate tokens one at a time autoregressively

Cache previous keys/values for efficiency (KV cache)

Consistent with how model was trained

Additional Mask Benefits

Padding Masks

Efficiently handle variable-length sequences in batches

Prefix Masks

Bidirectional context for prompts, causal for generation

Sparse Masks

Reduce computation for long sequences (local + global attention)

Complexity

Memory: O(n²) for mask (often pre-computed)

Compute: Negligible overhead when fused

Used in All Autoregressive Models

Foundation of GPT, LLaMA, and decoder transformers

GPT Family

Mask Type: Causal (lower triangular)

Use Case: Language modeling, text generation

Training: Parallel with causal mask

Inference: Autoregressive with KV cache

LLaMA / Mistral

Mask Type: Causal + optional sliding window

Optimization: Flash Attention with is_causal=True

Cache: Efficient KV cache with grouped queries

Performance: Fused kernels eliminate mask materialization

T5 / BART Decoders

Mask Type: Causal for self-attention, none for cross-attention

Architecture: Encoder-decoder with different masking

Use Case: Translation, summarization

Flexibility: Combine multiple mask types

Modern Optimizations: PyTorch 2.0+ F.scaled_dot_product_attention(Q, K, V, is_causal=True) automatically applies causal masking with fused kernels for 2-4× speedup

Click on tokens or matrix cells to explore different positions. Try different mask types to see how they affect attention patterns.

Why Masked Attention?

The Information Leakage Problem

In standard self-attention, every position can attend to every other position:

  • During training: Model can "cheat" by looking at future tokens
  • During inference: Future tokens don't exist yet

Solution: Apply masks to prevent attending to future positions

Types of Masking

  1. Causal Mask: For autoregressive generation (GPT-style)
  2. Padding Mask: For variable-length sequences
  3. Custom Masks: For specific attention patterns
  4. Combined Masks: Multiple masks applied together

How Causal Masking Works

The Causal Mask

For a sequence of length n, the causal mask is a lower triangular matrix:

Mij = \begin{cases} 0 & \text{if } i ≥ j \ -∞ & \text{if } i < j \end{cases}

This ensures position i can only attend to positions 0 through i.

Applying the Mask

Mask Application Process:

  1. Compute attention scores: Q × K^T / √d_k
  2. Apply mask: Replace masked positions with -∞
  3. Apply softmax: -∞ becomes 0, preventing attention flow
  4. Result: Clean attention weights with no information leakage

Key Insight: Softmax(-∞) = 0, completely blocking attention to masked positions

Implementation

Creating Causal Masks

Standard Approach:

  • Create lower triangular matrix where position i can attend to positions 0...i
  • Use torch.tril() for efficient generation
  • Shape: [seq_len, seq_len]

Memory-Efficient Approach:

  • Generate mask on-the-fly using broadcasting with row/column indices
  • Compare row_indices >= col_indices
  • No memory allocation for full matrix

Masked Self-Attention

Key Components:

  • Projections: Standard Q, K, V linear layers
  • Causal Mask Buffer: Pre-computed lower triangular matrix registered as buffer
  • Mask Application: Scores masked before softmax
  • Optional Padding Mask: Can combine causal + padding masks

Forward Pass Steps:

  1. Project input to Q, K, V and reshape for multi-head attention
  2. Compute attention scores: Q × K^T / √d_k
  3. Slice pre-computed causal mask to current sequence length
  4. Apply mask: scores.masked_fill(mask == 0, -∞)
  5. Apply softmax and dropout
  6. Apply attention to values
  7. Reshape and project output

Training vs Inference

Training: Parallel Processing

Efficient Batch Training:

  • Process entire sequence at once with causal mask applied
  • Compute logits for all positions in parallel
  • Calculate loss across all positions simultaneously (teacher forcing)
  • GPU can parallelize across sequence dimension

Benefit: Extremely fast training compared to sequential generation

Inference: Sequential Generation

Autoregressive Token Generation:

  1. Start with prompt tokens
  2. Run forward pass to get logits for all positions
  3. Use only last position's logits to predict next token
  4. Sample next token (greedy, top-k, or nucleus sampling)
  5. Append to sequence and repeat
  6. Stop at max length or end-of-sequence token

Key Difference: Training is parallel, inference is sequential

KV Cache Optimization

Efficient Generation with KV Cache:

The Problem: Recomputing K,V for all previous tokens at each step is wasteful

The Solution:

  • Store computed K,V pairs in cache
  • For each new token, compute only new K,V
  • Concatenate with cached K,V from previous tokens
  • Compute Q only for current token
  • Massive speedup for long sequences

Cache Management:

  1. Initialize empty cache
  2. For each generated token:
    • Compute K,V for new token only
    • Concatenate with cached K,V
    • Update cache with full K,V
    • Compute attention with full context
  3. Reuse cache across generation steps

Types of Attention Masks

1. Standard Causal Mask

Purpose: Autoregressive generation (GPT-style) Pattern: Lower triangular matrix Use Case: Language modeling, text generation

2. Padding Mask

Purpose: Handle variable-length sequences in batches Pattern: Mask positions beyond actual sequence length Implementation: Create binary mask based on sequence lengths Use Case: Efficient batching with different length inputs

3. Prefix LM Mask

Purpose: Bidirectional attention for prefix, causal for generation Pattern: Full attention within prefix, then causal Use Case: Given context (bidirectional), generate completion (causal) Example: T5, UL2 models

4. Block-Sparse Mask

Purpose: Reduce computation for long sequences Pattern: Local attention within blocks + global attention to special tokens Use Case: Long-context models (Longformer, BigBird) Benefit: O(n) complexity instead of O(n²)

Attention Patterns with Masking

Visualization of Different Masks

Interactive Visualization: Use the interactive component above to explore different mask patterns

Common Patterns:

  • Causal: Lower triangular (each position attends to itself and previous)
  • Padding: Block pattern (mask padding tokens)
  • Prefix LM: Hybrid pattern (bidirectional prefix + causal suffix)
  • Block-Sparse: Diagonal blocks + vertical/horizontal stripes for global tokens

Special Considerations

1. Numerical Stability

Mask Value Selection:

  • Use -1e9 instead of -∞ for better numerical stability
  • Or use torch.finfo(dtype).min for dtype-specific minimum
  • Prevents NaN issues in some implementations

2. Efficient Masking

Pre-computation Strategy:

  • Register causal mask as buffer in module initialization
  • Pre-compute for maximum sequence length
  • Slice to actual sequence length during forward pass
  • Avoids repeated mask creation overhead

Benefits:

  • No repeated tensor allocations
  • Mask moves with model to correct device
  • Included in state_dict for checkpointing

3. Flash Attention with Causal Mask

Modern Optimization (PyTorch 2.0+):

  • Use F.scaled_dot_product_attention() with is_causal=True
  • Automatically applies causal masking with fused kernels
  • 2-4× speedup over manual implementation
  • Reduces memory usage (no explicit mask materialization)

Common Applications

Language Modeling (GPT)

Architecture:

  • Stack of transformer decoder blocks
  • Each block contains masked self-attention
  • Layer normalization before attention (pre-norm)
  • Residual connections around attention and feedforward

Forward Pass:

  1. Apply layer norm
  2. Masked self-attention with causal mask
  3. Add residual connection
  4. Apply layer norm
  5. Feedforward network
  6. Add residual connection

Decoder in Seq2Seq

Architecture:

  • Masked self-attention on decoder inputs (causal)
  • Cross-attention to encoder outputs (no mask)
  • Feedforward network

Key Differences:

  • Self-attention uses causal mask (can't see future)
  • Cross-attention has no mask (can see all encoder outputs)
  • Used in translation, summarization models (T5, BART)

Performance Implications

Memory Usage

  • Standard: O(seq_len²) for mask storage
  • Optimized: O(1) with on-the-fly generation
  • Flash Attention: Fused kernels eliminate mask materialization

Computational Cost

  • Masking itself: O(seq_len²) comparisons
  • Can be fused with attention computation
  • Negligible overhead with proper implementation

Best Practices

  1. Pre-compute masks when sequence length is known
  2. Use buffers for fixed masks to avoid re-allocation
  3. Leverage built-in functions like is_causal in PyTorch 2.0+
  4. Combine masks efficiently using logical operations
  5. Profile memory usage for long sequences

Common Pitfalls

Pitfall 1: Wrong Mask Shape

Problem: Mask shape [seq_len, seq_len] doesn't broadcast with scores [batch, heads, seq_len, seq_len]

Solution: Reshape mask to broadcast-compatible shape [1, 1, seq_len, seq_len]

Why: PyTorch broadcasting rules require compatible dimensions

Pitfall 2: Forgetting to Mask During Inference

Problem: Not applying causal mask during generation leads to incorrect behavior

Solution: Always apply causal mask, even during inference

Impact: Without mask, model sees "future" tokens that don't exist yet

Pitfall 3: Mask Value Too Small

Problem: Using small negative values (like -1) doesn't effectively block attention

Solution: Use large negative value (-1e9 or -∞)

Reason: Softmax needs very negative values to produce near-zero outputs

If you found this explanation helpful, consider sharing it with others.

Mastodon