Contrastive Loss Functions

Master contrastive loss functions including InfoNCE, NT-Xent, and Triplet Loss for representation learning and self-supervised training.

Best viewed on desktop for optimal interactive experience

Contrastive Loss: Learning Representations by Comparison

Contrastive losses learn representations by pulling similar samples together and pushing dissimilar samples apart in embedding space, forming the foundation of modern self-supervised and multimodal learning.

Interactive Contrastive Learning Explorer

Visualize how different contrastive losses shape the embedding space:

Contrastive Loss Configuration

Explore different contrastive loss functions and their behavior

Lower temperature creates sharper distributions
Animation Speed:

Embedding Space

Temperature Effect on Similarity Scores

Loss Breakdown

Total Loss
2.3979
Positive Similarity
0.000
Avg Negative Sim
0.000
Hard Negative Sim
0.000

InfoNCE Formula

L = -log(exp(sim(z, z+) / τ) / Σ exp(sim(z, zi) / τ))
Numerator: 1.000 | Denominator: 11.000

Key Properties

InfoNCE

  • • Maximizes mutual information
  • • Scales with batch size
  • • Requires many negatives
  • • Used in CLIP, MoCo

NT-Xent

  • • Normalized temperature scaling
  • • Symmetric loss
  • • Data augmentation critical
  • • Used in SimCLR, BYOL

Triplet Loss

  • • Direct margin optimization
  • • Hard negative mining
  • • Simpler formulation
  • • Used in FaceNet, metric learning

Core Principle

Contrastive learning operates on a simple principle:

  1. Maximize agreement between augmented views of the same data (positives)
  2. Minimize agreement between different data points (negatives)
  3. Learn representations that capture semantic similarity

InfoNCE Loss

Definition

InfoNCE (Noise Contrastive Estimation) maximizes the lower bound on mutual information:

InfoNCE = -log exp(sim(zi, zi^+) / τ)Σk=12N 1[k ≠ i] exp(sim(zi, zk) / τ)

Where:

  • zi is the anchor embedding
  • zi^+ is the positive embedding
  • τ is the temperature parameter
  • sim(·, ·) is the similarity function (usually cosine)

Implementation

import torch import torch.nn as nn import torch.nn.functional as F class InfoNCE(nn.Module): """ InfoNCE loss for contrastive learning Used in CLIP, MoCo, and other vision-language models """ def __init__(self, temperature=0.07, similarity='cosine'): super().__init__() self.temperature = temperature self.similarity = similarity def compute_similarity(self, z1, z2): """Compute similarity between embeddings""" if self.similarity == 'cosine': z1 = F.normalize(z1, dim=-1) z2 = F.normalize(z2, dim=-1) return torch.matmul(z1, z2.T) elif self.similarity == 'dot': return torch.matmul(z1, z2.T) else: raise ValueError(f"Unknown similarity: {self.similarity}") def forward(self, anchors, positives, negatives=None): """ Args: anchors: [batch_size, embedding_dim] positives: [batch_size, embedding_dim] negatives: [batch_size, num_negatives, embedding_dim] or None """ batch_size = anchors.shape[0] # Compute positive similarities pos_sim = self.compute_similarity(anchors, positives) pos_sim = torch.diagonal(pos_sim) # [batch_size] # Compute negative similarities if negatives is None: # Use other samples in batch as negatives all_sim = self.compute_similarity(anchors, anchors) # Mask out self-similarities mask = torch.eye(batch_size, dtype=torch.bool, device=anchors.device) neg_sim = all_sim.masked_fill(mask, float('-inf')) else: # Use provided negatives neg_sim = self.compute_similarity( anchors.unsqueeze(1), negatives ).squeeze(1) # Compute InfoNCE loss logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) / self.temperature labels = torch.zeros(batch_size, dtype=torch.long, device=anchors.device) return F.cross_entropy(logits, labels) # Usage example model = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128) ) loss_fn = InfoNCE(temperature=0.07) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop for images, _ in dataloader: # Create augmented views view1 = augment(images) view2 = augment(images) # Get embeddings z1 = model(view1) z2 = model(view2) # Compute loss (symmetrized) loss = (loss_fn(z1, z2) + loss_fn(z2, z1)) / 2 optimizer.zero_grad() loss.backward() optimizer.step()

NT-Xent Loss (SimCLR)

Definition

Normalized Temperature-scaled Cross Entropy loss:

NT-Xent = -log exp(sim(zi, zj) / τ)Σk=12N 1[k ≠ i] exp(sim(zi, zk) / τ)

Implementation

class NTXentLoss(nn.Module): """ NT-Xent loss used in SimCLR Symmetric loss over augmented pairs """ def __init__(self, temperature=0.5, use_cosine_similarity=True): super().__init__() self.temperature = temperature self.use_cosine_similarity = use_cosine_similarity def forward(self, z1, z2): """ Args: z1: Embeddings from first augmentation [batch_size, embedding_dim] z2: Embeddings from second augmentation [batch_size, embedding_dim] """ batch_size = z1.shape[0] # Normalize if using cosine similarity if self.use_cosine_similarity: z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) # Concatenate representations representations = torch.cat([z1, z2], dim=0) # [2*batch_size, dim] # Compute similarity matrix similarity_matrix = torch.matmul(representations, representations.T) # Create mask for positive pairs mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z1.device) # Positive pairs are (i, i+batch_size) and (i+batch_size, i) mask[:batch_size, batch_size:].fill_diagonal_(True) mask[batch_size:, :batch_size].fill_diagonal_(True) # Extract positive and negative similarities positives = similarity_matrix[mask].view(2 * batch_size, 1) negatives = similarity_matrix[~mask].view(2 * batch_size, -1) # Compute loss logits = torch.cat([positives, negatives], dim=1) / self.temperature labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z1.device) return F.cross_entropy(logits, labels) # SimCLR training class SimCLR(nn.Module): def __init__(self, encoder, projection_dim=128): super().__init__() self.encoder = encoder self.projection_head = nn.Sequential( nn.Linear(encoder.output_dim, encoder.output_dim), nn.ReLU(), nn.Linear(encoder.output_dim, projection_dim) ) def forward(self, x): features = self.encoder(x) projections = self.projection_head(features) return features, projections # Training simclr = SimCLR(ResNet50()) criterion = NTXentLoss(temperature=0.5) for images in dataloader: # Strong augmentations aug1 = strong_augment(images) aug2 = strong_augment(images) # Get projections _, proj1 = simclr(aug1) _, proj2 = simclr(aug2) # Compute symmetric loss loss = criterion(proj1, proj2)

Triplet Loss

Definition

Triplet loss directly optimizes the relative distances:

triplet = max(0, d(a, p) - d(a, n) + margin)

Where:

  • a is the anchor
  • p is the positive sample
  • n is the negative sample
  • margin is the minimum separation

Implementation

class TripletLoss(nn.Module): """ Triplet loss with hard negative mining """ def __init__(self, margin=1.0, distance='euclidean', mining='hard'): super().__init__() self.margin = margin self.distance = distance self.mining = mining def compute_distance(self, x1, x2): """Compute pairwise distances""" if self.distance == 'euclidean': return torch.cdist(x1, x2, p=2) elif self.distance == 'cosine': x1_norm = F.normalize(x1, dim=1) x2_norm = F.normalize(x2, dim=1) return 1 - torch.matmul(x1_norm, x2_norm.T) def hard_negative_mining(self, anchor, positive, negatives): """Select hardest negatives for each anchor""" pos_dist = self.compute_distance(anchor, positive) neg_dist = self.compute_distance(anchor, negatives) # Get diagonal (anchor-positive distances) pos_dist = torch.diagonal(pos_dist) if self.mining == 'hard': # Select hardest negative (closest) neg_dist, _ = neg_dist.min(dim=1) elif self.mining == 'semi-hard': # Select semi-hard negatives mask = neg_dist > pos_dist.unsqueeze(1) neg_dist = torch.where( mask, neg_dist, torch.tensor(float('inf'), device=neg_dist.device) ) neg_dist, _ = neg_dist.min(dim=1) elif self.mining == 'all': # Use all negatives neg_dist = neg_dist.mean(dim=1) return pos_dist, neg_dist def forward(self, anchor, positive, negative): """ Args: anchor: [batch_size, embedding_dim] positive: [batch_size, embedding_dim] negative: [batch_size, num_negatives, embedding_dim] """ if negative.dim() == 2: negative = negative.unsqueeze(1) batch_size = anchor.shape[0] num_negatives = negative.shape[1] # Reshape for mining negative_flat = negative.view(-1, negative.shape[-1]) # Mine hard negatives pos_dist, neg_dist = self.hard_negative_mining( anchor, positive, negative_flat ) # Compute triplet loss loss = F.relu(pos_dist - neg_dist + self.margin) return loss.mean() # FaceNet-style training class FaceNet(nn.Module): def __init__(self, backbone, embedding_dim=128): super().__init__() self.backbone = backbone self.embedding_layer = nn.Linear(backbone.output_dim, embedding_dim) def forward(self, x): features = self.backbone(x) embeddings = self.embedding_layer(features) # L2 normalize for angular distance return F.normalize(embeddings, p=2, dim=1) # Triplet mining and batch construction class TripletDataLoader: def __init__(self, dataset, batch_size=32, p=4, k=4): """ P-K sampling: P classes, K samples per class """ self.dataset = dataset self.p = p self.k = k self.batch_size = p * k def generate_triplets(self, embeddings, labels): """Generate all valid triplets in batch""" triplets = [] for i in range(len(embeddings)): anchor = embeddings[i] anchor_label = labels[i] # Find positives (same class, different sample) positive_mask = (labels == anchor_label) & (torch.arange(len(labels)) != i) positives = embeddings[positive_mask] # Find negatives (different class) negative_mask = labels != anchor_label negatives = embeddings[negative_mask] if len(positives) > 0 and len(negatives) > 0: triplets.append((anchor, positives, negatives)) return triplets

Advanced Techniques

MoCo (Momentum Contrast)

class MoCo(nn.Module): """ Momentum Contrast for Unsupervised Visual Representation Learning """ def __init__(self, encoder, dim=128, K=65536, m=0.999, T=0.07): super().__init__() self.K = K # Queue size self.m = m # Momentum coefficient self.T = T # Temperature # Create encoders self.encoder_q = encoder self.encoder_k = copy.deepcopy(encoder) # Projection heads self.projection_q = nn.Linear(encoder.output_dim, dim) self.projection_k = nn.Linear(encoder.output_dim, dim) # Stop gradient for key encoder for param in self.encoder_k.parameters(): param.requires_grad = False for param in self.projection_k.parameters(): param.requires_grad = False # Create queue self.register_buffer("queue", torch.randn(dim, K)) self.queue = F.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) @torch.no_grad() def momentum_update(self): """Momentum update of key encoder""" for param_q, param_k in zip( self.encoder_q.parameters(), self.encoder_k.parameters() ): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) for param_q, param_k in zip( self.projection_q.parameters(), self.projection_k.parameters() ): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def dequeue_and_enqueue(self, keys): """Update queue with new keys""" batch_size = keys.shape[0] ptr = int(self.queue_ptr) # Replace oldest with newest self.queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K self.queue_ptr[0] = ptr def forward(self, x_q, x_k): """ Args: x_q: Query images x_k: Key images (augmented version of x_q) """ # Query features q = self.projection_q(self.encoder_q(x_q)) q = F.normalize(q, dim=1) # Key features (no gradient) with torch.no_grad(): self.momentum_update() k = self.projection_k(self.encoder_k(x_k)) k = F.normalize(k, dim=1) # Positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # Negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # Logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) logits /= self.T # Labels: positives are at position 0 labels = torch.zeros(logits.shape[0], dtype=torch.long, device=q.device) # Update queue self.dequeue_and_enqueue(k) return F.cross_entropy(logits, labels)

CLIP (Contrastive Language-Image Pretraining)

class CLIP(nn.Module): """ CLIP: Learning Transferable Visual Models From Natural Language Supervision """ def __init__(self, image_encoder, text_encoder, embed_dim=512, temperature=0.07): super().__init__() self.image_encoder = image_encoder self.text_encoder = text_encoder self.temperature = temperature # Projection layers self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim) self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim) # Learnable temperature self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) def encode_image(self, image): features = self.image_encoder(image) embeddings = self.image_projection(features) return F.normalize(embeddings, dim=-1) def encode_text(self, text): features = self.text_encoder(text) embeddings = self.text_projection(features) return F.normalize(embeddings, dim=-1) def forward(self, images, texts): """ Compute contrastive loss between images and texts """ # Get normalized embeddings image_embeds = self.encode_image(images) text_embeds = self.encode_text(texts) # Compute similarity matrix logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_embeds @ text_embeds.T logits_per_text = logits_per_image.T # Create labels (diagonal elements are positive pairs) batch_size = images.shape[0] labels = torch.arange(batch_size, device=images.device) # Compute cross-entropy loss in both directions loss_i2t = F.cross_entropy(logits_per_image, labels) loss_t2i = F.cross_entropy(logits_per_text, labels) return (loss_i2t + loss_t2i) / 2 # Zero-shot classification with CLIP def zero_shot_classifier(clip_model, class_names, templates): """ Create zero-shot classifier from text descriptions """ with torch.no_grad(): text_features = [] for class_name in class_names: # Use multiple templates class_embeddings = [] for template in templates: text = template.format(class_name) text_tokens = tokenize(text) class_embeddings.append(clip_model.encode_text(text_tokens)) # Average over templates class_embedding = torch.stack(class_embeddings).mean(dim=0) text_features.append(class_embedding) text_features = torch.stack(text_features) text_features = F.normalize(text_features, dim=-1) return text_features # Inference def classify_image(image, clip_model, text_features): with torch.no_grad(): image_features = clip_model.encode_image(image) image_features = F.normalize(image_features, dim=-1) # Compute similarities similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1) return similarities

Temperature Parameter Analysis

def analyze_temperature_effect(embeddings, temperatures=[0.01, 0.07, 0.5, 1.0, 2.0]): """ Analyze how temperature affects the contrastive loss distribution """ results = {} for temp in temperatures: # Compute similarities similarities = torch.matmul(embeddings, embeddings.T) # Apply temperature scaling scaled_sim = similarities / temp # Compute softmax distribution probs = F.softmax(scaled_sim, dim=1) # Analyze concentration entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1).mean() max_prob = probs.max(dim=1)[0].mean() results[temp] = { 'entropy': entropy.item(), 'max_prob': max_prob.item(), 'effective_negatives': 1 / max_prob.item() } return results

Practical Guidelines

1. Choosing the Right Loss

Loss FunctionUse CaseKey AdvantageLimitation
InfoNCELarge-scale pretrainingScales with batch sizeRequires many negatives
NT-XentSelf-supervised visionStrong augmentationsSymmetric loss only
TripletMetric learningDirect margin controlHard negative mining needed
MoCoLimited batch sizeMemory bankMomentum update complexity

2. Hyperparameter Selection

def optimal_temperature(num_negatives): """ Heuristic for temperature selection based on number of negatives """ # Lower temperature for more negatives return 1.0 / np.sqrt(num_negatives) def optimal_margin(embedding_dim): """ Heuristic for triplet margin based on embedding dimension """ # Scale margin with embedding space volume return 0.2 * np.sqrt(embedding_dim)

3. Data Augmentation Strategy

class ContrastiveAugmentation: """ Strong augmentations for contrastive learning """ def __init__(self, strength='strong'): self.strength = strength if strength == 'strong': self.transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23) ], p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: # weak self.transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __call__(self, x): return self.transform(x)

Common Pitfalls

1. Collapse Prevention

def prevent_collapse(model, embeddings): """ Monitor and prevent representation collapse """ # Check embedding variance std = embeddings.std(dim=0).mean() if std < 0.1: print("Warning: Potential collapse detected!") # Add regularization return std * 0.1 # Variance regularization term return 0

2. Negative Sampling

class HardNegativeSampler: """ Efficient hard negative sampling for contrastive learning """ def __init__(self, num_negatives=1000, hard_ratio=0.5): self.num_negatives = num_negatives self.hard_ratio = hard_ratio def sample(self, anchor, candidates, labels): # Compute similarities similarities = F.cosine_similarity( anchor.unsqueeze(1), candidates.unsqueeze(0), dim=2 ) # Separate hard and easy negatives num_hard = int(self.num_negatives * self.hard_ratio) num_easy = self.num_negatives - num_hard # Sample hard negatives (high similarity, wrong label) hard_mask = labels != labels[0] hard_similarities = similarities[:, hard_mask] hard_indices = hard_similarities.topk(num_hard, dim=1)[1] # Sample easy negatives randomly easy_indices = torch.randperm(len(candidates))[:num_easy] return torch.cat([candidates[hard_indices], candidates[easy_indices]])

If you found this explanation helpful, consider sharing it with others.

Mastodon