Contrastive Loss Functions
Master contrastive loss functions including InfoNCE, NT-Xent, and Triplet Loss for representation learning and self-supervised training.
Best viewed on desktop for optimal interactive experience
Contrastive Loss: Learning Representations by Comparison
Contrastive losses learn representations by pulling similar samples together and pushing dissimilar samples apart in embedding space, forming the foundation of modern self-supervised and multimodal learning.
Interactive Contrastive Learning Explorer
Visualize how different contrastive losses shape the embedding space:
Contrastive Loss Configuration
Explore different contrastive loss functions and their behavior
Embedding Space
Temperature Effect on Similarity Scores
Loss Breakdown
InfoNCE Formula
Key Properties
InfoNCE
- • Maximizes mutual information
- • Scales with batch size
- • Requires many negatives
- • Used in CLIP, MoCo
NT-Xent
- • Normalized temperature scaling
- • Symmetric loss
- • Data augmentation critical
- • Used in SimCLR, BYOL
Triplet Loss
- • Direct margin optimization
- • Hard negative mining
- • Simpler formulation
- • Used in FaceNet, metric learning
Core Principle
Contrastive learning operates on a simple principle:
- Maximize agreement between augmented views of the same data (positives)
- Minimize agreement between different data points (negatives)
- Learn representations that capture semantic similarity
InfoNCE Loss
Definition
InfoNCE (Noise Contrastive Estimation) maximizes the lower bound on mutual information:
Where:
- zi is the anchor embedding
- zi^+ is the positive embedding
- τ is the temperature parameter
- sim(·, ·) is the similarity function (usually cosine)
Implementation
import torch import torch.nn as nn import torch.nn.functional as F class InfoNCE(nn.Module): """ InfoNCE loss for contrastive learning Used in CLIP, MoCo, and other vision-language models """ def __init__(self, temperature=0.07, similarity='cosine'): super().__init__() self.temperature = temperature self.similarity = similarity def compute_similarity(self, z1, z2): """Compute similarity between embeddings""" if self.similarity == 'cosine': z1 = F.normalize(z1, dim=-1) z2 = F.normalize(z2, dim=-1) return torch.matmul(z1, z2.T) elif self.similarity == 'dot': return torch.matmul(z1, z2.T) else: raise ValueError(f"Unknown similarity: {self.similarity}") def forward(self, anchors, positives, negatives=None): """ Args: anchors: [batch_size, embedding_dim] positives: [batch_size, embedding_dim] negatives: [batch_size, num_negatives, embedding_dim] or None """ batch_size = anchors.shape[0] # Compute positive similarities pos_sim = self.compute_similarity(anchors, positives) pos_sim = torch.diagonal(pos_sim) # [batch_size] # Compute negative similarities if negatives is None: # Use other samples in batch as negatives all_sim = self.compute_similarity(anchors, anchors) # Mask out self-similarities mask = torch.eye(batch_size, dtype=torch.bool, device=anchors.device) neg_sim = all_sim.masked_fill(mask, float('-inf')) else: # Use provided negatives neg_sim = self.compute_similarity( anchors.unsqueeze(1), negatives ).squeeze(1) # Compute InfoNCE loss logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) / self.temperature labels = torch.zeros(batch_size, dtype=torch.long, device=anchors.device) return F.cross_entropy(logits, labels) # Usage example model = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128) ) loss_fn = InfoNCE(temperature=0.07) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop for images, _ in dataloader: # Create augmented views view1 = augment(images) view2 = augment(images) # Get embeddings z1 = model(view1) z2 = model(view2) # Compute loss (symmetrized) loss = (loss_fn(z1, z2) + loss_fn(z2, z1)) / 2 optimizer.zero_grad() loss.backward() optimizer.step()
NT-Xent Loss (SimCLR)
Definition
Normalized Temperature-scaled Cross Entropy loss:
Implementation
class NTXentLoss(nn.Module): """ NT-Xent loss used in SimCLR Symmetric loss over augmented pairs """ def __init__(self, temperature=0.5, use_cosine_similarity=True): super().__init__() self.temperature = temperature self.use_cosine_similarity = use_cosine_similarity def forward(self, z1, z2): """ Args: z1: Embeddings from first augmentation [batch_size, embedding_dim] z2: Embeddings from second augmentation [batch_size, embedding_dim] """ batch_size = z1.shape[0] # Normalize if using cosine similarity if self.use_cosine_similarity: z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) # Concatenate representations representations = torch.cat([z1, z2], dim=0) # [2*batch_size, dim] # Compute similarity matrix similarity_matrix = torch.matmul(representations, representations.T) # Create mask for positive pairs mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z1.device) # Positive pairs are (i, i+batch_size) and (i+batch_size, i) mask[:batch_size, batch_size:].fill_diagonal_(True) mask[batch_size:, :batch_size].fill_diagonal_(True) # Extract positive and negative similarities positives = similarity_matrix[mask].view(2 * batch_size, 1) negatives = similarity_matrix[~mask].view(2 * batch_size, -1) # Compute loss logits = torch.cat([positives, negatives], dim=1) / self.temperature labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z1.device) return F.cross_entropy(logits, labels) # SimCLR training class SimCLR(nn.Module): def __init__(self, encoder, projection_dim=128): super().__init__() self.encoder = encoder self.projection_head = nn.Sequential( nn.Linear(encoder.output_dim, encoder.output_dim), nn.ReLU(), nn.Linear(encoder.output_dim, projection_dim) ) def forward(self, x): features = self.encoder(x) projections = self.projection_head(features) return features, projections # Training simclr = SimCLR(ResNet50()) criterion = NTXentLoss(temperature=0.5) for images in dataloader: # Strong augmentations aug1 = strong_augment(images) aug2 = strong_augment(images) # Get projections _, proj1 = simclr(aug1) _, proj2 = simclr(aug2) # Compute symmetric loss loss = criterion(proj1, proj2)
Triplet Loss
Definition
Triplet loss directly optimizes the relative distances:
Where:
- a is the anchor
- p is the positive sample
- n is the negative sample
- margin is the minimum separation
Implementation
class TripletLoss(nn.Module): """ Triplet loss with hard negative mining """ def __init__(self, margin=1.0, distance='euclidean', mining='hard'): super().__init__() self.margin = margin self.distance = distance self.mining = mining def compute_distance(self, x1, x2): """Compute pairwise distances""" if self.distance == 'euclidean': return torch.cdist(x1, x2, p=2) elif self.distance == 'cosine': x1_norm = F.normalize(x1, dim=1) x2_norm = F.normalize(x2, dim=1) return 1 - torch.matmul(x1_norm, x2_norm.T) def hard_negative_mining(self, anchor, positive, negatives): """Select hardest negatives for each anchor""" pos_dist = self.compute_distance(anchor, positive) neg_dist = self.compute_distance(anchor, negatives) # Get diagonal (anchor-positive distances) pos_dist = torch.diagonal(pos_dist) if self.mining == 'hard': # Select hardest negative (closest) neg_dist, _ = neg_dist.min(dim=1) elif self.mining == 'semi-hard': # Select semi-hard negatives mask = neg_dist > pos_dist.unsqueeze(1) neg_dist = torch.where( mask, neg_dist, torch.tensor(float('inf'), device=neg_dist.device) ) neg_dist, _ = neg_dist.min(dim=1) elif self.mining == 'all': # Use all negatives neg_dist = neg_dist.mean(dim=1) return pos_dist, neg_dist def forward(self, anchor, positive, negative): """ Args: anchor: [batch_size, embedding_dim] positive: [batch_size, embedding_dim] negative: [batch_size, num_negatives, embedding_dim] """ if negative.dim() == 2: negative = negative.unsqueeze(1) batch_size = anchor.shape[0] num_negatives = negative.shape[1] # Reshape for mining negative_flat = negative.view(-1, negative.shape[-1]) # Mine hard negatives pos_dist, neg_dist = self.hard_negative_mining( anchor, positive, negative_flat ) # Compute triplet loss loss = F.relu(pos_dist - neg_dist + self.margin) return loss.mean() # FaceNet-style training class FaceNet(nn.Module): def __init__(self, backbone, embedding_dim=128): super().__init__() self.backbone = backbone self.embedding_layer = nn.Linear(backbone.output_dim, embedding_dim) def forward(self, x): features = self.backbone(x) embeddings = self.embedding_layer(features) # L2 normalize for angular distance return F.normalize(embeddings, p=2, dim=1) # Triplet mining and batch construction class TripletDataLoader: def __init__(self, dataset, batch_size=32, p=4, k=4): """ P-K sampling: P classes, K samples per class """ self.dataset = dataset self.p = p self.k = k self.batch_size = p * k def generate_triplets(self, embeddings, labels): """Generate all valid triplets in batch""" triplets = [] for i in range(len(embeddings)): anchor = embeddings[i] anchor_label = labels[i] # Find positives (same class, different sample) positive_mask = (labels == anchor_label) & (torch.arange(len(labels)) != i) positives = embeddings[positive_mask] # Find negatives (different class) negative_mask = labels != anchor_label negatives = embeddings[negative_mask] if len(positives) > 0 and len(negatives) > 0: triplets.append((anchor, positives, negatives)) return triplets
Advanced Techniques
MoCo (Momentum Contrast)
class MoCo(nn.Module): """ Momentum Contrast for Unsupervised Visual Representation Learning """ def __init__(self, encoder, dim=128, K=65536, m=0.999, T=0.07): super().__init__() self.K = K # Queue size self.m = m # Momentum coefficient self.T = T # Temperature # Create encoders self.encoder_q = encoder self.encoder_k = copy.deepcopy(encoder) # Projection heads self.projection_q = nn.Linear(encoder.output_dim, dim) self.projection_k = nn.Linear(encoder.output_dim, dim) # Stop gradient for key encoder for param in self.encoder_k.parameters(): param.requires_grad = False for param in self.projection_k.parameters(): param.requires_grad = False # Create queue self.register_buffer("queue", torch.randn(dim, K)) self.queue = F.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) @torch.no_grad() def momentum_update(self): """Momentum update of key encoder""" for param_q, param_k in zip( self.encoder_q.parameters(), self.encoder_k.parameters() ): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) for param_q, param_k in zip( self.projection_q.parameters(), self.projection_k.parameters() ): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def dequeue_and_enqueue(self, keys): """Update queue with new keys""" batch_size = keys.shape[0] ptr = int(self.queue_ptr) # Replace oldest with newest self.queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K self.queue_ptr[0] = ptr def forward(self, x_q, x_k): """ Args: x_q: Query images x_k: Key images (augmented version of x_q) """ # Query features q = self.projection_q(self.encoder_q(x_q)) q = F.normalize(q, dim=1) # Key features (no gradient) with torch.no_grad(): self.momentum_update() k = self.projection_k(self.encoder_k(x_k)) k = F.normalize(k, dim=1) # Positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # Negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # Logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) logits /= self.T # Labels: positives are at position 0 labels = torch.zeros(logits.shape[0], dtype=torch.long, device=q.device) # Update queue self.dequeue_and_enqueue(k) return F.cross_entropy(logits, labels)
CLIP (Contrastive Language-Image Pretraining)
class CLIP(nn.Module): """ CLIP: Learning Transferable Visual Models From Natural Language Supervision """ def __init__(self, image_encoder, text_encoder, embed_dim=512, temperature=0.07): super().__init__() self.image_encoder = image_encoder self.text_encoder = text_encoder self.temperature = temperature # Projection layers self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim) self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim) # Learnable temperature self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) def encode_image(self, image): features = self.image_encoder(image) embeddings = self.image_projection(features) return F.normalize(embeddings, dim=-1) def encode_text(self, text): features = self.text_encoder(text) embeddings = self.text_projection(features) return F.normalize(embeddings, dim=-1) def forward(self, images, texts): """ Compute contrastive loss between images and texts """ # Get normalized embeddings image_embeds = self.encode_image(images) text_embeds = self.encode_text(texts) # Compute similarity matrix logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_embeds @ text_embeds.T logits_per_text = logits_per_image.T # Create labels (diagonal elements are positive pairs) batch_size = images.shape[0] labels = torch.arange(batch_size, device=images.device) # Compute cross-entropy loss in both directions loss_i2t = F.cross_entropy(logits_per_image, labels) loss_t2i = F.cross_entropy(logits_per_text, labels) return (loss_i2t + loss_t2i) / 2 # Zero-shot classification with CLIP def zero_shot_classifier(clip_model, class_names, templates): """ Create zero-shot classifier from text descriptions """ with torch.no_grad(): text_features = [] for class_name in class_names: # Use multiple templates class_embeddings = [] for template in templates: text = template.format(class_name) text_tokens = tokenize(text) class_embeddings.append(clip_model.encode_text(text_tokens)) # Average over templates class_embedding = torch.stack(class_embeddings).mean(dim=0) text_features.append(class_embedding) text_features = torch.stack(text_features) text_features = F.normalize(text_features, dim=-1) return text_features # Inference def classify_image(image, clip_model, text_features): with torch.no_grad(): image_features = clip_model.encode_image(image) image_features = F.normalize(image_features, dim=-1) # Compute similarities similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1) return similarities
Temperature Parameter Analysis
def analyze_temperature_effect(embeddings, temperatures=[0.01, 0.07, 0.5, 1.0, 2.0]): """ Analyze how temperature affects the contrastive loss distribution """ results = {} for temp in temperatures: # Compute similarities similarities = torch.matmul(embeddings, embeddings.T) # Apply temperature scaling scaled_sim = similarities / temp # Compute softmax distribution probs = F.softmax(scaled_sim, dim=1) # Analyze concentration entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1).mean() max_prob = probs.max(dim=1)[0].mean() results[temp] = { 'entropy': entropy.item(), 'max_prob': max_prob.item(), 'effective_negatives': 1 / max_prob.item() } return results
Practical Guidelines
1. Choosing the Right Loss
Loss Function | Use Case | Key Advantage | Limitation |
---|---|---|---|
InfoNCE | Large-scale pretraining | Scales with batch size | Requires many negatives |
NT-Xent | Self-supervised vision | Strong augmentations | Symmetric loss only |
Triplet | Metric learning | Direct margin control | Hard negative mining needed |
MoCo | Limited batch size | Memory bank | Momentum update complexity |
2. Hyperparameter Selection
def optimal_temperature(num_negatives): """ Heuristic for temperature selection based on number of negatives """ # Lower temperature for more negatives return 1.0 / np.sqrt(num_negatives) def optimal_margin(embedding_dim): """ Heuristic for triplet margin based on embedding dimension """ # Scale margin with embedding space volume return 0.2 * np.sqrt(embedding_dim)
3. Data Augmentation Strategy
class ContrastiveAugmentation: """ Strong augmentations for contrastive learning """ def __init__(self, strength='strong'): self.strength = strength if strength == 'strong': self.transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23) ], p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: # weak self.transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __call__(self, x): return self.transform(x)
Common Pitfalls
1. Collapse Prevention
def prevent_collapse(model, embeddings): """ Monitor and prevent representation collapse """ # Check embedding variance std = embeddings.std(dim=0).mean() if std < 0.1: print("Warning: Potential collapse detected!") # Add regularization return std * 0.1 # Variance regularization term return 0
2. Negative Sampling
class HardNegativeSampler: """ Efficient hard negative sampling for contrastive learning """ def __init__(self, num_negatives=1000, hard_ratio=0.5): self.num_negatives = num_negatives self.hard_ratio = hard_ratio def sample(self, anchor, candidates, labels): # Compute similarities similarities = F.cosine_similarity( anchor.unsqueeze(1), candidates.unsqueeze(0), dim=2 ) # Separate hard and easy negatives num_hard = int(self.num_negatives * self.hard_ratio) num_easy = self.num_negatives - num_hard # Sample hard negatives (high similarity, wrong label) hard_mask = labels != labels[0] hard_similarities = similarities[:, hard_mask] hard_indices = hard_similarities.topk(num_hard, dim=1)[1] # Sample easy negatives randomly easy_indices = torch.randperm(len(candidates))[:num_easy] return torch.cat([candidates[hard_indices], candidates[easy_indices]])
Related Concepts
- Cross-Entropy Loss - Classification losses
- MSE/MAE Loss - Regression losses
- Focal Loss - Imbalanced classification
- Self-Supervised Learning - Training without labels
- Vision Transformers - Often trained with contrastive losses