Rotary Position Embeddings (RoPE)

Understand Rotary Position Embeddings, the elegant position encoding method that encodes relative positions through rotation matrices, used in LLaMA, GPT-NeoX, and most modern LLMs.

Best viewed on desktop for optimal interactive experience

Rotary Position Embeddings: Elegant Position Encoding Through Rotation

Rotary Position Embeddings (RoPE) is a position encoding method that encodes absolute positions while naturally capturing relative position information through rotation matrices in complex space. It's become the standard in modern LLMs like LLaMA, Mistral, and Qwen.

Interactive RoPE Visualization

See how positions are encoded through rotations in 2D space:

Position:0
Dimension Pair:

Rotation in 2D Space

Position:0
Dim Pair:0
Frequency:1.00e+0
Angle:0.0°

Position Embedding Matrix

PosDim 0Dim 1Dim 2Dim 3
cossincossincossincossin
01.000.001.000.001.000.001.000.00
11.00-0.000.810.591.000.061.000.01
21.00-0.000.310.950.990.131.000.01
31.00-0.00-0.310.950.980.191.000.02
41.00-0.00-0.810.590.970.251.000.03
51.00-0.00-1.000.000.950.311.000.03
61.00-0.00-0.81-0.590.930.371.000.04
71.00-0.00-0.31-0.950.900.431.000.04
Each position gets unique rotation angles across dimension pairs. Lower dimensions rotate slowly (global), higher dimensions rotate quickly (local).

Relative Position

Dot product between positions m and n depends only on (m-n), naturally encoding relative distances.

No Parameters

RoPE requires zero learned parameters, using deterministic rotations based on position.

Extrapolation

Can handle sequences longer than training through interpolation or base frequency adjustment.

The Core Insight

Traditional position encodings add position information to embeddings. RoPE instead rotates the embedding vectors based on their position, with key properties:

  1. Relative positions emerge from rotation differences
  2. Long-range decay naturally occurs
  3. Extrapolation to unseen lengths works better
  4. No additional parameters needed

Mathematical Foundation

The Rotation Formula

For position m and dimension pair (2i, 2i+1):

RoPE(x, m) = \begin{bmatrix} cos(mθi) & -sin(mθi) \ sin(mθi) & cos(mθi) \end{bmatrix} \begin{bmatrix} x2i \ x2i+1 \end{bmatrix}

Where θi = 10000-2i/d controls rotation frequency.

Complex Number View

Equivalently in complex space:

RoPE(x, m) = x · eimθ = x · (cos(mθ) + isin(mθ))

This rotation preserves:

  • Vector magnitude: |RoPE(x, m)| = |x|
  • Relative angles: ⟨RoPE(q, m), RoPE(k, n)⟩ = f(q, k, m-n)

How RoPE Works

Step 1: Pair Dimensions

Split d-dimensional vectors into d/2 pairs:

# Original vector: [x0, x1, x2, x3, ..., xd-1] # Paired: [(x0,x1), (x2,x3), ..., (xd-2,xd-1)]

Step 2: Apply Rotations

Each pair rotates by position-dependent angle:

def rope_rotation(x, position, dim_pair): theta = 10000 ** (-2 * dim_pair / d_model) angle = position * theta cos_angle = np.cos(angle) sin_angle = np.sin(angle) x_rot = np.zeros_like(x) x_rot[0] = x[0] * cos_angle - x[1] * sin_angle x_rot[1] = x[0] * sin_angle + x[1] * cos_angle return x_rot

Step 3: Relative Position Emerges

When computing attention between positions m and n:

⟨qm, kn⟩ = ⟨RoPE(q, m), RoPE(k, n)⟩ = ⟨q, k⟩ cos((m-n)θ) + ...

The dot product depends on relative position (m - n)!

Implementation

PyTorch Implementation

class RotaryPositionEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # Precompute frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) # Precompute cos and sin for efficiency self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len): """Precompute cos and sin values""" self.max_seq_len_cached = seq_len t = torch.arange(seq_len, dtype=self.inv_freq.dtype) # Outer product gives all position-frequency combinations freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Create 2D rotation matrices emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def forward(self, q, k, seq_len=None): """Apply rotary embeddings to query and key""" if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) return ( self.apply_rotary_pos_emb(q, self.cos_cached, self.sin_cached), self.apply_rotary_pos_emb(k, self.cos_cached, self.sin_cached) ) @staticmethod def apply_rotary_pos_emb(x, cos, sin): """Apply rotation to input tensor""" # x: [batch, seq_len, num_heads, head_dim] batch, seq_len, num_heads, head_dim = x.shape # Split into pairs x1 = x[..., :head_dim//2] x2 = x[..., head_dim//2:] # Apply rotation cos = cos[:seq_len, :head_dim//2].unsqueeze(0).unsqueeze(2) sin = sin[:seq_len, :head_dim//2].unsqueeze(0).unsqueeze(2) # Rotate pairs y1 = x1 * cos - x2 * sin y2 = x1 * sin + x2 * cos # Concatenate back return torch.cat([y1, y2], dim=-1)

Efficient Implementation with Complex Numbers

def rope_complex(x, seq_len, base=10000): """RoPE using complex number operations""" batch, seq_len, num_heads, head_dim = x.shape # Convert to complex x_complex = x.reshape(batch, seq_len, num_heads, head_dim // 2, 2) x_complex = torch.view_as_complex(x_complex) # Compute frequencies freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) freqs = torch.outer(torch.arange(seq_len), freqs) # Apply rotation in complex space freqs_complex = torch.polar(torch.ones_like(freqs), freqs) x_rotated = x_complex * freqs_complex.unsqueeze(0).unsqueeze(2) # Convert back to real x_rotated = torch.view_as_real(x_rotated) return x_rotated.reshape(batch, seq_len, num_heads, head_dim)

Key Properties

1. Relative Position Awareness

The attention score between positions m and n:

score(m, n) = Re[⟨qm, kn^*⟩] = Re[⟨q, k^*⟩ ei(m-n)θ]

Depends only on relative distance (m - n)!

2. Long-Range Decay

Higher frequency components (larger i) rotate faster:

  • Low frequencies: Capture global patterns
  • High frequencies: Capture local patterns

This creates natural decay with distance.

3. Extrapolation Capability

RoPE can handle sequences longer than training:

# Trained on 2K context model = LLaMAWithRoPE(max_position=2048) # Can extrapolate to 4K+ with techniques: # 1. Linear interpolation # 2. NTK-aware scaling # 3. Dynamic scaling

4. No Learned Parameters

Unlike learned position embeddings:

  • No embedding table to store
  • Generalizes to any sequence length
  • Deterministic computation

Advanced Techniques

1. Linear Interpolation

For longer contexts, compress positions:

def rope_with_linear_scaling(x, seq_len, max_train_len=2048, scale=2.0): """Apply RoPE with position interpolation""" # Scale positions down position_ids = torch.arange(seq_len) / scale # Compute frequencies with scaled positions freqs = compute_freqs(position_ids, base=10000) return apply_rotation(x, freqs)

2. NTK-Aware Scaling

Adjust base frequency for better extrapolation:

def rope_ntk_scaling(dim, max_position, scale_factor): """NTK-aware RoPE scaling""" base = 10000 # Adjust base based on scale ntk_alpha = scale_factor base = base * ntk_alpha ** (dim / (dim - 2)) return RotaryPositionEmbedding(dim, max_position, base)

3. Dynamic Scaling

Adapt to sequence length dynamically:

class DynamicRoPE(nn.Module): def forward(self, x, seq_len): if seq_len <= self.max_train_len: scale = 1.0 else: scale = seq_len / self.max_train_len # Apply scaled RoPE return self.rope(x, scale=scale)

RoPE in Production Models

LLaMA Configuration

# LLaMA 2 RoPE settings config = { "hidden_size": 4096, "num_attention_heads": 32, "max_position_embeddings": 4096, "rope_theta": 10000.0, "rope_scaling": None # Can add linear/dynamic scaling }

Mistral with Sliding Window

# Mistral combines RoPE with sliding window config = { "hidden_size": 4096, "num_attention_heads": 32, "sliding_window": 4096, "max_position_embeddings": 32768, "rope_theta": 10000.0 }

GPT-NeoX

# GPT-NeoX RoPE implementation config = { "rotary_pct": 0.25, # Apply to 25% of head dim "rotary_emb_base": 10000, "use_parallel_residual": True }

Comparison with Other Methods

MethodRelative PositionExtrapolationParametersUsed In
Absolute PEPoorO(L × D)Original Transformer
Relative PEGoodO(L²) or O(L)T5, BERT variants
RoPEExcellent0LLaMA, Mistral
ALiBiExcellent0BLOOM, MPT
SinusoidalGood0Original Transformer

Visualization of Rotation Effects

Position 0 vs Position 10

import matplotlib.pyplot as plt import numpy as np def visualize_rope_rotation(): # Original vector v = np.array([1, 0]) # Positions to show positions = [0, 5, 10, 20] theta = np.pi / 10 # Base frequency fig, axes = plt.subplots(1, len(positions), figsize=(12, 3)) for i, pos in enumerate(positions): angle = pos * theta # Rotation matrix R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) v_rot = R @ v axes[i].arrow(0, 0, v_rot[0], v_rot[1], head_width=0.1, color='blue') axes[i].set_xlim(-1.5, 1.5) axes[i].set_ylim(-1.5, 1.5) axes[i].set_aspect('equal') axes[i].set_title(f'Position {pos}') axes[i].grid(True, alpha=0.3) plt.tight_layout() plt.show()

Common Issues and Solutions

Issue 1: Precision Loss at Long Contexts

# Problem: float32 loses precision for large positions # Solution: Use float64 or bfloat16 for frequencies freqs = freqs.to(torch.float64) cos_sin = torch.cos(freqs), torch.sin(freqs) cos_sin = (cos_sin[0].to(dtype), cos_sin[1].to(dtype))

Issue 2: Memory for Cached Values

# Problem: Storing cos/sin for all positions # Solution: Compute on-the-fly for very long sequences def dynamic_rope(x, positions): # Compute only needed positions freqs = compute_freqs_for_positions(positions) return apply_rotation(x, freqs)

Issue 3: Different Head Dimensions

# Problem: Head dim not divisible by 2 # Solution: Apply RoPE to portion of dimensions def partial_rope(x, rotary_dim): x_rope = x[..., :rotary_dim] x_pass = x[..., rotary_dim:] x_rope = apply_rope(x_rope) return torch.cat([x_rope, x_pass], dim=-1)

Best Practices

  1. Precompute when possible: Cache cos/sin values
  2. Use appropriate dtype: bfloat16 for training, float16 for inference
  3. Apply to Q and K only: Not to V (preserves information)
  4. Consider partial application: Can apply to subset of dimensions
  5. Test extrapolation: Verify behavior on longer sequences

Implementation Tips

Fused Kernel

@torch.jit.script def fused_rope(q, k, cos, sin): """Fused RoPE application""" # Apply to both Q and K in single kernel q_rot = apply_rotary(q, cos, sin) k_rot = apply_rotary(k, cos, sin) return q_rot, k_rot

Memory-Efficient Version

def memory_efficient_rope(qk, positions): """Apply RoPE without storing full cos/sin cache""" batch, seq_len, heads, dim = qk.shape for i in range(seq_len): pos = positions[i] freqs = compute_freqs_at_position(pos, dim) qk[:, i] = rotate_at_position(qk[:, i], freqs) return qk

If you found this explanation helpful, consider sharing it with others.

Mastodon