Contrastive Learning
Learn representations by pulling similar samples together and pushing dissimilar ones apart
Best viewed on desktop for optimal interactive experience
Contrastive Learning
Contrastive learning has revolutionized self-supervised representation learning by teaching models to distinguish between similar and dissimilar samples without explicit labels. This approach powers systems like CLIP, SimCLR, and modern embedding models.
Interactive Learning Visualization
Contrastive Learning
Learn representations by pulling similar samples together and pushing dissimilar ones apart
Training Configuration
Embedding Space
Similarity Matrix
Cat | Cat+ | Dog | Dog+ | Car | Car+ | Plane | Plane+ | |
---|---|---|---|---|---|---|---|---|
Cat | 1.00 | 1.00 | -0.09 | 0.07 | 0.17 | -0.02 | -1.00 | -1.00 |
Cat+ | 1.00 | 1.00 | -0.02 | 0.14 | 0.10 | -0.09 | -1.00 | -1.00 |
Dog | -0.09 | -0.02 | 1.00 | 0.99 | -1.00 | -0.99 | 0.08 | 0.01 |
Dog+ | 0.07 | 0.14 | 0.99 | 1.00 | -0.97 | -1.00 | -0.09 | -0.15 |
Car | 0.17 | 0.10 | -1.00 | -0.97 | 1.00 | 0.98 | -0.16 | -0.09 |
Car+ | -0.02 | -0.09 | -0.99 | -1.00 | 0.98 | 1.00 | 0.03 | 0.10 |
Plane | -1.00 | -1.00 | 0.08 | -0.09 | -0.16 | 0.03 | 1.00 | 1.00 |
Plane+ | -1.00 | -1.00 | 0.01 | -0.15 | -0.09 | 0.10 | 1.00 | 1.00 |
Diagonal: Self-similarity (1.0). Bold: Positive pairs that should have high similarity.
Loss Function
InfoNCE Loss
Current Loss Breakdown
PyTorch Implementation
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.07):
"""
InfoNCE loss for contrastive learning
features: [2N, D] where first N are anchors, next N are positives
"""
batch_size = features.shape[0] // 2
# Normalize features
features = F.normalize(features, dim=1)
# Compute similarity matrix
similarity = torch.matmul(features, features.T) / temperature
# Create labels (positive pairs are on diagonal after shift)
labels = torch.cat([
torch.arange(batch_size, 2*batch_size),
torch.arange(batch_size)
]).to(features.device)
# Mask out self-similarity
mask = torch.eye(2*batch_size, dtype=torch.bool).to(features.device)
similarity = similarity.masked_fill(mask, -float('inf'))
# Compute loss
loss = F.cross_entropy(similarity, labels)
return loss
Key Concepts in Contrastive Learning
Core Principles
- • Pull positive pairs together
- • Push negative pairs apart
- • Learn without labels (self-supervised)
- • Temperature controls concentration
- • More negatives improve representations
Applications
- • SimCLR for vision
- • CLIP for vision-language
- • Sentence-BERT for text
- • MoCo for unsupervised learning
- • BYOL without negatives
Core Principles
The Contrastive Objective
Contrastive learning optimizes representations by:
- Pulling Together: Positive pairs (augmentations of the same sample)
- Pushing Apart: Negative pairs (different samples)
- Learning Invariances: Robust features across transformations
InfoNCE Loss
The InfoNCE (Noise Contrastive Estimation) loss is the foundation:
L = -log(exp(sim(a,p)/τ) / Σexp(sim(a,n)/τ))
Where:
a
: anchor samplep
: positive samplen
: negative samplesτ
: temperature parametersim
: similarity function (usually cosine)
Key Components
1. Data Augmentation
Visual Domain:
- Random cropping
- Color jittering
- Gaussian blur
- Random flipping
Text Domain:
- Token dropout
- Paraphrasing
- Back-translation
- Span corruption
2. Temperature Scaling
The temperature parameter τ controls the concentration of the distribution:
- Low τ (0.01-0.1): Sharp distribution, harder negatives
- High τ (0.5-1.0): Smooth distribution, softer learning
3. Negative Sampling
More negatives generally improve representations:
- In-batch negatives: Other samples in the minibatch
- Memory bank: Store and reuse past embeddings
- Hard negative mining: Focus on challenging examples
Popular Methods
SimCLR (Vision)
def simclr_loss(z_i, z_j, temperature=0.07): """SimCLR loss for image representations""" # Normalize embeddings z_i = F.normalize(z_i, dim=1) z_j = F.normalize(z_j, dim=1) # Concatenate representations representations = torch.cat([z_i, z_j], dim=0) # Compute similarity matrix similarity_matrix = torch.matmul( representations, representations.T ) / temperature # Create labels for positive pairs batch_size = z_i.shape[0] labels = torch.cat([ torch.arange(batch_size) + batch_size, torch.arange(batch_size) ]) # Mask out self-similarity mask = torch.eye(2 * batch_size, dtype=torch.bool) similarity_matrix = similarity_matrix.masked_fill( mask, -float('inf') ) # Compute loss loss = F.cross_entropy(similarity_matrix, labels) return loss
CLIP (Vision-Language)
def clip_loss(image_embeddings, text_embeddings, temperature=0.07): """CLIP contrastive loss""" # Normalize embeddings image_embeddings = F.normalize(image_embeddings, dim=-1) text_embeddings = F.normalize(text_embeddings, dim=-1) # Compute similarity logits = torch.matmul( image_embeddings, text_embeddings.T ) * torch.exp(temperature) # Labels: diagonal elements are positive pairs labels = torch.arange(len(logits)) # Symmetric loss loss_i2t = F.cross_entropy(logits, labels) loss_t2i = F.cross_entropy(logits.T, labels) return (loss_i2t + loss_t2i) / 2
Advanced Techniques
1. Momentum Contrast (MoCo)
Maintains a queue of negative samples:
class MoCo(nn.Module): def __init__(self, encoder, dim=128, K=65536, m=0.999): super().__init__() self.K = K # queue size self.m = m # momentum coefficient # Create encoders self.encoder_q = encoder self.encoder_k = copy.deepcopy(encoder) # Stop gradients to key encoder for param in self.encoder_k.parameters(): param.requires_grad = False # Create queue self.register_buffer( "queue", torch.randn(dim, K) ) self.queue = F.normalize(self.queue, dim=0) self.register_buffer( "queue_ptr", torch.zeros(1, dtype=torch.long) )
2. SwAV (Clustering)
Combines contrastive learning with clustering:
def swav_loss(z1, z2, prototypes, temperature=0.1): """SwAV loss with online clustering""" # Compute assignments using Sinkhorn-Knopp q1 = sinkhorn(torch.matmul(z1, prototypes.T) / temperature) q2 = sinkhorn(torch.matmul(z2, prototypes.T) / temperature) # Cross-entropy between assignments and predictions loss = -torch.mean( torch.sum(q1 * torch.log(p2), dim=1) + torch.sum(q2 * torch.log(p1), dim=1) ) return loss
Applications
1. Visual Representation Learning
- Pre-training on unlabeled images
- Transfer learning for downstream tasks
- Few-shot classification
2. Language Model Pre-training
- Sentence embeddings (Sentence-BERT)
- Cross-lingual alignment
- Document similarity
3. Multimodal Learning
- Vision-language models (CLIP, ALIGN)
- Audio-visual correspondence
- Cross-modal retrieval
Best Practices
Training Tips
- Large Batch Sizes: More negatives improve learning (256-8192)
- Strong Augmentations: Encourage invariance learning
- Projection Head: Non-linear projection improves representations
- Learning Rate Schedule: Cosine annealing with warmup
- Weight Decay: Prevent representation collapse
Common Pitfalls
-
Representation Collapse: All samples map to same point
- Solution: Use stop-gradient, asymmetric architectures
-
False Negatives: Similar samples treated as negatives
- Solution: Careful data curation, supervised fine-tuning
-
Temperature Sensitivity: Performance varies with τ
- Solution: Grid search, learnable temperature
Performance Considerations
Method | Batch Size | Memory | Training Time | ImageNet Acc |
---|---|---|---|---|
SimCLR | 4096 | High | 1000 epochs | 76.5% |
MoCo v3 | 256 | Medium | 300 epochs | 76.7% |
SwAV | 4096 | High | 800 epochs | 75.3% |
BYOL | 4096 | High | 1000 epochs | 74.3% |
Future Directions
Emerging Trends
- Masked Autoencoders: Combining masking with contrastive learning
- Equivariant Representations: Learning transformation-aware features
- Hierarchical Contrastive: Multi-scale representation learning
- Efficient Negatives: Reducing computational requirements
Open Challenges
- Theoretical understanding of contrastive learning
- Optimal augmentation strategies
- Scaling to trillion-parameter models
- Combining with other self-supervised objectives
Conclusion
Contrastive learning has become a cornerstone of modern representation learning, enabling models to learn powerful features from unlabeled data. The interactive visualization above demonstrates how the contrastive objective shapes the embedding space, pulling positive pairs together while pushing negatives apart.
The success of methods like CLIP and SimCLR shows that contrastive learning can rival or exceed supervised pre-training, opening new possibilities for learning from the vast amounts of unlabeled data available in the real world.