Multi-Vector Late Interaction

12 min

Explore ColBERT and other multi-vector retrieval models that use fine-grained token-level matching for superior search quality.

Best viewed on desktop for optimal interactive experience

Multi-Vector Late Interaction

Multi-vector models like ColBERT achieve state-of-the-art retrieval quality by maintaining fine-grained token representations and computing similarity through late interaction.

Interactive ColBERT Visualization

Multi-Vector Late Interaction

Explore ColBERT and other multi-vector retrieval models

Model Configuration

ColBERT Architecture

Contextualized late interaction over BERT. Uses MaxSim for computing final scores. Each token produces 4 vectors for fine-grained matching.

Token Interaction Matrix

Whatismachinelearning?Machinelearningisartificiintelligmethod0.880.07-0.00-0.00-0.00-0.000.170.990.400.010.020.02-0.000.480.980.320.030.03-0.000.030.440.850.440.030.05-0.01-0.010.48-0.050.35Similarity ScoreLowHigh

Model Comparison

ColBERT

Acc: 95%Eff: 70%

Contextualized late interaction over BERT

Vectors: Per token
Interaction: MaxSim

Poly-Encoder

Acc: 90%Eff: 80%

Multiple attention codes for representation

Vectors: Per document
Interaction: Attention

DPR Multi-Vector

Acc: 85%Eff: 60%

Multiple dense passages per document

Vectors: Per document
Interaction: Average

SPLADE

Acc: 92%Eff: 50%

Sparse + dense multi-vector

Vectors: Per token
Interaction: Weighted Sum

Technical Implementation

ColBERT Scoring

# Late interaction scoring
score = 0
for q_token in query_tokens:
    max_sim = -inf
    for d_token in doc_tokens:
        sim = cosine(q_token, d_token)
        max_sim = max(max_sim, sim)
    score += max_sim
    
# Final score
score = score / len(query_tokens)

Memory Efficiency

Index Size:~10x dense
Query Latency:143ms
Accuracy:95.0%

Indexing Strategy

# Document indexing
doc_embeddings = []
for token in document:
    # Contextualized embedding
    token_emb = bert(token, context)
    if multi_vector:
        # Project to multiple vectors
        vectors = [proj_i(token_emb) 
                  for proj_i in projections]
    else:
        vectors = [token_emb]
    doc_embeddings.append(vectors)
    
# Store in index
index.add(doc_id, doc_embeddings)

Retrieval Pipeline

1
Encode query tokens to vectors
2
Retrieve candidate documents
3
Compute late interaction scores
4
Rank by aggregated scores

Multi-Vector Late Interaction Advantages

Advantages

  • • Fine-grained token-level matching
  • • Better semantic understanding
  • • Handles long documents well
  • • Interpretable similarity scores
  • • State-of-the-art retrieval quality

Trade-offs

  • • Larger index size (10-100x)
  • • Higher query latency
  • • More complex implementation
  • • Requires specialized indexing
  • • Higher memory requirements

Best for: High-precision retrieval tasks, question answering, fact verification, and scenarios where accuracy is more important than speed.

The Late Interaction Paradigm

Traditional dense retrieval compresses entire documents into single vectors, losing fine-grained information. Multi-vector models preserve token-level representations:

Single-Vector (BERT)

Document → BERT → [CLS] token → Single vector Query → BERT → [CLS] token → Single vector Score = cosine(query_vec, doc_vec)

Multi-Vector (ColBERT)

Document → BERT → All tokens → Multiple vectors Query → BERT → All tokens → Multiple vectors Score = sum of max similarities

ColBERT Architecture

The MaxSim Operation

ColBERT's core scoring function:

Sq,d = Σi ∈ |q| maxj ∈ |d| Eqi · EdjT

Where:

  • Eqi = embedding of query token i
  • Edj = embedding of document token j
  • Each query token finds its best match in the document

Implementation

import torch import torch.nn.functional as F class ColBERT(nn.Module): def __init__(self, bert_model, dim=128): super().__init__() self.bert = bert_model self.linear = nn.Linear(768, dim) self.dim = dim def encode_query(self, query_tokens): # Encode query outputs = self.bert(query_tokens) embeddings = outputs.last_hidden_state # Project to lower dimension embeddings = self.linear(embeddings) # Normalize embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [Q] marker to query embeddings query_marker = torch.zeros(1, self.dim) query_marker[0, 0] = 1 # Special query indicator embeddings = embeddings + query_marker return embeddings def encode_document(self, doc_tokens): # Encode document (no [Q] marker) outputs = self.bert(doc_tokens) embeddings = outputs.last_hidden_state embeddings = self.linear(embeddings) embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [D] marker doc_marker = torch.zeros(1, self.dim) doc_marker[0, 1] = 1 # Special doc indicator embeddings = embeddings + doc_marker return embeddings def score(self, query_embeddings, doc_embeddings): # Compute all pairwise similarities scores = torch.matmul(query_embeddings, doc_embeddings.T) # MaxSim: max over document tokens for each query token max_scores = scores.max(dim=-1).values # Sum over query tokens total_score = max_scores.sum() return total_score

Indexing and Retrieval

Efficient Indexing

class ColBERTIndex: def __init__(self, model, documents): self.model = model self.doc_embeddings = [] self.doc_lengths = [] self.doc_ids = [] # Encode all documents for doc_id, doc in enumerate(documents): embeddings = model.encode_document(doc) self.doc_embeddings.append(embeddings) self.doc_lengths.append(len(embeddings)) self.doc_ids.append(doc_id) # Flatten for efficient search self.all_embeddings = torch.cat(self.doc_embeddings) def search(self, query, k=10): # Encode query query_embs = self.model.encode_query(query) # Score all documents scores = [] offset = 0 for length in self.doc_lengths: doc_embs = self.all_embeddings[offset:offset+length] score = self.model.score(query_embs, doc_embs) scores.append(score) offset += length # Get top-k top_k = torch.topk(torch.tensor(scores), k) return [(self.doc_ids[i], scores[i]) for i in top_k.indices]

For large-scale retrieval:

import faiss class ApproximateColBERT: def __init__(self, model, documents, nprobe=32): self.model = model self.nprobe = nprobe # Build inverted index embeddings = [] doc_mapping = [] # Maps embedding to (doc_id, token_id) for doc_id, doc in enumerate(documents): doc_embs = model.encode_document(doc) for token_id, emb in enumerate(doc_embs): embeddings.append(emb) doc_mapping.append((doc_id, token_id)) # Create FAISS index embeddings = np.array(embeddings) self.index = faiss.IndexIVFPQ( faiss.IndexFlatIP(128), # Base index 128, # Dimension 1000, # Number of clusters 32, # Subquantizers 8 # Bits per subquantizer ) self.index.train(embeddings) self.index.add(embeddings) self.index.nprobe = nprobe self.doc_mapping = doc_mapping def search(self, query, k=10): query_embs = self.model.encode_query(query) # Find nearest tokens for each query token scores_per_doc = defaultdict(float) for q_emb in query_embs: # Search for nearest document tokens distances, indices = self.index.search(q_emb.reshape(1, -1), 100) # Accumulate MaxSim scores doc_scores = defaultdict(float) for dist, idx in zip(distances[0], indices[0]): doc_id, _ = self.doc_mapping[idx] doc_scores[doc_id] = max(doc_scores[doc_id], dist) # Add to total scores for doc_id, score in doc_scores.items(): scores_per_doc[doc_id] += score # Get top-k documents sorted_docs = sorted(scores_per_doc.items(), key=lambda x: x[1], reverse=True) return sorted_docs[:k]

Other Multi-Vector Models

1. Poly-Encoder

Uses multiple attention codes:

class PolyEncoder(nn.Module): def __init__(self, bert_model, num_codes=64): super().__init__() self.bert = bert_model self.codes = nn.Parameter(torch.randn(num_codes, 768)) def encode_context(self, context): outputs = self.bert(context) hidden = outputs.last_hidden_state # Attention over context using codes attention = torch.matmul(self.codes, hidden.T) attention = F.softmax(attention, dim=-1) # Weighted average poly_embs = torch.matmul(attention, hidden) return poly_embs # [num_codes, dim] def encode_candidate(self, candidate): outputs = self.bert(candidate) return outputs.pooler_output # Single vector def score(self, context_embs, candidate_emb): # Attention-weighted scoring scores = torch.matmul(context_embs, candidate_emb) attention = F.softmax(scores, dim=0) final_score = (attention * scores).sum() return final_score

2. SPLADE (Sparse + Dense)

Learned sparse representations:

class SPLADE(nn.Module): def __init__(self, bert_model, vocab_size=30522): super().__init__() self.bert = bert_model self.vocab_size = vocab_size def encode(self, tokens): outputs = self.bert(tokens) hidden = outputs.last_hidden_state # Project to vocabulary size logits = self.bert.cls(hidden) # [batch, seq_len, vocab] # Max pooling over sequence scores = logits.max(dim=1).values # Sparsify with ReLU and log sparse = torch.log(1 + F.relu(scores)) return sparse # [batch, vocab_size]

3. DPR Multi-Vector

Multiple passage representations:

class MultiVectorDPR(nn.Module): def __init__(self, bert_model, num_vectors=5): super().__init__() self.bert = bert_model self.projections = nn.ModuleList([ nn.Linear(768, 768) for _ in range(num_vectors) ]) def encode(self, passage): outputs = self.bert(passage) cls_token = outputs.pooler_output # Generate multiple views vectors = [] for projection in self.projections: vec = projection(cls_token) vec = F.normalize(vec, p=2, dim=-1) vectors.append(vec) return torch.stack(vectors) # [num_vectors, dim]

Performance Comparison

Retrieval Quality (MS MARCO)

ModelMRR@10Recall@1000Index Size
BM2518.785.70.5GB
DPR (single)31.295.221GB
ANCE33.095.921GB
ColBERT36.097.0154GB
ColBERTv239.798.425GB

Latency Analysis

# Benchmark different approaches def benchmark_retrieval(model, queries, corpus, method): times = [] for query in queries: start = time.time() if method == 'single_vector': q_emb = model.encode_query_single(query) scores = cosine_similarity(q_emb, corpus_embeddings) elif method == 'colbert': q_embs = model.encode_query_multi(query) scores = [] for doc_embs in corpus_multi_embeddings: score = maxsim(q_embs, doc_embs) scores.append(score) elif method == 'colbert_indexed': results = index.search(query, k=1000) times.append(time.time() - start) return np.mean(times) # Results (typical) # Single vector: 5ms # ColBERT naive: 200ms # ColBERT indexed: 50ms

Optimization Techniques

1. Compression

Reduce index size:

# Dimension reduction embeddings_128d = pca.fit_transform(embeddings_768d) # Quantization embeddings_int8 = quantize_embeddings(embeddings_128d) # Combined: 6× reduction with <2% quality loss

2. Centroid Interaction

Speed up scoring:

def centroid_interaction(query_embs, doc_centroids, top_k=100): # First stage: Score centroids centroid_scores = maxsim(query_embs, doc_centroids) # Second stage: Score top-k documents fully top_docs = centroid_scores.topk(top_k).indices final_scores = [] for doc_id in top_docs: doc_embs = get_full_embeddings(doc_id) score = maxsim(query_embs, doc_embs) final_scores.append(score) return final_scores

3. Denoised Supervision

Improve training:

def denoised_colbert_loss(query, positive, negatives, tau=0.01): # Encode q_embs = model.encode_query(query) pos_embs = model.encode_document(positive) neg_embs = [model.encode_document(neg) for neg in negatives] # Scores pos_score = maxsim(q_embs, pos_embs) neg_scores = [maxsim(q_embs, neg) for neg in neg_embs] # Denoised contrastive loss numerator = torch.exp(pos_score / tau) denominator = numerator + sum(torch.exp(s / tau) for s in neg_scores) loss = -torch.log(numerator / denominator) return loss

Best Practices

1. Token Length Management

# Limit document length for efficiency MAX_DOC_LENGTH = 180 # ColBERT default def prepare_documents(documents): processed = [] for doc in documents: tokens = tokenizer(doc, max_length=MAX_DOC_LENGTH, truncation=True) processed.append(tokens) return processed

2. Query Augmentation

# Add [Q] tokens for better discrimination def augment_query(query): return f"[Q] {query}"

3. Hybrid Retrieval

def hybrid_search(query, k=100): # Stage 1: BM25 for initial candidates bm25_results = bm25_search(query, k=1000) # Stage 2: ColBERT reranking colbert_scores = [] for doc_id in bm25_results: score = colbert_model.score(query, documents[doc_id]) colbert_scores.append((doc_id, score)) # Combine scores final_scores = [] for doc_id, bm25_score in bm25_results: colbert_score = dict(colbert_scores)[doc_id] combined = 0.3 * bm25_score + 0.7 * colbert_score final_scores.append((doc_id, combined)) return sorted(final_scores, key=lambda x: x[1], reverse=True)[:k]

References

  • Khattab & Zaharia "ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT"
  • Santhanam et al. "ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction"
  • Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring"
  • Formal et al. "SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking"

If you found this explanation helpful, consider sharing it with others.

Mastodon