Matryoshka Embeddings

10 min

Learn about nested representations that enable flexible dimension reduction without retraining models.

Best viewed on desktop for optimal interactive experience

Matryoshka Embeddings

Matryoshka embeddings enable flexible dimension reduction through nested representations - train once, deploy at any dimension by simple truncation.

Interactive Matryoshka Visualization

Matryoshka Embeddings

Nested representations for flexible dimension reduction without retraining

Dimension Configuration

32256512768

Truncation from 768D to 128D

Preserves 16.7% of dimensions, maintaining 40.8% similarity. Memory reduced by 83.3%, with 6x speed improvement.

Nested Representation Structure

768D512D256D128DTarget: Small64D32DMatryoshka

Performance at 128D

Accuracy
94.8%
Memory
16.7%
Speed
6.0x
Similarity
40.8%

Dimension Truncation Process

Original Embedding (768D)Truncated Embedding (128D)PreservedTruncated

Dimension Trade-offs

DimensionAccuracyMemorySpeedUse Case
768D (Full)
100%100.0%1xFull quality, research
512D (Large)
99.2%66.7%1.5xProduction, high quality
256D (Medium)
97.5%33.3%3xBalanced performance
128D (Small)
94.8%16.7%6xFast search, mobile
64D (Tiny)
89.3%8.3%12xEdge devices, IoT
32D (Micro)
82.1%4.2%24xExtreme constraints

Implementation Example

Training with Matryoshka Loss

import torch
import torch.nn.functional as F

def matryoshka_loss(embeddings, labels, 
                    dims=[768, 512, 256, 128]):
    """Multi-scale contrastive loss"""
    total_loss = 0
    
    for dim in dims:
        # Truncate to current dimension
        truncated = embeddings[:, :dim]
        
        # Normalize
        truncated = F.normalize(truncated, dim=1)
        
        # Compute contrastive loss
        loss = contrastive_loss(truncated, labels)
        
        # Weight by dimension importance
        weight = math.sqrt(dim / max(dims))
        total_loss += weight * loss
    
    return total_loss

Flexible Inference

# Dynamic dimension selection
def get_embedding(text, target_dim=None):
    # Get full embedding
    full_emb = model.encode(text)
    
    if target_dim is None:
        # Auto-select based on constraints
        if memory_limited:
            target_dim = 128
        elif speed_critical:
            target_dim = 64
        else:
            target_dim = 768
    
    # Simple truncation
    return full_emb[:target_dim]

# Usage
query_emb = get_embedding(query, 128)
doc_embs = [get_embedding(d, 128) for d in docs]
scores = cosine_similarity(query_emb, doc_embs)

Matryoshka Embeddings Best Practices

When to Use

  • • Multiple deployment targets (cloud/edge)
  • • Dynamic resource constraints
  • • A/B testing different dimensions
  • • Gradual quality degradation acceptable
  • • Need for backwards compatibility

Dimension Guidelines

  • • 768D: Maximum quality, research
  • • 256-512D: Production systems
  • • 128D: Mobile applications
  • • 64D: Real-time, high-volume
  • • 32D: Extreme edge devices

Key Insight: Train once with Matryoshka loss, deploy everywhere. The same model can serve high-quality requests on powerful servers and low-latency requests on edge devices, simply by truncating embeddings.

The Matryoshka Principle

Like Russian nesting dolls, Matryoshka embeddings contain accurate representations at multiple scales within a single embedding:

768D: [█████████████████████████████████] Full representation 512D: [██████████████████████] 98% accuracy retained 256D: [███████████] 95% accuracy retained 128D: [██████] 92% accuracy retained 64D: [███] 87% accuracy retained

How It Works

Traditional vs Matryoshka

Traditional Embeddings:

# Need separate models for different dimensions model_768 = train_model(dim=768) # Full model model_256 = train_model(dim=256) # Retrain for smaller model_128 = train_model(dim=128) # Retrain again

Matryoshka Embeddings:

# Single model, multiple dimensions model = train_matryoshka_model(dims=[768, 512, 256, 128, 64]) # Use any dimension at inference embedding_768 = model.encode(text)[:768] # Full embedding_256 = model.encode(text)[:256] # Truncated embedding_128 = model.encode(text)[:128] # More truncated

Matryoshka Representation Learning (MRL)

The Loss Function

Train with multi-scale contrastive loss:

MRL = Σm ∈ M λm · ℒcontrastive(E[:m])

Where:

  • M = \{d1, d2, ..., dk\} = Set of dimensions
  • E[:m] = First m dimensions of embedding
  • λm = Weight for dimension m

Implementation

import torch import torch.nn as nn import torch.nn.functional as F class MatryoshkaModel(nn.Module): def __init__(self, encoder, dims=[768, 512, 256, 128, 64, 32]): super().__init__() self.encoder = encoder self.dims = sorted(dims, reverse=True) self.projection = nn.Linear(encoder.config.hidden_size, max(dims)) def forward(self, input_ids, attention_mask): # Get base embeddings outputs = self.encoder(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state.mean(dim=1) # Project to max dimension embeddings = self.projection(embeddings) # Normalize full embedding embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings def matryoshka_loss(self, embeddings, labels, temperature=0.07): """Multi-scale contrastive loss""" total_loss = 0 weights = [1.0] + [0.5] * (len(self.dims) - 1) # Decreasing weights for dim, weight in zip(self.dims, weights): # Truncate to current dimension truncated = embeddings[:, :dim] # Re-normalize after truncation truncated = F.normalize(truncated, p=2, dim=-1) # Compute contrastive loss similarity = torch.matmul(truncated, truncated.T) / temperature # Create labels for contrastive learning batch_size = embeddings.shape[0] labels = torch.arange(batch_size).to(embeddings.device) # Cross-entropy loss loss = F.cross_entropy(similarity, labels) # Weighted contribution total_loss += weight * loss return total_loss / sum(weights)

Training Strategy

Progressive Training

Train with increasing complexity:

def progressive_matryoshka_training(model, dataloader, epochs=10): """Train Matryoshka model progressively""" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # Start with smaller dimensions active_dims = [32, 64] for epoch in range(epochs): # Gradually add larger dimensions if epoch == 3: active_dims.append(128) if epoch == 5: active_dims.append(256) if epoch == 7: active_dims.extend([512, 768]) for batch in dataloader: embeddings = model(batch['input_ids'], batch['attention_mask']) # Loss only for active dimensions loss = 0 for dim in active_dims: truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) dim_loss = contrastive_loss(truncated, batch['labels']) loss += dim_loss / len(active_dims) optimizer.zero_grad() loss.backward() optimizer.step()

Importance-Weighted Dimensions

Earlier dimensions are more important:

def importance_weighted_loss(embeddings, labels, dims, alpha=0.5): """Weight loss by dimension importance""" total_loss = 0 for i, dim in enumerate(dims): # Exponentially decreasing importance weight = alpha ** i truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) loss = contrastive_loss(truncated, labels) total_loss += weight * loss return total_loss

Inference and Deployment

Dynamic Dimension Selection

class AdaptiveMatryoshkaIndex: def __init__(self, model, documents): self.model = model self.dims = [768, 512, 256, 128, 64, 32] self.embeddings = {} # Pre-compute embeddings at max dimension with torch.no_grad(): full_embeddings = [] for doc in documents: emb = model.encode(doc) full_embeddings.append(emb) self.full_embeddings = torch.stack(full_embeddings) def search(self, query, k=10, max_latency_ms=100): """Search with latency constraint""" # Estimate dimension based on latency budget if max_latency_ms < 20: dim = 32 elif max_latency_ms < 50: dim = 64 elif max_latency_ms < 100: dim = 128 else: dim = 256 # Encode query at selected dimension query_emb = self.model.encode(query)[:dim] query_emb = F.normalize(query_emb, p=2, dim=-1) # Truncate document embeddings doc_embs = self.full_embeddings[:, :dim] doc_embs = F.normalize(doc_embs, p=2, dim=-1) # Compute similarities similarities = torch.matmul(query_emb, doc_embs.T) # Get top-k top_k = torch.topk(similarities, k) return top_k.indices, top_k.values

Memory-Aware Deployment

def deploy_matryoshka_model(model, memory_budget_mb): """Configure model for memory constraints""" # Estimate memory per dimension bytes_per_float = 4 vocab_size = 30000 # Calculate maximum dimension max_vectors = (memory_budget_mb * 1024 * 1024) / bytes_per_float / vocab_size # Select appropriate dimension if max_vectors >= 768: return 768 elif max_vectors >= 512: return 512 elif max_vectors >= 256: return 256 elif max_vectors >= 128: return 128 else: return 64

Performance Analysis

Dimension vs Accuracy Trade-off

DimensionRelative SizeAccuracySpeedUse Case
768100%100%Research, high-quality
51267%99.2%1.5×Production servers
25633%97.5%Balanced performance
12817%94.8%Mobile, real-time
648%89.3%12×Edge devices
324%82.1%24×IoT, extreme constraints

Benchmark Results

def benchmark_dimensions(model, test_data): """Compare performance across dimensions""" results = {} for dim in [768, 512, 256, 128, 64, 32]: # Truncate embeddings query_embs = model.encode(test_data['queries'])[:, :dim] doc_embs = model.encode(test_data['documents'])[:, :dim] # Measure accuracy accuracy = compute_recall_at_k(query_embs, doc_embs, k=10) # Measure speed start = time.time() for _ in range(1000): similarities = cosine_similarity(query_embs[:10], doc_embs) latency = (time.time() - start) / 1000 # Measure memory memory_mb = (query_embs.nbytes + doc_embs.nbytes) / 1024 / 1024 results[dim] = { 'accuracy': accuracy, 'latency_ms': latency * 1000, 'memory_mb': memory_mb } return results

Advanced Techniques

1. Learned Truncation

Not all dimensions are equally important:

class LearnedTruncation(nn.Module): def __init__(self, full_dim=768): super().__init__() self.importance = nn.Parameter(torch.ones(full_dim)) def forward(self, embeddings, target_dim): # Sort dimensions by learned importance importance_sorted = torch.argsort(self.importance, descending=True) # Select most important dimensions selected = importance_sorted[:target_dim] # Reorder embeddings truncated = embeddings[:, selected] return truncated

2. Cascaded Retrieval

Use multiple dimensions for refinement:

def cascaded_search(query, documents, k=10): """Multi-stage retrieval with increasing precision""" # Stage 1: Fast filtering with 32D emb_32 = encode(query)[:32] candidates = search_32d(emb_32, top_k=1000) # Stage 2: Rerank with 128D emb_128 = encode(query)[:128] candidates = rerank_128d(emb_128, candidates, top_k=100) # Stage 3: Final ranking with 768D emb_768 = encode(query)[:768] results = rerank_768d(emb_768, candidates, top_k=k) return results

3. Dimension Prediction

Predict optimal dimension per query:

class DimensionPredictor(nn.Module): def __init__(self, input_dim=768): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 6), # 6 dimension options nn.Softmax(dim=-1) ) self.dims = [32, 64, 128, 256, 512, 768] def forward(self, query_embedding): # Predict dimension probabilities probs = self.mlp(query_embedding) # Select dimension dim_idx = torch.argmax(probs) return self.dims[dim_idx]

Practical Applications

1. Semantic Search at Scale

class ScalableSearch: def __init__(self, model, documents, index_budget_gb=10): self.model = model # Calculate dimension based on budget num_docs = len(documents) bytes_per_doc = index_budget_gb * 1e9 / num_docs self.dim = min(768, int(bytes_per_doc / 4)) # Index at selected dimension self.index = self.build_index(documents, self.dim) def search(self, query, k=10): query_emb = self.model.encode(query)[:self.dim] return self.index.search(query_emb, k)

2. Real-time Recommendation

def real_time_recommendations(user_embedding, items, latency_budget_ms=50): """Get recommendations within latency budget""" # Start with smallest dimension dim = 32 results = None while dim <= 768: start = time.time() # Truncate embeddings user_emb = user_embedding[:dim] item_embs = items[:, :dim] # Compute scores scores = cosine_similarity(user_emb, item_embs) elapsed_ms = (time.time() - start) * 1000 if elapsed_ms < latency_budget_ms: results = scores dim *= 2 # Try higher dimension else: break # Use previous result return results

3. Progressive Loading

class ProgressiveEmbedding: def __init__(self, embedding_path): # Load embeddings in chunks self.dims = [32, 64, 128, 256, 512, 768] self.chunks = {} for i, dim in enumerate(self.dims): start = 0 if i == 0 else self.dims[i-1] end = dim chunk_path = f"{embedding_path}.{start}_{end}" self.chunks[dim] = np.load(chunk_path) def get_embedding(self, dim): """Load only required dimensions""" if dim not in self.dims: dim = min(d for d in self.dims if d >= dim) # Concatenate required chunks embedding = [] for d in self.dims: if d <= dim: embedding.append(self.chunks[d]) else: break return np.concatenate(embedding)

Best Practices

Training Tips

  1. Use multiple dimensions: Train with at least 4-6 nested dimensions
  2. Weight by importance: Give more weight to smaller dimensions
  3. Normalize at each scale: Re-normalize after truncation
  4. Progressive training: Start with small dimensions, add larger ones

Deployment Tips

  1. Profile your constraints: Measure latency/memory requirements
  2. Use cascaded search: Coarse-to-fine retrieval
  3. Cache appropriately: Store different dimensions separately
  4. Monitor quality: Track accuracy at deployed dimensions

References

  • Kusupati et al. "Matryoshka Representation Learning"
  • Wortsman et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time"
  • Chen et al. "Multi-Scale Contrastive Learning for Embedding Compression"

If you found this explanation helpful, consider sharing it with others.

Mastodon