Multi-Vector Late Interaction
Explore ColBERT and other multi-vector retrieval models that use fine-grained token-level matching for superior search quality.
Best viewed on desktop for optimal interactive experience
Multi-Vector Late Interaction
Multi-vector models like ColBERT achieve state-of-the-art retrieval quality by maintaining fine-grained token representations and computing similarity through late interaction.
Interactive ColBERT Visualization
Multi-Vector Late Interaction
Explore ColBERT and other multi-vector retrieval models
Model Configuration
ColBERT Architecture
Contextualized late interaction over BERT. Uses MaxSim for computing final scores. Each token produces 4 vectors for fine-grained matching.
Token Interaction Matrix
Model Comparison
ColBERT
Contextualized late interaction over BERT
Poly-Encoder
Multiple attention codes for representation
DPR Multi-Vector
Multiple dense passages per document
SPLADE
Sparse + dense multi-vector
Technical Implementation
ColBERT Scoring
# Late interaction scoring
score = 0
for q_token in query_tokens:
max_sim = -inf
for d_token in doc_tokens:
sim = cosine(q_token, d_token)
max_sim = max(max_sim, sim)
score += max_sim
# Final score
score = score / len(query_tokens)
Memory Efficiency
Indexing Strategy
# Document indexing
doc_embeddings = []
for token in document:
# Contextualized embedding
token_emb = bert(token, context)
if multi_vector:
# Project to multiple vectors
vectors = [proj_i(token_emb)
for proj_i in projections]
else:
vectors = [token_emb]
doc_embeddings.append(vectors)
# Store in index
index.add(doc_id, doc_embeddings)
Retrieval Pipeline
Multi-Vector Late Interaction Advantages
Advantages
- • Fine-grained token-level matching
- • Better semantic understanding
- • Handles long documents well
- • Interpretable similarity scores
- • State-of-the-art retrieval quality
Trade-offs
- • Larger index size (10-100x)
- • Higher query latency
- • More complex implementation
- • Requires specialized indexing
- • Higher memory requirements
Best for: High-precision retrieval tasks, question answering, fact verification, and scenarios where accuracy is more important than speed.
The Late Interaction Paradigm
Traditional dense retrieval compresses entire documents into single vectors, losing fine-grained information. Multi-vector models preserve token-level representations:
Single-Vector (BERT)
Document → BERT → [CLS] token → Single vector Query → BERT → [CLS] token → Single vector Score = cosine(query_vec, doc_vec)
Multi-Vector (ColBERT)
Document → BERT → All tokens → Multiple vectors Query → BERT → All tokens → Multiple vectors Score = sum of max similarities
ColBERT Architecture
The MaxSim Operation
ColBERT's core scoring function:
Where:
- Eqi = embedding of query token i
- Edj = embedding of document token j
- Each query token finds its best match in the document
Implementation
import torch import torch.nn.functional as F class ColBERT(nn.Module): def __init__(self, bert_model, dim=128): super().__init__() self.bert = bert_model self.linear = nn.Linear(768, dim) self.dim = dim def encode_query(self, query_tokens): # Encode query outputs = self.bert(query_tokens) embeddings = outputs.last_hidden_state # Project to lower dimension embeddings = self.linear(embeddings) # Normalize embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [Q] marker to query embeddings query_marker = torch.zeros(1, self.dim) query_marker[0, 0] = 1 # Special query indicator embeddings = embeddings + query_marker return embeddings def encode_document(self, doc_tokens): # Encode document (no [Q] marker) outputs = self.bert(doc_tokens) embeddings = outputs.last_hidden_state embeddings = self.linear(embeddings) embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [D] marker doc_marker = torch.zeros(1, self.dim) doc_marker[0, 1] = 1 # Special doc indicator embeddings = embeddings + doc_marker return embeddings def score(self, query_embeddings, doc_embeddings): # Compute all pairwise similarities scores = torch.matmul(query_embeddings, doc_embeddings.T) # MaxSim: max over document tokens for each query token max_scores = scores.max(dim=-1).values # Sum over query tokens total_score = max_scores.sum() return total_score
Indexing and Retrieval
Efficient Indexing
class ColBERTIndex: def __init__(self, model, documents): self.model = model self.doc_embeddings = [] self.doc_lengths = [] self.doc_ids = [] # Encode all documents for doc_id, doc in enumerate(documents): embeddings = model.encode_document(doc) self.doc_embeddings.append(embeddings) self.doc_lengths.append(len(embeddings)) self.doc_ids.append(doc_id) # Flatten for efficient search self.all_embeddings = torch.cat(self.doc_embeddings) def search(self, query, k=10): # Encode query query_embs = self.model.encode_query(query) # Score all documents scores = [] offset = 0 for length in self.doc_lengths: doc_embs = self.all_embeddings[offset:offset+length] score = self.model.score(query_embs, doc_embs) scores.append(score) offset += length # Get top-k top_k = torch.topk(torch.tensor(scores), k) return [(self.doc_ids[i], scores[i]) for i in top_k.indices]
Approximate Search
For large-scale retrieval:
import faiss class ApproximateColBERT: def __init__(self, model, documents, nprobe=32): self.model = model self.nprobe = nprobe # Build inverted index embeddings = [] doc_mapping = [] # Maps embedding to (doc_id, token_id) for doc_id, doc in enumerate(documents): doc_embs = model.encode_document(doc) for token_id, emb in enumerate(doc_embs): embeddings.append(emb) doc_mapping.append((doc_id, token_id)) # Create FAISS index embeddings = np.array(embeddings) self.index = faiss.IndexIVFPQ( faiss.IndexFlatIP(128), # Base index 128, # Dimension 1000, # Number of clusters 32, # Subquantizers 8 # Bits per subquantizer ) self.index.train(embeddings) self.index.add(embeddings) self.index.nprobe = nprobe self.doc_mapping = doc_mapping def search(self, query, k=10): query_embs = self.model.encode_query(query) # Find nearest tokens for each query token scores_per_doc = defaultdict(float) for q_emb in query_embs: # Search for nearest document tokens distances, indices = self.index.search(q_emb.reshape(1, -1), 100) # Accumulate MaxSim scores doc_scores = defaultdict(float) for dist, idx in zip(distances[0], indices[0]): doc_id, _ = self.doc_mapping[idx] doc_scores[doc_id] = max(doc_scores[doc_id], dist) # Add to total scores for doc_id, score in doc_scores.items(): scores_per_doc[doc_id] += score # Get top-k documents sorted_docs = sorted(scores_per_doc.items(), key=lambda x: x[1], reverse=True) return sorted_docs[:k]
Other Multi-Vector Models
1. Poly-Encoder
Uses multiple attention codes:
class PolyEncoder(nn.Module): def __init__(self, bert_model, num_codes=64): super().__init__() self.bert = bert_model self.codes = nn.Parameter(torch.randn(num_codes, 768)) def encode_context(self, context): outputs = self.bert(context) hidden = outputs.last_hidden_state # Attention over context using codes attention = torch.matmul(self.codes, hidden.T) attention = F.softmax(attention, dim=-1) # Weighted average poly_embs = torch.matmul(attention, hidden) return poly_embs # [num_codes, dim] def encode_candidate(self, candidate): outputs = self.bert(candidate) return outputs.pooler_output # Single vector def score(self, context_embs, candidate_emb): # Attention-weighted scoring scores = torch.matmul(context_embs, candidate_emb) attention = F.softmax(scores, dim=0) final_score = (attention * scores).sum() return final_score
2. SPLADE (Sparse + Dense)
Learned sparse representations:
class SPLADE(nn.Module): def __init__(self, bert_model, vocab_size=30522): super().__init__() self.bert = bert_model self.vocab_size = vocab_size def encode(self, tokens): outputs = self.bert(tokens) hidden = outputs.last_hidden_state # Project to vocabulary size logits = self.bert.cls(hidden) # [batch, seq_len, vocab] # Max pooling over sequence scores = logits.max(dim=1).values # Sparsify with ReLU and log sparse = torch.log(1 + F.relu(scores)) return sparse # [batch, vocab_size]
3. DPR Multi-Vector
Multiple passage representations:
class MultiVectorDPR(nn.Module): def __init__(self, bert_model, num_vectors=5): super().__init__() self.bert = bert_model self.projections = nn.ModuleList([ nn.Linear(768, 768) for _ in range(num_vectors) ]) def encode(self, passage): outputs = self.bert(passage) cls_token = outputs.pooler_output # Generate multiple views vectors = [] for projection in self.projections: vec = projection(cls_token) vec = F.normalize(vec, p=2, dim=-1) vectors.append(vec) return torch.stack(vectors) # [num_vectors, dim]
Performance Comparison
Retrieval Quality (MS MARCO)
Model | MRR@10 | Recall@1000 | Index Size |
---|---|---|---|
BM25 | 18.7 | 85.7 | 0.5GB |
DPR (single) | 31.2 | 95.2 | 21GB |
ANCE | 33.0 | 95.9 | 21GB |
ColBERT | 36.0 | 97.0 | 154GB |
ColBERTv2 | 39.7 | 98.4 | 25GB |
Latency Analysis
# Benchmark different approaches def benchmark_retrieval(model, queries, corpus, method): times = [] for query in queries: start = time.time() if method == 'single_vector': q_emb = model.encode_query_single(query) scores = cosine_similarity(q_emb, corpus_embeddings) elif method == 'colbert': q_embs = model.encode_query_multi(query) scores = [] for doc_embs in corpus_multi_embeddings: score = maxsim(q_embs, doc_embs) scores.append(score) elif method == 'colbert_indexed': results = index.search(query, k=1000) times.append(time.time() - start) return np.mean(times) # Results (typical) # Single vector: 5ms # ColBERT naive: 200ms # ColBERT indexed: 50ms
Optimization Techniques
1. Compression
Reduce index size:
# Dimension reduction embeddings_128d = pca.fit_transform(embeddings_768d) # Quantization embeddings_int8 = quantize_embeddings(embeddings_128d) # Combined: 6× reduction with <2% quality loss
2. Centroid Interaction
Speed up scoring:
def centroid_interaction(query_embs, doc_centroids, top_k=100): # First stage: Score centroids centroid_scores = maxsim(query_embs, doc_centroids) # Second stage: Score top-k documents fully top_docs = centroid_scores.topk(top_k).indices final_scores = [] for doc_id in top_docs: doc_embs = get_full_embeddings(doc_id) score = maxsim(query_embs, doc_embs) final_scores.append(score) return final_scores
3. Denoised Supervision
Improve training:
def denoised_colbert_loss(query, positive, negatives, tau=0.01): # Encode q_embs = model.encode_query(query) pos_embs = model.encode_document(positive) neg_embs = [model.encode_document(neg) for neg in negatives] # Scores pos_score = maxsim(q_embs, pos_embs) neg_scores = [maxsim(q_embs, neg) for neg in neg_embs] # Denoised contrastive loss numerator = torch.exp(pos_score / tau) denominator = numerator + sum(torch.exp(s / tau) for s in neg_scores) loss = -torch.log(numerator / denominator) return loss
Best Practices
1. Token Length Management
# Limit document length for efficiency MAX_DOC_LENGTH = 180 # ColBERT default def prepare_documents(documents): processed = [] for doc in documents: tokens = tokenizer(doc, max_length=MAX_DOC_LENGTH, truncation=True) processed.append(tokens) return processed
2. Query Augmentation
# Add [Q] tokens for better discrimination def augment_query(query): return f"[Q] {query}"
3. Hybrid Retrieval
def hybrid_search(query, k=100): # Stage 1: BM25 for initial candidates bm25_results = bm25_search(query, k=1000) # Stage 2: ColBERT reranking colbert_scores = [] for doc_id in bm25_results: score = colbert_model.score(query, documents[doc_id]) colbert_scores.append((doc_id, score)) # Combine scores final_scores = [] for doc_id, bm25_score in bm25_results: colbert_score = dict(colbert_scores)[doc_id] combined = 0.3 * bm25_score + 0.7 * colbert_score final_scores.append((doc_id, combined)) return sorted(final_scores, key=lambda x: x[1], reverse=True)[:k]
Related Concepts
- Dense Embeddings - Single-vector baselines
- Sparse vs Dense - Comparing retrieval paradigms
- Matryoshka Embeddings - Efficient multi-scale representations
References
- Khattab & Zaharia "ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT"
- Santhanam et al. "ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction"
- Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring"
- Formal et al. "SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking"