Cross-Attention: Bridging Different Modalities
Understand cross-attention, the mechanism that enables transformers to align and fuse information from different sources, sequences, or modalities.
Best viewed on desktop for optimal interactive experience
Cross-Attention: Connecting Different Information Sources
Cross-attention is the bridge that allows transformers to align and combine information from different sequences, making it fundamental for tasks like translation, image captioning, and multimodal understanding.
Interactive Cross-Attention Visualization
Explore how queries from one sequence attend to keys and values from another:
What is Cross-Attention?
Unlike self-attention where Q, K, and V come from the same sequence, cross-attention uses:
- Queries (Q) from one sequence (e.g., decoder)
- Keys (K) and Values (V) from another sequence (e.g., encoder)
Why Cross-Attention?
The Connection Problem
Many tasks require relating two different sequences:
- Translation: Source language → Target language
- Image Captioning: Image features → Text description
- VQA: Question + Image → Answer
- Speech Recognition: Audio → Text
Cross-attention provides the mechanism to:
- Align elements between sequences
- Transfer information from source to target
- Learn relationships across modalities
How Cross-Attention Works
Step-by-Step Process
- Extract Representations
# Source sequence (e.g., encoder output) source_hidden = encoder(source_input) # [batch, src_len, d_model] # Target sequence (e.g., decoder state) target_hidden = decoder_self_attn(target_input) # [batch, tgt_len, d_model]
- Generate Q, K, V
# Queries from target Q = W_q(target_hidden) # [batch, tgt_len, d_k] # Keys and Values from source K = W_k(source_hidden) # [batch, src_len, d_k] V = W_v(source_hidden) # [batch, src_len, d_v]
- Compute Attention
# Each target position attends to all source positions scores = Q @ K.T / sqrt(d_k) # [batch, tgt_len, src_len] attention = softmax(scores) output = attention @ V # [batch, tgt_len, d_v]
Cross-Attention in Transformers
Encoder-Decoder Architecture
class TransformerDecoder(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.self_attention = MultiHeadAttention(d_model, n_heads) self.cross_attention = MultiHeadAttention(d_model, n_heads) self.feed_forward = FeedForward(d_model) def forward(self, x, encoder_output, mask=None): # Self-attention on decoder x = self.self_attention(x, x, x, mask=mask) # Cross-attention to encoder # Q from decoder, K,V from encoder x = self.cross_attention( query=x, # What we're looking for key=encoder_output, # Where to look value=encoder_output # What to retrieve ) x = self.feed_forward(x) return x
Information Flow
Encoder Input → Encoder → Encoder Output ↓ (K, V) Decoder Input → Self-Attn → Cross-Attn → FFN → Output ↑ (Q)
Types of Cross-Attention
1. Encoder-Decoder Attention
Classic transformer architecture:
- Used in: Machine translation, summarization
- Q: Decoder states
- K, V: Encoder outputs
2. Multi-Modal Cross-Attention
Between different modalities:
- Vision-Language: CLIP, DALL-E
- Audio-Text: Whisper, Wav2Vec
- Video-Text: Video understanding models
3. Memory Attention
Attending to external memory:
- Retrieval-Augmented: RAG models
- Memory Networks: Neural Turing Machines
- Q: Current state
- K, V: Memory bank
4. Cross-Attention in Diffusion
Conditioning image generation:
- Q: Image features at timestep t
- K, V: Text embeddings
- Guides generation based on text
Implementation Patterns
Basic Cross-Attention
class CrossAttention(nn.Module): def __init__(self, d_model, n_heads=8, dropout=0.1): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # Transform and reshape for multi-head Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k) K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k) V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k) # Transpose for attention computation Q = Q.transpose(1, 2) # [batch, heads, tgt_len, d_k] K = K.transpose(1, 2) # [batch, heads, src_len, d_k] V = V.transpose(1, 2) # [batch, heads, src_len, d_k] # Compute attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention = F.softmax(scores, dim=-1) attention = self.dropout(attention) # Apply attention to values context = torch.matmul(attention, V) # Reshape and project context = context.transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.n_heads * self.d_k) output = self.W_o(context) return output, attention
Conditional Cross-Attention
For controlled generation:
class ConditionalCrossAttention(nn.Module): def __init__(self, d_model, d_condition): super().__init__() self.cross_attn = CrossAttention(d_model) self.condition_proj = nn.Linear(d_condition, d_model) self.gate = nn.Linear(d_model, d_model) def forward(self, x, condition): # Project condition to model dimension cond_hidden = self.condition_proj(condition) # Cross-attention with gating attended, _ = self.cross_attn(x, cond_hidden, cond_hidden) # Gated fusion gate = torch.sigmoid(self.gate(attended)) output = gate * attended + (1 - gate) * x return output
Attention Patterns in Cross-Attention
Common Patterns
- Alignment: Direct correspondence (translation)
- Coverage: Ensuring all source is attended
- Focusing: Attending to specific regions
- Distributed: Broad attention for context
Visualization
def visualize_cross_attention(attention_weights, source_tokens, target_tokens): """ attention_weights: [tgt_len, src_len] """ plt.figure(figsize=(10, 8)) plt.imshow(attention_weights, cmap='Blues', aspect='auto') plt.colorbar() # Add labels plt.xticks(range(len(source_tokens)), source_tokens, rotation=45) plt.yticks(range(len(target_tokens)), target_tokens) plt.xlabel('Source Sequence') plt.ylabel('Target Sequence') plt.title('Cross-Attention Alignment') plt.tight_layout()
Multimodal Cross-Attention
Vision-Language Example
class VisionLanguageCrossAttention(nn.Module): def __init__(self, d_visual, d_text, d_model): super().__init__() # Project to common dimension self.visual_proj = nn.Linear(d_visual, d_model) self.text_proj = nn.Linear(d_text, d_model) self.cross_attn = CrossAttention(d_model) def forward(self, image_features, text_features): # Project to common space visual = self.visual_proj(image_features) textual = self.text_proj(text_features) # Bidirectional cross-attention text_to_image, _ = self.cross_attn(textual, visual, visual) image_to_text, _ = self.cross_attn(visual, textual, textual) return text_to_image, image_to_text
Best Practices
1. Dimension Matching
Ensure query and key dimensions match:
assert Q.size(-1) == K.size(-1), "Q and K must have same dimension"
2. Proper Masking
Handle variable-length sequences:
def create_padding_mask(lengths, max_len): batch_size = len(lengths) mask = torch.zeros(batch_size, max_len) for i, length in enumerate(lengths): mask[i, :length] = 1 return mask
3. Position Information
Add positional encodings when needed:
# Add position to source for better alignment source_with_pos = source + positional_encoding(source)
4. Regularization
Prevent attention collapse:
- Dropout on attention weights
- Entropy regularization
- Attention diversity loss
Common Applications
Machine Translation
# Encoder processes source encoder_out = encoder(src_tokens) # Decoder generates with cross-attention for i in range(max_length): decoder_out = decoder(tgt_tokens[:i], encoder_out) next_token = predict(decoder_out[-1]) tgt_tokens[i] = next_token
Image Captioning
# Extract image features image_features = vision_encoder(image) # Generate caption with cross-attention caption = text_decoder(image_features)
Visual Question Answering
# Encode question and image question_enc = text_encoder(question) image_enc = vision_encoder(image) # Cross-attention fusion fused = cross_attention(question_enc, image_enc, image_enc) answer = classifier(fused)
Performance Considerations
Computational Complexity
- Time: O(n_target × n_source × d)
- Memory: O(n_target × n_source)
- Can be bottleneck for long sequences
Optimization Strategies
- Sparse Cross-Attention: Attend to subset
- Hierarchical: Multi-resolution attention
- Caching: Reuse encoder outputs
- Quantization: Reduce precision
Common Issues
Issue 1: Attention Drift
Problem: Attention doesn't align properly Solution:
- Add positional encodings
- Use supervised attention
- Increase model capacity
Issue 2: Information Bottleneck
Problem: Too much compression in cross-attention Solution:
- Multiple cross-attention layers
- Increase hidden dimension
- Use skip connections
Issue 3: Modality Gap
Problem: Different modalities don't align Solution:
- Pre-training with alignment objectives
- Learnable modality embeddings
- Projection to common space