Rotary Position Embeddings (RoPE)
Understand Rotary Position Embeddings, the elegant position encoding method that encodes relative positions through rotation matrices, used in LLaMA, GPT-NeoX, and most modern LLMs.
Best viewed on desktop for optimal interactive experience
Rotary Position Embeddings: Elegant Position Encoding Through Rotation
Rotary Position Embeddings (RoPE) is a position encoding method that encodes absolute positions while naturally capturing relative position information through rotation matrices in complex space. It's become the standard in modern LLMs like LLaMA, Mistral, and Qwen.
Interactive RoPE Visualization
See how positions are encoded through rotations in 2D space:
Rotation in 2D Space
Position Embedding Matrix
Pos | Dim 0 | Dim 1 | Dim 2 | Dim 3 | ||||
---|---|---|---|---|---|---|---|---|
cos | sin | cos | sin | cos | sin | cos | sin | |
0 | 1.00 | 0.00 | 1.00 | 0.00 | 1.00 | 0.00 | 1.00 | 0.00 |
1 | 1.00 | -0.00 | 0.81 | 0.59 | 1.00 | 0.06 | 1.00 | 0.01 |
2 | 1.00 | -0.00 | 0.31 | 0.95 | 0.99 | 0.13 | 1.00 | 0.01 |
3 | 1.00 | -0.00 | -0.31 | 0.95 | 0.98 | 0.19 | 1.00 | 0.02 |
4 | 1.00 | -0.00 | -0.81 | 0.59 | 0.97 | 0.25 | 1.00 | 0.03 |
5 | 1.00 | -0.00 | -1.00 | 0.00 | 0.95 | 0.31 | 1.00 | 0.03 |
6 | 1.00 | -0.00 | -0.81 | -0.59 | 0.93 | 0.37 | 1.00 | 0.04 |
7 | 1.00 | -0.00 | -0.31 | -0.95 | 0.90 | 0.43 | 1.00 | 0.04 |
Relative Position
Dot product between positions m and n depends only on (m-n), naturally encoding relative distances.
No Parameters
RoPE requires zero learned parameters, using deterministic rotations based on position.
Extrapolation
Can handle sequences longer than training through interpolation or base frequency adjustment.
The Core Insight
Traditional position encodings add position information to embeddings. RoPE instead rotates the embedding vectors based on their position, with key properties:
- Relative positions emerge from rotation differences
- Long-range decay naturally occurs
- Extrapolation to unseen lengths works better
- No additional parameters needed
Mathematical Foundation
The Rotation Formula
For position m and dimension pair (2i, 2i+1):
Where θi = 10000-2i/d controls rotation frequency.
Complex Number View
Equivalently in complex space:
This rotation preserves:
- Vector magnitude: |RoPE(x, m)| = |x|
- Relative angles: ⟨RoPE(q, m), RoPE(k, n)⟩ = f(q, k, m-n)
How RoPE Works
Step 1: Pair Dimensions
Split d-dimensional vectors into d/2 pairs:
# Original vector: [x0, x1, x2, x3, ..., xd-1] # Paired: [(x0,x1), (x2,x3), ..., (xd-2,xd-1)]
Step 2: Apply Rotations
Each pair rotates by position-dependent angle:
def rope_rotation(x, position, dim_pair): theta = 10000 ** (-2 * dim_pair / d_model) angle = position * theta cos_angle = np.cos(angle) sin_angle = np.sin(angle) x_rot = np.zeros_like(x) x_rot[0] = x[0] * cos_angle - x[1] * sin_angle x_rot[1] = x[0] * sin_angle + x[1] * cos_angle return x_rot
Step 3: Relative Position Emerges
When computing attention between positions m and n:
The dot product depends on relative position (m - n)!
Implementation
PyTorch Implementation
class RotaryPositionEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # Precompute frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) # Precompute cos and sin for efficiency self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len): """Precompute cos and sin values""" self.max_seq_len_cached = seq_len t = torch.arange(seq_len, dtype=self.inv_freq.dtype) # Outer product gives all position-frequency combinations freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Create 2D rotation matrices emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def forward(self, q, k, seq_len=None): """Apply rotary embeddings to query and key""" if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) return ( self.apply_rotary_pos_emb(q, self.cos_cached, self.sin_cached), self.apply_rotary_pos_emb(k, self.cos_cached, self.sin_cached) ) @staticmethod def apply_rotary_pos_emb(x, cos, sin): """Apply rotation to input tensor""" # x: [batch, seq_len, num_heads, head_dim] batch, seq_len, num_heads, head_dim = x.shape # Split into pairs x1 = x[..., :head_dim//2] x2 = x[..., head_dim//2:] # Apply rotation cos = cos[:seq_len, :head_dim//2].unsqueeze(0).unsqueeze(2) sin = sin[:seq_len, :head_dim//2].unsqueeze(0).unsqueeze(2) # Rotate pairs y1 = x1 * cos - x2 * sin y2 = x1 * sin + x2 * cos # Concatenate back return torch.cat([y1, y2], dim=-1)
Efficient Implementation with Complex Numbers
def rope_complex(x, seq_len, base=10000): """RoPE using complex number operations""" batch, seq_len, num_heads, head_dim = x.shape # Convert to complex x_complex = x.reshape(batch, seq_len, num_heads, head_dim // 2, 2) x_complex = torch.view_as_complex(x_complex) # Compute frequencies freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) freqs = torch.outer(torch.arange(seq_len), freqs) # Apply rotation in complex space freqs_complex = torch.polar(torch.ones_like(freqs), freqs) x_rotated = x_complex * freqs_complex.unsqueeze(0).unsqueeze(2) # Convert back to real x_rotated = torch.view_as_real(x_rotated) return x_rotated.reshape(batch, seq_len, num_heads, head_dim)
Key Properties
1. Relative Position Awareness
The attention score between positions m and n:
Depends only on relative distance (m - n)!
2. Long-Range Decay
Higher frequency components (larger i) rotate faster:
- Low frequencies: Capture global patterns
- High frequencies: Capture local patterns
This creates natural decay with distance.
3. Extrapolation Capability
RoPE can handle sequences longer than training:
# Trained on 2K context model = LLaMAWithRoPE(max_position=2048) # Can extrapolate to 4K+ with techniques: # 1. Linear interpolation # 2. NTK-aware scaling # 3. Dynamic scaling
4. No Learned Parameters
Unlike learned position embeddings:
- No embedding table to store
- Generalizes to any sequence length
- Deterministic computation
Advanced Techniques
1. Linear Interpolation
For longer contexts, compress positions:
def rope_with_linear_scaling(x, seq_len, max_train_len=2048, scale=2.0): """Apply RoPE with position interpolation""" # Scale positions down position_ids = torch.arange(seq_len) / scale # Compute frequencies with scaled positions freqs = compute_freqs(position_ids, base=10000) return apply_rotation(x, freqs)
2. NTK-Aware Scaling
Adjust base frequency for better extrapolation:
def rope_ntk_scaling(dim, max_position, scale_factor): """NTK-aware RoPE scaling""" base = 10000 # Adjust base based on scale ntk_alpha = scale_factor base = base * ntk_alpha ** (dim / (dim - 2)) return RotaryPositionEmbedding(dim, max_position, base)
3. Dynamic Scaling
Adapt to sequence length dynamically:
class DynamicRoPE(nn.Module): def forward(self, x, seq_len): if seq_len <= self.max_train_len: scale = 1.0 else: scale = seq_len / self.max_train_len # Apply scaled RoPE return self.rope(x, scale=scale)
RoPE in Production Models
LLaMA Configuration
# LLaMA 2 RoPE settings config = { "hidden_size": 4096, "num_attention_heads": 32, "max_position_embeddings": 4096, "rope_theta": 10000.0, "rope_scaling": None # Can add linear/dynamic scaling }
Mistral with Sliding Window
# Mistral combines RoPE with sliding window config = { "hidden_size": 4096, "num_attention_heads": 32, "sliding_window": 4096, "max_position_embeddings": 32768, "rope_theta": 10000.0 }
GPT-NeoX
# GPT-NeoX RoPE implementation config = { "rotary_pct": 0.25, # Apply to 25% of head dim "rotary_emb_base": 10000, "use_parallel_residual": True }
Comparison with Other Methods
Method | Relative Position | Extrapolation | Parameters | Used In |
---|---|---|---|---|
Absolute PE | ❌ | Poor | O(L × D) | Original Transformer |
Relative PE | ✅ | Good | O(L²) or O(L) | T5, BERT variants |
RoPE | ✅ | Excellent | 0 | LLaMA, Mistral |
ALiBi | ✅ | Excellent | 0 | BLOOM, MPT |
Sinusoidal | ❌ | Good | 0 | Original Transformer |
Visualization of Rotation Effects
Position 0 vs Position 10
import matplotlib.pyplot as plt import numpy as np def visualize_rope_rotation(): # Original vector v = np.array([1, 0]) # Positions to show positions = [0, 5, 10, 20] theta = np.pi / 10 # Base frequency fig, axes = plt.subplots(1, len(positions), figsize=(12, 3)) for i, pos in enumerate(positions): angle = pos * theta # Rotation matrix R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) v_rot = R @ v axes[i].arrow(0, 0, v_rot[0], v_rot[1], head_width=0.1, color='blue') axes[i].set_xlim(-1.5, 1.5) axes[i].set_ylim(-1.5, 1.5) axes[i].set_aspect('equal') axes[i].set_title(f'Position {pos}') axes[i].grid(True, alpha=0.3) plt.tight_layout() plt.show()
Common Issues and Solutions
Issue 1: Precision Loss at Long Contexts
# Problem: float32 loses precision for large positions # Solution: Use float64 or bfloat16 for frequencies freqs = freqs.to(torch.float64) cos_sin = torch.cos(freqs), torch.sin(freqs) cos_sin = (cos_sin[0].to(dtype), cos_sin[1].to(dtype))
Issue 2: Memory for Cached Values
# Problem: Storing cos/sin for all positions # Solution: Compute on-the-fly for very long sequences def dynamic_rope(x, positions): # Compute only needed positions freqs = compute_freqs_for_positions(positions) return apply_rotation(x, freqs)
Issue 3: Different Head Dimensions
# Problem: Head dim not divisible by 2 # Solution: Apply RoPE to portion of dimensions def partial_rope(x, rotary_dim): x_rope = x[..., :rotary_dim] x_pass = x[..., rotary_dim:] x_rope = apply_rope(x_rope) return torch.cat([x_rope, x_pass], dim=-1)
Best Practices
- Precompute when possible: Cache cos/sin values
- Use appropriate dtype: bfloat16 for training, float16 for inference
- Apply to Q and K only: Not to V (preserves information)
- Consider partial application: Can apply to subset of dimensions
- Test extrapolation: Verify behavior on longer sequences
Implementation Tips
Fused Kernel
@torch.jit.script def fused_rope(q, k, cos, sin): """Fused RoPE application""" # Apply to both Q and K in single kernel q_rot = apply_rotary(q, cos, sin) k_rot = apply_rotary(k, cos, sin) return q_rot, k_rot
Memory-Efficient Version
def memory_efficient_rope(qk, positions): """Apply RoPE without storing full cos/sin cache""" batch, seq_len, heads, dim = qk.shape for i in range(seq_len): pos = positions[i] freqs = compute_freqs_at_position(pos, dim) qk[:, i] = rotate_at_position(qk[:, i], freqs) return qk