Matryoshka Embeddings
Learn about nested representations that enable flexible dimension reduction without retraining models.
Best viewed on desktop for optimal interactive experience
Matryoshka Embeddings
Matryoshka embeddings enable flexible dimension reduction through nested representations - train once, deploy at any dimension by simple truncation.
Interactive Matryoshka Visualization
Matryoshka Embeddings
Nested representations for flexible dimension reduction without retraining
Dimension Configuration
Truncation from 768D to 128D
Preserves 16.7% of dimensions, maintaining 40.8% similarity. Memory reduced by 83.3%, with 6x speed improvement.
Nested Representation Structure
Performance at 128D
Dimension Truncation Process
Dimension Trade-offs
Dimension | Accuracy | Memory | Speed | Use Case |
---|---|---|---|---|
768D (Full) | 100% | 100.0% | 1x | Full quality, research |
512D (Large) | 99.2% | 66.7% | 1.5x | Production, high quality |
256D (Medium) | 97.5% | 33.3% | 3x | Balanced performance |
128D (Small) | 94.8% | 16.7% | 6x | Fast search, mobile |
64D (Tiny) | 89.3% | 8.3% | 12x | Edge devices, IoT |
32D (Micro) | 82.1% | 4.2% | 24x | Extreme constraints |
Implementation Example
Training with Matryoshka Loss
import torch
import torch.nn.functional as F
def matryoshka_loss(embeddings, labels,
dims=[768, 512, 256, 128]):
"""Multi-scale contrastive loss"""
total_loss = 0
for dim in dims:
# Truncate to current dimension
truncated = embeddings[:, :dim]
# Normalize
truncated = F.normalize(truncated, dim=1)
# Compute contrastive loss
loss = contrastive_loss(truncated, labels)
# Weight by dimension importance
weight = math.sqrt(dim / max(dims))
total_loss += weight * loss
return total_loss
Flexible Inference
# Dynamic dimension selection
def get_embedding(text, target_dim=None):
# Get full embedding
full_emb = model.encode(text)
if target_dim is None:
# Auto-select based on constraints
if memory_limited:
target_dim = 128
elif speed_critical:
target_dim = 64
else:
target_dim = 768
# Simple truncation
return full_emb[:target_dim]
# Usage
query_emb = get_embedding(query, 128)
doc_embs = [get_embedding(d, 128) for d in docs]
scores = cosine_similarity(query_emb, doc_embs)
Matryoshka Embeddings Best Practices
When to Use
- • Multiple deployment targets (cloud/edge)
- • Dynamic resource constraints
- • A/B testing different dimensions
- • Gradual quality degradation acceptable
- • Need for backwards compatibility
Dimension Guidelines
- • 768D: Maximum quality, research
- • 256-512D: Production systems
- • 128D: Mobile applications
- • 64D: Real-time, high-volume
- • 32D: Extreme edge devices
Key Insight: Train once with Matryoshka loss, deploy everywhere. The same model can serve high-quality requests on powerful servers and low-latency requests on edge devices, simply by truncating embeddings.
The Matryoshka Principle
Like Russian nesting dolls, Matryoshka embeddings contain accurate representations at multiple scales within a single embedding:
768D: [█████████████████████████████████] Full representation 512D: [██████████████████████] 98% accuracy retained 256D: [███████████] 95% accuracy retained 128D: [██████] 92% accuracy retained 64D: [███] 87% accuracy retained
How It Works
Traditional vs Matryoshka
Traditional Embeddings:
# Need separate models for different dimensions model_768 = train_model(dim=768) # Full model model_256 = train_model(dim=256) # Retrain for smaller model_128 = train_model(dim=128) # Retrain again
Matryoshka Embeddings:
# Single model, multiple dimensions model = train_matryoshka_model(dims=[768, 512, 256, 128, 64]) # Use any dimension at inference embedding_768 = model.encode(text)[:768] # Full embedding_256 = model.encode(text)[:256] # Truncated embedding_128 = model.encode(text)[:128] # More truncated
Matryoshka Representation Learning (MRL)
The Loss Function
Train with multi-scale contrastive loss:
Where:
- M = \{d1, d2, ..., dk\} = Set of dimensions
- E[:m] = First m dimensions of embedding
- λm = Weight for dimension m
Implementation
import torch import torch.nn as nn import torch.nn.functional as F class MatryoshkaModel(nn.Module): def __init__(self, encoder, dims=[768, 512, 256, 128, 64, 32]): super().__init__() self.encoder = encoder self.dims = sorted(dims, reverse=True) self.projection = nn.Linear(encoder.config.hidden_size, max(dims)) def forward(self, input_ids, attention_mask): # Get base embeddings outputs = self.encoder(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state.mean(dim=1) # Project to max dimension embeddings = self.projection(embeddings) # Normalize full embedding embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings def matryoshka_loss(self, embeddings, labels, temperature=0.07): """Multi-scale contrastive loss""" total_loss = 0 weights = [1.0] + [0.5] * (len(self.dims) - 1) # Decreasing weights for dim, weight in zip(self.dims, weights): # Truncate to current dimension truncated = embeddings[:, :dim] # Re-normalize after truncation truncated = F.normalize(truncated, p=2, dim=-1) # Compute contrastive loss similarity = torch.matmul(truncated, truncated.T) / temperature # Create labels for contrastive learning batch_size = embeddings.shape[0] labels = torch.arange(batch_size).to(embeddings.device) # Cross-entropy loss loss = F.cross_entropy(similarity, labels) # Weighted contribution total_loss += weight * loss return total_loss / sum(weights)
Training Strategy
Progressive Training
Train with increasing complexity:
def progressive_matryoshka_training(model, dataloader, epochs=10): """Train Matryoshka model progressively""" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # Start with smaller dimensions active_dims = [32, 64] for epoch in range(epochs): # Gradually add larger dimensions if epoch == 3: active_dims.append(128) if epoch == 5: active_dims.append(256) if epoch == 7: active_dims.extend([512, 768]) for batch in dataloader: embeddings = model(batch['input_ids'], batch['attention_mask']) # Loss only for active dimensions loss = 0 for dim in active_dims: truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) dim_loss = contrastive_loss(truncated, batch['labels']) loss += dim_loss / len(active_dims) optimizer.zero_grad() loss.backward() optimizer.step()
Importance-Weighted Dimensions
Earlier dimensions are more important:
def importance_weighted_loss(embeddings, labels, dims, alpha=0.5): """Weight loss by dimension importance""" total_loss = 0 for i, dim in enumerate(dims): # Exponentially decreasing importance weight = alpha ** i truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) loss = contrastive_loss(truncated, labels) total_loss += weight * loss return total_loss
Inference and Deployment
Dynamic Dimension Selection
class AdaptiveMatryoshkaIndex: def __init__(self, model, documents): self.model = model self.dims = [768, 512, 256, 128, 64, 32] self.embeddings = {} # Pre-compute embeddings at max dimension with torch.no_grad(): full_embeddings = [] for doc in documents: emb = model.encode(doc) full_embeddings.append(emb) self.full_embeddings = torch.stack(full_embeddings) def search(self, query, k=10, max_latency_ms=100): """Search with latency constraint""" # Estimate dimension based on latency budget if max_latency_ms < 20: dim = 32 elif max_latency_ms < 50: dim = 64 elif max_latency_ms < 100: dim = 128 else: dim = 256 # Encode query at selected dimension query_emb = self.model.encode(query)[:dim] query_emb = F.normalize(query_emb, p=2, dim=-1) # Truncate document embeddings doc_embs = self.full_embeddings[:, :dim] doc_embs = F.normalize(doc_embs, p=2, dim=-1) # Compute similarities similarities = torch.matmul(query_emb, doc_embs.T) # Get top-k top_k = torch.topk(similarities, k) return top_k.indices, top_k.values
Memory-Aware Deployment
def deploy_matryoshka_model(model, memory_budget_mb): """Configure model for memory constraints""" # Estimate memory per dimension bytes_per_float = 4 vocab_size = 30000 # Calculate maximum dimension max_vectors = (memory_budget_mb * 1024 * 1024) / bytes_per_float / vocab_size # Select appropriate dimension if max_vectors >= 768: return 768 elif max_vectors >= 512: return 512 elif max_vectors >= 256: return 256 elif max_vectors >= 128: return 128 else: return 64
Performance Analysis
Dimension vs Accuracy Trade-off
Dimension | Relative Size | Accuracy | Speed | Use Case |
---|---|---|---|---|
768 | 100% | 100% | 1× | Research, high-quality |
512 | 67% | 99.2% | 1.5× | Production servers |
256 | 33% | 97.5% | 3× | Balanced performance |
128 | 17% | 94.8% | 6× | Mobile, real-time |
64 | 8% | 89.3% | 12× | Edge devices |
32 | 4% | 82.1% | 24× | IoT, extreme constraints |
Benchmark Results
def benchmark_dimensions(model, test_data): """Compare performance across dimensions""" results = {} for dim in [768, 512, 256, 128, 64, 32]: # Truncate embeddings query_embs = model.encode(test_data['queries'])[:, :dim] doc_embs = model.encode(test_data['documents'])[:, :dim] # Measure accuracy accuracy = compute_recall_at_k(query_embs, doc_embs, k=10) # Measure speed start = time.time() for _ in range(1000): similarities = cosine_similarity(query_embs[:10], doc_embs) latency = (time.time() - start) / 1000 # Measure memory memory_mb = (query_embs.nbytes + doc_embs.nbytes) / 1024 / 1024 results[dim] = { 'accuracy': accuracy, 'latency_ms': latency * 1000, 'memory_mb': memory_mb } return results
Advanced Techniques
1. Learned Truncation
Not all dimensions are equally important:
class LearnedTruncation(nn.Module): def __init__(self, full_dim=768): super().__init__() self.importance = nn.Parameter(torch.ones(full_dim)) def forward(self, embeddings, target_dim): # Sort dimensions by learned importance importance_sorted = torch.argsort(self.importance, descending=True) # Select most important dimensions selected = importance_sorted[:target_dim] # Reorder embeddings truncated = embeddings[:, selected] return truncated
2. Cascaded Retrieval
Use multiple dimensions for refinement:
def cascaded_search(query, documents, k=10): """Multi-stage retrieval with increasing precision""" # Stage 1: Fast filtering with 32D emb_32 = encode(query)[:32] candidates = search_32d(emb_32, top_k=1000) # Stage 2: Rerank with 128D emb_128 = encode(query)[:128] candidates = rerank_128d(emb_128, candidates, top_k=100) # Stage 3: Final ranking with 768D emb_768 = encode(query)[:768] results = rerank_768d(emb_768, candidates, top_k=k) return results
3. Dimension Prediction
Predict optimal dimension per query:
class DimensionPredictor(nn.Module): def __init__(self, input_dim=768): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 6), # 6 dimension options nn.Softmax(dim=-1) ) self.dims = [32, 64, 128, 256, 512, 768] def forward(self, query_embedding): # Predict dimension probabilities probs = self.mlp(query_embedding) # Select dimension dim_idx = torch.argmax(probs) return self.dims[dim_idx]
Practical Applications
1. Semantic Search at Scale
class ScalableSearch: def __init__(self, model, documents, index_budget_gb=10): self.model = model # Calculate dimension based on budget num_docs = len(documents) bytes_per_doc = index_budget_gb * 1e9 / num_docs self.dim = min(768, int(bytes_per_doc / 4)) # Index at selected dimension self.index = self.build_index(documents, self.dim) def search(self, query, k=10): query_emb = self.model.encode(query)[:self.dim] return self.index.search(query_emb, k)
2. Real-time Recommendation
def real_time_recommendations(user_embedding, items, latency_budget_ms=50): """Get recommendations within latency budget""" # Start with smallest dimension dim = 32 results = None while dim <= 768: start = time.time() # Truncate embeddings user_emb = user_embedding[:dim] item_embs = items[:, :dim] # Compute scores scores = cosine_similarity(user_emb, item_embs) elapsed_ms = (time.time() - start) * 1000 if elapsed_ms < latency_budget_ms: results = scores dim *= 2 # Try higher dimension else: break # Use previous result return results
3. Progressive Loading
class ProgressiveEmbedding: def __init__(self, embedding_path): # Load embeddings in chunks self.dims = [32, 64, 128, 256, 512, 768] self.chunks = {} for i, dim in enumerate(self.dims): start = 0 if i == 0 else self.dims[i-1] end = dim chunk_path = f"{embedding_path}.{start}_{end}" self.chunks[dim] = np.load(chunk_path) def get_embedding(self, dim): """Load only required dimensions""" if dim not in self.dims: dim = min(d for d in self.dims if d >= dim) # Concatenate required chunks embedding = [] for d in self.dims: if d <= dim: embedding.append(self.chunks[d]) else: break return np.concatenate(embedding)
Best Practices
Training Tips
- Use multiple dimensions: Train with at least 4-6 nested dimensions
- Weight by importance: Give more weight to smaller dimensions
- Normalize at each scale: Re-normalize after truncation
- Progressive training: Start with small dimensions, add larger ones
Deployment Tips
- Profile your constraints: Measure latency/memory requirements
- Use cascaded search: Coarse-to-fine retrieval
- Cache appropriately: Store different dimensions separately
- Monitor quality: Track accuracy at deployed dimensions
Related Concepts
- Dense Embeddings - Full-dimensional representations
- Quantization Effects - Alternative compression
- Multi-Vector Late Interaction - Token-level representations
References
- Kusupati et al. "Matryoshka Representation Learning"
- Wortsman et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time"
- Chen et al. "Multi-Scale Contrastive Learning for Embedding Compression"