Cross-Encoder vs Bi-Encoder
Understand the fundamental differences between independent and joint encoding architectures for neural retrieval systems.
Best viewed on desktop for optimal interactive experience
Cross-Encoder vs Bi-Encoder
The choice between cross-encoders and bi-encoders is fundamental to building effective neural search systems, each offering distinct trade-offs between speed and accuracy.
Interactive Architecture Comparison
Cross-Encoder vs Bi-Encoder
Compare independent vs joint encoding architectures for neural retrieval
Architecture Configuration
Bi-Encoder Architecture
Two-Stage Retrieval Pipeline
Stage 1: Retrieval
Bi-Encoder: Fast candidate selection
Stage 2: Re-ranking
Cross-Encoder: Precise scoring
Ranking Results
Machine learning is a subset of AI that enables systems to learn from data
Gradient descent is an optimization algorithm used to minimize functions
Deep learning uses multiple layers of neural networks for complex patterns
Supervised learning requires labeled training data for model training
Neural networks are computing systems inspired by biological neural networks
Performance Comparison
Detailed Comparison
Aspect | Bi-Encoder | Cross-Encoder |
---|---|---|
Architecture | Two separate encoders | Single joint encoder |
Input | Query and doc separately | [CLS] Query [SEP] Doc [SEP] |
Output | Dense vectors | Relevance score |
Pre-computation | ✅ Yes | ❌ No |
Latency | ~50ms for 1M docs | ~10ms per pair |
Use Case | First-stage retrieval | Re-ranking |
Best Practices for Encoder Selection
When to Use Bi-Encoder
- • Large-scale retrieval (millions of docs)
- • Real-time search requirements
- • Need for pre-computed embeddings
- • Semantic similarity search
- • First-stage candidate generation
When to Use Cross-Encoder
- • Small candidate sets (<1000)
- • Maximum accuracy required
- • Re-ranking top results
- • Question answering tasks
- • Fact verification
Hybrid Approach: Use bi-encoder for initial retrieval of top-100 candidates, then cross-encoder to re-rank for final top-10. This balances speed and accuracy optimally.
Core Architectural Differences
Bi-Encoder (Dual Encoder)
- Independent encoding of queries and documents
- Pre-computable document embeddings
- Fast similarity computation via dot product
- Scalable to millions of documents
Cross-Encoder
- Joint encoding of query-document pairs
- Full attention between query and document tokens
- High accuracy but computationally expensive
- Suitable for re-ranking small candidate sets
Bi-Encoder Architecture
How It Works
class BiEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.query_encoder = AutoModel.from_pretrained(model_name) self.doc_encoder = AutoModel.from_pretrained(model_name) def encode_query(self, query_tokens): outputs = self.query_encoder(**query_tokens) # Use [CLS] token or mean pooling query_embedding = outputs.pooler_output return F.normalize(query_embedding, p=2, dim=-1) def encode_document(self, doc_tokens): outputs = self.doc_encoder(**doc_tokens) doc_embedding = outputs.pooler_output return F.normalize(doc_embedding, p=2, dim=-1) def score(self, query_embedding, doc_embedding): # Simple dot product return torch.sum(query_embedding * doc_embedding, dim=-1)
Training with Contrastive Loss
Where:
- s(q, d) = Similarity score
- d^+ = Positive document
- D = All documents in batch
- τ = Temperature parameter
def in_batch_negatives_loss(query_embs, doc_embs, temperature=0.07): """Contrastive loss with in-batch negatives""" # Compute all similarities similarities = torch.matmul(query_embs, doc_embs.T) / temperature # Positive pairs are on diagonal labels = torch.arange(len(query_embs)).to(query_embs.device) # Cross-entropy loss loss = F.cross_entropy(similarities, labels) return loss
Cross-Encoder Architecture
How It Works
class CrossEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.classifier = nn.Linear(768, 1) def forward(self, input_ids, attention_mask, token_type_ids): # Joint encoding of [CLS] query [SEP] document [SEP] outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) # Use [CLS] token for classification cls_output = outputs.last_hidden_state[:, 0] # Compute relevance score score = self.classifier(cls_output) return torch.sigmoid(score)
Training with Binary Classification
def train_cross_encoder(model, dataloader, optimizer): criterion = nn.BCELoss() for batch in dataloader: # Prepare input: [CLS] query [SEP] document [SEP] inputs = tokenizer( batch['queries'], batch['documents'], truncation=True, padding=True, return_tensors='pt' ) # Forward pass scores = model(**inputs) # Binary labels: 1 for relevant, 0 for non-relevant loss = criterion(scores, batch['labels']) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step()
Two-Stage Retrieval Pipeline
The optimal approach combines both architectures:
Stage 1: Retrieval (Bi-Encoder)
class DenseRetriever: def __init__(self, bi_encoder, documents): self.encoder = bi_encoder self.index = self.build_index(documents) def build_index(self, documents): # Pre-compute all document embeddings doc_embeddings = [] for doc in tqdm(documents): embedding = self.encoder.encode_document(doc) doc_embeddings.append(embedding) # Build FAISS index embeddings = torch.stack(doc_embeddings).numpy() index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) return index def retrieve(self, query, k=100): # Encode query query_emb = self.encoder.encode_query(query).numpy() # Fast nearest neighbor search scores, indices = self.index.search(query_emb, k) return indices[0], scores[0]
Stage 2: Re-ranking (Cross-Encoder)
class Reranker: def __init__(self, cross_encoder): self.model = cross_encoder def rerank(self, query, documents, k=10): # Score each query-document pair scores = [] for doc in documents: inputs = tokenizer( query, doc, truncation=True, return_tensors='pt' ) with torch.no_grad(): score = self.model(**inputs).item() scores.append(score) # Sort by score ranked_indices = np.argsort(scores)[::-1][:k] return ranked_indices, [scores[i] for i in ranked_indices]
Complete Pipeline
def hybrid_search(query, corpus, bi_encoder, cross_encoder, k=10): """Two-stage retrieval and re-ranking""" # Stage 1: Fast retrieval with bi-encoder retriever = DenseRetriever(bi_encoder, corpus) candidate_indices, _ = retriever.retrieve(query, k=100) candidates = [corpus[i] for i in candidate_indices] # Stage 2: Accurate re-ranking with cross-encoder reranker = Reranker(cross_encoder) final_indices, final_scores = reranker.rerank(query, candidates, k=k) # Map back to original corpus results = [] for idx, score in zip(final_indices, final_scores): original_idx = candidate_indices[idx] results.append({ 'document': corpus[original_idx], 'score': score, 'index': original_idx }) return results
Performance Comparison
Speed Analysis
Stage | Bi-Encoder | Cross-Encoder |
---|---|---|
Indexing | O(n) one-time | Not applicable |
Query encoding | O(1) | O(n) per document |
Scoring | O(1) dot product | O(L²) full attention |
Total for 1M docs | ~50ms | ~3 hours |
Quality Metrics (MS MARCO)
Model | MRR@10 | Recall@100 | Latency |
---|---|---|---|
BM25 | 18.7 | 85.7 | 20ms |
Bi-Encoder (DPR) | 31.2 | 95.2 | 50ms |
Cross-Encoder | 39.2 | N/A | 10s/doc |
Bi-Encoder + Cross-Encoder | 38.5 | 95.2 | 150ms |
Optimization Techniques
Bi-Encoder Optimizations
# 1. Hard negative mining def mine_hard_negatives(query, positive_doc, corpus, bi_encoder, k=10): """Find challenging negative examples""" # Retrieve similar but wrong documents results = bi_encoder.search(query, k=k+1) hard_negatives = [doc for doc in results if doc != positive_doc] return hard_negatives[:k] # 2. Distillation from cross-encoder def distill_bi_encoder(student_bi, teacher_cross, data): """Knowledge distillation""" for query, docs in data: # Get teacher scores teacher_scores = teacher_cross.score_pairs(query, docs) # Train student to match student_scores = student_bi.score_pairs(query, docs) loss = F.mse_loss(student_scores, teacher_scores) loss.backward()
Cross-Encoder Optimizations
# 1. Lightweight models class MiniCrossEncoder(nn.Module): """Distilled cross-encoder for faster inference""" def __init__(self): super().__init__() # Use DistilBERT or TinyBERT self.encoder = AutoModel.from_pretrained('distilbert-base-uncased') self.classifier = nn.Linear(768, 1) # 2. Caching strategies class CachedCrossEncoder: def __init__(self, model, cache_size=10000): self.model = model self.cache = LRUCache(cache_size) def score(self, query, doc): cache_key = hash((query, doc)) if cache_key in self.cache: return self.cache[cache_key] score = self.model.score(query, doc) self.cache[cache_key] = score return score
Choosing the Right Architecture
Decision Framework
def choose_architecture(requirements): """Select optimal architecture based on requirements""" # Pure bi-encoder for large-scale, real-time if requirements['corpus_size'] > 1e6 and requirements['latency_ms'] < 100: return 'bi-encoder' # Pure cross-encoder for small, high-accuracy if requirements['corpus_size'] < 1000 and requirements['accuracy_critical']: return 'cross-encoder' # Hybrid for balanced performance if requirements['corpus_size'] > 1e4: return 'bi-encoder + cross-encoder' return 'cross-encoder'
Use Case Examples
Bi-Encoder Only:
- Semantic search engines
- Similar item recommendation
- Large-scale document retrieval
- Real-time question answering
Cross-Encoder Only:
- Fact verification
- Answer selection
- Duplicate detection
- Small corpus QA
Hybrid (Both):
- Web search engines
- Enterprise search
- E-commerce search
- Academic paper search
Advanced Architectures
Poly-Encoder
Balances between bi and cross-encoders:
class PolyEncoder(nn.Module): """Multiple attention codes for better interaction""" def __init__(self, num_codes=64): super().__init__() self.context_encoder = AutoModel.from_pretrained('bert-base') self.candidate_encoder = AutoModel.from_pretrained('bert-base') self.poly_codes = nn.Parameter(torch.randn(num_codes, 768))
ColBERT
Late interaction with token-level matching:
class ColBERT(nn.Module): """Multi-vector with late interaction""" def score(self, query_tokens, doc_tokens): # MaxSim over all token pairs scores = torch.matmul(query_tokens, doc_tokens.T) max_scores = scores.max(dim=-1).values return max_scores.sum()
Best Practices
- Start with bi-encoder for initial system
- Add cross-encoder when accuracy plateaus
- Use hard negatives for training
- Implement caching for cross-encoder
- Monitor latency in production
- A/B test hybrid configurations
Related Concepts
- Dense Embeddings - Foundation of bi-encoders
- Multi-Vector Late Interaction - Advanced architectures
- Sparse vs Dense - Retrieval paradigms
References
- Karpukhin et al. "Dense Passage Retrieval for Open-Domain Question Answering"
- Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies"
- Reimers & Gurevych "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"
- Nogueira & Cho "Passage Re-ranking with BERT"