Cross-Attention: Bridging Different Modalities

Understand cross-attention, the mechanism that enables transformers to align and fuse information from different sources, sequences, or modalities.

Best viewed on desktop for optimal interactive experience

Cross-Attention: Connecting Different Information Sources

Cross-attention is the bridge that allows transformers to align and combine information from different sequences, making it fundamental for tasks like translation, image captioning, and multimodal understanding.

Interactive Cross-Attention Visualization

Explore how queries from one sequence attend to keys and values from another:

Source Encoding: Encode source sequence (encoder output)
English (Encoder)
The
cat
sits
here
French (Decoder)
Le
chat
est
ici

What is Cross-Attention?

Unlike self-attention where Q, K, and V come from the same sequence, cross-attention uses:

  • Queries (Q) from one sequence (e.g., decoder)
  • Keys (K) and Values (V) from another sequence (e.g., encoder)
CrossAttention(Qtarget, Ksource, Vsource) = softmax(QtargetKsourceT√(dk))Vsource

Why Cross-Attention?

The Connection Problem

Many tasks require relating two different sequences:

  • Translation: Source language → Target language
  • Image Captioning: Image features → Text description
  • VQA: Question + Image → Answer
  • Speech Recognition: Audio → Text

Cross-attention provides the mechanism to:

  • Align elements between sequences
  • Transfer information from source to target
  • Learn relationships across modalities

How Cross-Attention Works

Step-by-Step Process

  1. Extract Representations
# Source sequence (e.g., encoder output) source_hidden = encoder(source_input) # [batch, src_len, d_model] # Target sequence (e.g., decoder state) target_hidden = decoder_self_attn(target_input) # [batch, tgt_len, d_model]
  1. Generate Q, K, V
# Queries from target Q = W_q(target_hidden) # [batch, tgt_len, d_k] # Keys and Values from source K = W_k(source_hidden) # [batch, src_len, d_k] V = W_v(source_hidden) # [batch, src_len, d_v]
  1. Compute Attention
# Each target position attends to all source positions scores = Q @ K.T / sqrt(d_k) # [batch, tgt_len, src_len] attention = softmax(scores) output = attention @ V # [batch, tgt_len, d_v]

Cross-Attention in Transformers

Encoder-Decoder Architecture

class TransformerDecoder(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.self_attention = MultiHeadAttention(d_model, n_heads) self.cross_attention = MultiHeadAttention(d_model, n_heads) self.feed_forward = FeedForward(d_model) def forward(self, x, encoder_output, mask=None): # Self-attention on decoder x = self.self_attention(x, x, x, mask=mask) # Cross-attention to encoder # Q from decoder, K,V from encoder x = self.cross_attention( query=x, # What we're looking for key=encoder_output, # Where to look value=encoder_output # What to retrieve ) x = self.feed_forward(x) return x

Information Flow

Encoder Input → Encoder → Encoder Output ↓ (K, V) Decoder Input → Self-Attn → Cross-Attn → FFN → Output ↑ (Q)

Types of Cross-Attention

1. Encoder-Decoder Attention

Classic transformer architecture:

  • Used in: Machine translation, summarization
  • Q: Decoder states
  • K, V: Encoder outputs

2. Multi-Modal Cross-Attention

Between different modalities:

  • Vision-Language: CLIP, DALL-E
  • Audio-Text: Whisper, Wav2Vec
  • Video-Text: Video understanding models

3. Memory Attention

Attending to external memory:

  • Retrieval-Augmented: RAG models
  • Memory Networks: Neural Turing Machines
  • Q: Current state
  • K, V: Memory bank

4. Cross-Attention in Diffusion

Conditioning image generation:

  • Q: Image features at timestep t
  • K, V: Text embeddings
  • Guides generation based on text

Implementation Patterns

Basic Cross-Attention

class CrossAttention(nn.Module): def __init__(self, d_model, n_heads=8, dropout=0.1): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # Transform and reshape for multi-head Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k) K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k) V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k) # Transpose for attention computation Q = Q.transpose(1, 2) # [batch, heads, tgt_len, d_k] K = K.transpose(1, 2) # [batch, heads, src_len, d_k] V = V.transpose(1, 2) # [batch, heads, src_len, d_k] # Compute attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention = F.softmax(scores, dim=-1) attention = self.dropout(attention) # Apply attention to values context = torch.matmul(attention, V) # Reshape and project context = context.transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.n_heads * self.d_k) output = self.W_o(context) return output, attention

Conditional Cross-Attention

For controlled generation:

class ConditionalCrossAttention(nn.Module): def __init__(self, d_model, d_condition): super().__init__() self.cross_attn = CrossAttention(d_model) self.condition_proj = nn.Linear(d_condition, d_model) self.gate = nn.Linear(d_model, d_model) def forward(self, x, condition): # Project condition to model dimension cond_hidden = self.condition_proj(condition) # Cross-attention with gating attended, _ = self.cross_attn(x, cond_hidden, cond_hidden) # Gated fusion gate = torch.sigmoid(self.gate(attended)) output = gate * attended + (1 - gate) * x return output

Attention Patterns in Cross-Attention

Common Patterns

  1. Alignment: Direct correspondence (translation)
  2. Coverage: Ensuring all source is attended
  3. Focusing: Attending to specific regions
  4. Distributed: Broad attention for context

Visualization

def visualize_cross_attention(attention_weights, source_tokens, target_tokens): """ attention_weights: [tgt_len, src_len] """ plt.figure(figsize=(10, 8)) plt.imshow(attention_weights, cmap='Blues', aspect='auto') plt.colorbar() # Add labels plt.xticks(range(len(source_tokens)), source_tokens, rotation=45) plt.yticks(range(len(target_tokens)), target_tokens) plt.xlabel('Source Sequence') plt.ylabel('Target Sequence') plt.title('Cross-Attention Alignment') plt.tight_layout()

Multimodal Cross-Attention

Vision-Language Example

class VisionLanguageCrossAttention(nn.Module): def __init__(self, d_visual, d_text, d_model): super().__init__() # Project to common dimension self.visual_proj = nn.Linear(d_visual, d_model) self.text_proj = nn.Linear(d_text, d_model) self.cross_attn = CrossAttention(d_model) def forward(self, image_features, text_features): # Project to common space visual = self.visual_proj(image_features) textual = self.text_proj(text_features) # Bidirectional cross-attention text_to_image, _ = self.cross_attn(textual, visual, visual) image_to_text, _ = self.cross_attn(visual, textual, textual) return text_to_image, image_to_text

Best Practices

1. Dimension Matching

Ensure query and key dimensions match:

assert Q.size(-1) == K.size(-1), "Q and K must have same dimension"

2. Proper Masking

Handle variable-length sequences:

def create_padding_mask(lengths, max_len): batch_size = len(lengths) mask = torch.zeros(batch_size, max_len) for i, length in enumerate(lengths): mask[i, :length] = 1 return mask

3. Position Information

Add positional encodings when needed:

# Add position to source for better alignment source_with_pos = source + positional_encoding(source)

4. Regularization

Prevent attention collapse:

  • Dropout on attention weights
  • Entropy regularization
  • Attention diversity loss

Common Applications

Machine Translation

# Encoder processes source encoder_out = encoder(src_tokens) # Decoder generates with cross-attention for i in range(max_length): decoder_out = decoder(tgt_tokens[:i], encoder_out) next_token = predict(decoder_out[-1]) tgt_tokens[i] = next_token

Image Captioning

# Extract image features image_features = vision_encoder(image) # Generate caption with cross-attention caption = text_decoder(image_features)

Visual Question Answering

# Encode question and image question_enc = text_encoder(question) image_enc = vision_encoder(image) # Cross-attention fusion fused = cross_attention(question_enc, image_enc, image_enc) answer = classifier(fused)

Performance Considerations

Computational Complexity

  • Time: O(n_target × n_source × d)
  • Memory: O(n_target × n_source)
  • Can be bottleneck for long sequences

Optimization Strategies

  1. Sparse Cross-Attention: Attend to subset
  2. Hierarchical: Multi-resolution attention
  3. Caching: Reuse encoder outputs
  4. Quantization: Reduce precision

Common Issues

Issue 1: Attention Drift

Problem: Attention doesn't align properly Solution:

  • Add positional encodings
  • Use supervised attention
  • Increase model capacity

Issue 2: Information Bottleneck

Problem: Too much compression in cross-attention Solution:

  • Multiple cross-attention layers
  • Increase hidden dimension
  • Use skip connections

Issue 3: Modality Gap

Problem: Different modalities don't align Solution:

  • Pre-training with alignment objectives
  • Learnable modality embeddings
  • Projection to common space

If you found this explanation helpful, consider sharing it with others.

Mastodon