Cross-Encoder vs Bi-Encoder

10 min

Understand the fundamental differences between independent and joint encoding architectures for neural retrieval systems.

Best viewed on desktop for optimal interactive experience

Cross-Encoder vs Bi-Encoder

The choice between cross-encoders and bi-encoders is fundamental to building effective neural search systems, each offering distinct trade-offs between speed and accuracy.

Interactive Architecture Comparison

Cross-Encoder vs Bi-Encoder

Compare independent vs joint encoding architectures for neural retrieval

Architecture Configuration

Bi-Encoder Architecture

Query EncoderBERTDoc EncoderBERTQueryVector[768D]DocVectors[N×768D]CosineSimilarityScoresIndependentPre-computed

Two-Stage Retrieval Pipeline

Stage 1: Retrieval

Bi-Encoder: Fast candidate selection

~50ms for 1M docs

Stage 2: Re-ranking

Cross-Encoder: Precise scoring

~10ms per pair
1M Documents
Bi-Encoder
100 Candidates
Cross-Encoder
Top 10
Final Results

Ranking Results

#1Score: 0.552

Machine learning is a subset of AI that enables systems to learn from data

#2Score: 0.457

Gradient descent is an optimization algorithm used to minimize functions

#3Score: 0.265

Deep learning uses multiple layers of neural networks for complex patterns

#4Score: 0.094

Supervised learning requires labeled training data for model training

#5Score: 0.079

Neural networks are computing systems inspired by biological neural networks

Performance Comparison

Speed
Very FastSlow
Accuracy
GoodExcellent
Scalability
MillionsThousands
Memory
ModerateLow

Detailed Comparison

AspectBi-EncoderCross-Encoder
ArchitectureTwo separate encodersSingle joint encoder
InputQuery and doc separately[CLS] Query [SEP] Doc [SEP]
OutputDense vectorsRelevance score
Pre-computation✅ Yes❌ No
Latency~50ms for 1M docs~10ms per pair
Use CaseFirst-stage retrievalRe-ranking

Best Practices for Encoder Selection

When to Use Bi-Encoder

  • • Large-scale retrieval (millions of docs)
  • • Real-time search requirements
  • • Need for pre-computed embeddings
  • • Semantic similarity search
  • • First-stage candidate generation

When to Use Cross-Encoder

  • • Small candidate sets (<1000)
  • • Maximum accuracy required
  • • Re-ranking top results
  • • Question answering tasks
  • • Fact verification

Hybrid Approach: Use bi-encoder for initial retrieval of top-100 candidates, then cross-encoder to re-rank for final top-10. This balances speed and accuracy optimally.

Core Architectural Differences

Bi-Encoder (Dual Encoder)

  • Independent encoding of queries and documents
  • Pre-computable document embeddings
  • Fast similarity computation via dot product
  • Scalable to millions of documents

Cross-Encoder

  • Joint encoding of query-document pairs
  • Full attention between query and document tokens
  • High accuracy but computationally expensive
  • Suitable for re-ranking small candidate sets

Bi-Encoder Architecture

How It Works

class BiEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.query_encoder = AutoModel.from_pretrained(model_name) self.doc_encoder = AutoModel.from_pretrained(model_name) def encode_query(self, query_tokens): outputs = self.query_encoder(**query_tokens) # Use [CLS] token or mean pooling query_embedding = outputs.pooler_output return F.normalize(query_embedding, p=2, dim=-1) def encode_document(self, doc_tokens): outputs = self.doc_encoder(**doc_tokens) doc_embedding = outputs.pooler_output return F.normalize(doc_embedding, p=2, dim=-1) def score(self, query_embedding, doc_embedding): # Simple dot product return torch.sum(query_embedding * doc_embedding, dim=-1)

Training with Contrastive Loss

ℒ = -log es(q, d^+) / τΣd' ∈ D es(q, d') / τ

Where:

  • s(q, d) = Similarity score
  • d^+ = Positive document
  • D = All documents in batch
  • τ = Temperature parameter
def in_batch_negatives_loss(query_embs, doc_embs, temperature=0.07): """Contrastive loss with in-batch negatives""" # Compute all similarities similarities = torch.matmul(query_embs, doc_embs.T) / temperature # Positive pairs are on diagonal labels = torch.arange(len(query_embs)).to(query_embs.device) # Cross-entropy loss loss = F.cross_entropy(similarities, labels) return loss

Cross-Encoder Architecture

How It Works

class CrossEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.classifier = nn.Linear(768, 1) def forward(self, input_ids, attention_mask, token_type_ids): # Joint encoding of [CLS] query [SEP] document [SEP] outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) # Use [CLS] token for classification cls_output = outputs.last_hidden_state[:, 0] # Compute relevance score score = self.classifier(cls_output) return torch.sigmoid(score)

Training with Binary Classification

def train_cross_encoder(model, dataloader, optimizer): criterion = nn.BCELoss() for batch in dataloader: # Prepare input: [CLS] query [SEP] document [SEP] inputs = tokenizer( batch['queries'], batch['documents'], truncation=True, padding=True, return_tensors='pt' ) # Forward pass scores = model(**inputs) # Binary labels: 1 for relevant, 0 for non-relevant loss = criterion(scores, batch['labels']) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step()

Two-Stage Retrieval Pipeline

The optimal approach combines both architectures:

Stage 1: Retrieval (Bi-Encoder)

class DenseRetriever: def __init__(self, bi_encoder, documents): self.encoder = bi_encoder self.index = self.build_index(documents) def build_index(self, documents): # Pre-compute all document embeddings doc_embeddings = [] for doc in tqdm(documents): embedding = self.encoder.encode_document(doc) doc_embeddings.append(embedding) # Build FAISS index embeddings = torch.stack(doc_embeddings).numpy() index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) return index def retrieve(self, query, k=100): # Encode query query_emb = self.encoder.encode_query(query).numpy() # Fast nearest neighbor search scores, indices = self.index.search(query_emb, k) return indices[0], scores[0]

Stage 2: Re-ranking (Cross-Encoder)

class Reranker: def __init__(self, cross_encoder): self.model = cross_encoder def rerank(self, query, documents, k=10): # Score each query-document pair scores = [] for doc in documents: inputs = tokenizer( query, doc, truncation=True, return_tensors='pt' ) with torch.no_grad(): score = self.model(**inputs).item() scores.append(score) # Sort by score ranked_indices = np.argsort(scores)[::-1][:k] return ranked_indices, [scores[i] for i in ranked_indices]

Complete Pipeline

def hybrid_search(query, corpus, bi_encoder, cross_encoder, k=10): """Two-stage retrieval and re-ranking""" # Stage 1: Fast retrieval with bi-encoder retriever = DenseRetriever(bi_encoder, corpus) candidate_indices, _ = retriever.retrieve(query, k=100) candidates = [corpus[i] for i in candidate_indices] # Stage 2: Accurate re-ranking with cross-encoder reranker = Reranker(cross_encoder) final_indices, final_scores = reranker.rerank(query, candidates, k=k) # Map back to original corpus results = [] for idx, score in zip(final_indices, final_scores): original_idx = candidate_indices[idx] results.append({ 'document': corpus[original_idx], 'score': score, 'index': original_idx }) return results

Performance Comparison

Speed Analysis

StageBi-EncoderCross-Encoder
IndexingO(n) one-timeNot applicable
Query encodingO(1)O(n) per document
ScoringO(1) dot productO(L²) full attention
Total for 1M docs~50ms~3 hours

Quality Metrics (MS MARCO)

ModelMRR@10Recall@100Latency
BM2518.785.720ms
Bi-Encoder (DPR)31.295.250ms
Cross-Encoder39.2N/A10s/doc
Bi-Encoder + Cross-Encoder38.595.2150ms

Optimization Techniques

Bi-Encoder Optimizations

# 1. Hard negative mining def mine_hard_negatives(query, positive_doc, corpus, bi_encoder, k=10): """Find challenging negative examples""" # Retrieve similar but wrong documents results = bi_encoder.search(query, k=k+1) hard_negatives = [doc for doc in results if doc != positive_doc] return hard_negatives[:k] # 2. Distillation from cross-encoder def distill_bi_encoder(student_bi, teacher_cross, data): """Knowledge distillation""" for query, docs in data: # Get teacher scores teacher_scores = teacher_cross.score_pairs(query, docs) # Train student to match student_scores = student_bi.score_pairs(query, docs) loss = F.mse_loss(student_scores, teacher_scores) loss.backward()

Cross-Encoder Optimizations

# 1. Lightweight models class MiniCrossEncoder(nn.Module): """Distilled cross-encoder for faster inference""" def __init__(self): super().__init__() # Use DistilBERT or TinyBERT self.encoder = AutoModel.from_pretrained('distilbert-base-uncased') self.classifier = nn.Linear(768, 1) # 2. Caching strategies class CachedCrossEncoder: def __init__(self, model, cache_size=10000): self.model = model self.cache = LRUCache(cache_size) def score(self, query, doc): cache_key = hash((query, doc)) if cache_key in self.cache: return self.cache[cache_key] score = self.model.score(query, doc) self.cache[cache_key] = score return score

Choosing the Right Architecture

Decision Framework

def choose_architecture(requirements): """Select optimal architecture based on requirements""" # Pure bi-encoder for large-scale, real-time if requirements['corpus_size'] > 1e6 and requirements['latency_ms'] < 100: return 'bi-encoder' # Pure cross-encoder for small, high-accuracy if requirements['corpus_size'] < 1000 and requirements['accuracy_critical']: return 'cross-encoder' # Hybrid for balanced performance if requirements['corpus_size'] > 1e4: return 'bi-encoder + cross-encoder' return 'cross-encoder'

Use Case Examples

Bi-Encoder Only:

  • Semantic search engines
  • Similar item recommendation
  • Large-scale document retrieval
  • Real-time question answering

Cross-Encoder Only:

  • Fact verification
  • Answer selection
  • Duplicate detection
  • Small corpus QA

Hybrid (Both):

  • Web search engines
  • Enterprise search
  • E-commerce search
  • Academic paper search

Advanced Architectures

Poly-Encoder

Balances between bi and cross-encoders:

class PolyEncoder(nn.Module): """Multiple attention codes for better interaction""" def __init__(self, num_codes=64): super().__init__() self.context_encoder = AutoModel.from_pretrained('bert-base') self.candidate_encoder = AutoModel.from_pretrained('bert-base') self.poly_codes = nn.Parameter(torch.randn(num_codes, 768))

ColBERT

Late interaction with token-level matching:

class ColBERT(nn.Module): """Multi-vector with late interaction""" def score(self, query_tokens, doc_tokens): # MaxSim over all token pairs scores = torch.matmul(query_tokens, doc_tokens.T) max_scores = scores.max(dim=-1).values return max_scores.sum()

Best Practices

  1. Start with bi-encoder for initial system
  2. Add cross-encoder when accuracy plateaus
  3. Use hard negatives for training
  4. Implement caching for cross-encoder
  5. Monitor latency in production
  6. A/B test hybrid configurations

References

  • Karpukhin et al. "Dense Passage Retrieval for Open-Domain Question Answering"
  • Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies"
  • Reimers & Gurevych "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"
  • Nogueira & Cho "Passage Re-ranking with BERT"

If you found this explanation helpful, consider sharing it with others.

Mastodon