Domain Adaptation
Adapt embeddings from source to target domains while preserving knowledge
Best viewed on desktop for optimal interactive experience
Domain Adaptation
Domain adaptation enables models trained on one domain (source) to perform well on a different but related domain (target). This is crucial when labeled data is scarce in the target domain but abundant in a related source domain.
Interactive Adaptation Simulator
Domain Adaptation
Adapt embeddings from source to target domains while preserving knowledge
Adaptation Configuration
Feature Distribution Shift
Adaptation Performance
Adaptation Methods
Fine-tuning
Continue training on target domain
Adapter Layers
Add domain-specific adapter modules
Elastic Weight Consolidation
Preserve important weights
Multi-task Learning
Train on multiple domains jointly
Domain Prompting
Use domain-specific prompts
Implementation Examples
Domain-Adaptive Fine-tuning
# Gradual unfreezing
def adaptive_finetune(model, data):
# Freeze all layers initially
for param in model.parameters():
param.requires_grad = False
# Unfreeze top layers first
for layer in model.layers[-2:]:
for param in layer.parameters():
param.requires_grad = True
# Train with small learning rate
optimizer = AdamW(
model.parameters(),
lr=1e-5
)
# Gradually unfreeze more layers
for epoch in range(num_epochs):
if epoch % 5 == 0:
unfreeze_next_layer(model)
train_epoch(model, data, optimizer)
Domain Adversarial Training
class DANN(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.task_classifier = nn.Linear(768, n_classes)
self.domain_classifier = nn.Sequential(
GradientReversal(),
nn.Linear(768, 128),
nn.ReLU(),
nn.Linear(128, 2)
)
def forward(self, x):
features = self.encoder(x)
task_output = self.task_classifier(features)
domain_output = self.domain_classifier(features)
return task_output, domain_output
Domain Adaptation Best Practices
Strategies
- • Start with pre-trained general model
- • Use gradual unfreezing
- • Apply regularization to prevent forgetting
- • Mix source and target data when available
- • Monitor performance on both domains
Common Pitfalls
- • Catastrophic forgetting of source knowledge
- • Overfitting to small target datasets
- • Negative transfer from unrelated domains
- • Distribution mismatch during inference
- • Insufficient adaptation epochs
The Domain Shift Problem
Distribution Mismatch
When we deploy models in new domains, we encounter:
- Covariate Shift: P(X) changes but P(Y|X) remains same
- Label Shift: P(Y) changes but P(X|Y) remains same
- Concept Drift: P(Y|X) changes over time
Real-World Examples
Source Domain | Target Domain | Challenge |
---|---|---|
General Web Text | Medical Records | Specialized terminology |
News Articles | Social Media | Informal language |
English Reviews | Spanish Reviews | Language + culture |
Synthetic Data | Real Sensors | Noise patterns |
Adaptation Strategies
1. Fine-Tuning
The simplest approach - continue training on target data:
def fine_tune(model, source_data, target_data, config): # Pre-train on source domain model.fit(source_data, epochs=config.source_epochs) # Fine-tune on target domain with smaller learning rate optimizer = Adam(lr=config.lr * 0.1) for epoch in range(config.target_epochs): # Optional: mix source and target data if config.mix_ratio > 0: batch = mix_batches(source_data, target_data, config.mix_ratio) else: batch = target_data.sample() loss = model.train_step(batch) # Early stopping based on target validation if should_stop(loss, patience=5): break return model
2. Adapter Layers
Parameter-efficient adaptation without forgetting:
class AdapterLayer(nn.Module): def __init__(self, hidden_size, adapter_size=64): super().__init__() self.down_project = nn.Linear(hidden_size, adapter_size) self.up_project = nn.Linear(adapter_size, hidden_size) self.activation = nn.ReLU() def forward(self, x): # Keep original path residual = x # Adapter path x = self.down_project(x) x = self.activation(x) x = self.up_project(x) # Residual connection return residual + x class AdapterBERT(nn.Module): def __init__(self, bert_model): super().__init__() self.bert = bert_model # Freeze BERT parameters for param in self.bert.parameters(): param.requires_grad = False # Add adapters to each layer self.adapters = nn.ModuleList([ AdapterLayer(768) for _ in range(12) ])
3. Elastic Weight Consolidation (EWC)
Prevents catastrophic forgetting:
class EWC: def __init__(self, model, source_data, lambda_ewc=0.4): self.model = model self.lambda_ewc = lambda_ewc # Compute Fisher Information Matrix self.fisher = self.compute_fisher(source_data) # Store optimal source parameters self.optimal_params = { name: param.clone() for name, param in model.named_parameters() } def compute_fisher(self, data): """Estimate importance of each parameter""" fisher = {} model.eval() for batch in data: model.zero_grad() output = model(batch.input) loss = F.cross_entropy(output, batch.target) loss.backward() for name, param in model.named_parameters(): if param.grad is not None: if name not in fisher: fisher[name] = param.grad.data.clone() ** 2 else: fisher[name] += param.grad.data.clone() ** 2 # Normalize for name in fisher: fisher[name] /= len(data) return fisher def penalty(self): """EWC penalty term""" loss = 0 for name, param in self.model.named_parameters(): if name in self.fisher: loss += (self.fisher[name] * (param - self.optimal_params[name]) ** 2).sum() return self.lambda_ewc * loss
4. Domain-Adversarial Training (DANN)
Learn domain-invariant features:
class DANN(nn.Module): def __init__(self, feature_extractor, task_classifier, domain_classifier): super().__init__() self.feature_extractor = feature_extractor self.task_classifier = task_classifier self.domain_classifier = domain_classifier self.gradient_reversal = GradientReversal() def forward(self, x, alpha=1.0): # Extract features features = self.feature_extractor(x) # Task prediction task_output = self.task_classifier(features) # Domain prediction with gradient reversal reversed_features = self.gradient_reversal(features, alpha) domain_output = self.domain_classifier(reversed_features) return task_output, domain_output def train_step(self, source_batch, target_batch): # Process source domain src_task, src_domain = self(source_batch.x) task_loss = F.cross_entropy(src_task, source_batch.y) src_domain_loss = F.binary_cross_entropy( src_domain, torch.zeros_like(src_domain) ) # Process target domain (no task labels) _, tgt_domain = self(target_batch.x) tgt_domain_loss = F.binary_cross_entropy( tgt_domain, torch.ones_like(tgt_domain) ) # Combined loss total_loss = task_loss + src_domain_loss + tgt_domain_loss return total_loss
Advanced Techniques
1. Self-Training / Pseudo-Labeling
Use model predictions as labels:
def self_training(model, source_data, target_data, threshold=0.9): # Initial training on source model.fit(source_data) for iteration in range(num_iterations): # Generate pseudo-labels for target pseudo_labels = [] for batch in target_data: predictions = model.predict(batch) confidence = predictions.max(dim=1)[0] # Only use high-confidence predictions mask = confidence > threshold if mask.any(): pseudo_labels.append({ 'x': batch[mask], 'y': predictions[mask].argmax(dim=1) }) # Retrain with pseudo-labels combined_data = source_data + pseudo_labels model.fit(combined_data) # Gradually decrease threshold threshold *= 0.95
2. Maximum Mean Discrepancy (MMD)
Minimize distribution distance:
def mmd_loss(source_features, target_features, kernel='rbf'): """Maximum Mean Discrepancy for domain alignment""" def rbf_kernel(x, y, gamma=1.0): """RBF kernel for MMD""" xx = torch.matmul(x, x.t()) yy = torch.matmul(y, y.t()) xy = torch.matmul(x, y.t()) rx = xx.diag().unsqueeze(0).expand_as(xx) ry = yy.diag().unsqueeze(0).expand_as(yy) dxx = rx.t() + rx - 2 * xx dyy = ry.t() + ry - 2 * yy dxy = rx.t() + ry - 2 * xy return torch.exp(-gamma * dxx), \ torch.exp(-gamma * dyy), \ torch.exp(-gamma * dxy) kxx, kyy, kxy = rbf_kernel(source_features, target_features) mmd = kxx.mean() + kyy.mean() - 2 * kxy.mean() return mmd
Evaluation Strategies
1. Target Domain Performance
Primary metric - accuracy on target test set:
def evaluate_adaptation(model, source_test, target_test): results = { 'source_accuracy': model.evaluate(source_test), 'target_accuracy': model.evaluate(target_test), 'adaptation_gap': None } # Compute adaptation effectiveness baseline_accuracy = train_from_scratch(target_train).evaluate(target_test) results['adaptation_gap'] = results['target_accuracy'] - baseline_accuracy return results
2. Feature Alignment Metrics
Measure distribution alignment:
def compute_alignment_metrics(source_features, target_features): metrics = {} # A-distance (proxy) metrics['a_distance'] = compute_a_distance( source_features, target_features ) # Correlation alignment cs = torch.matmul(source_features.t(), source_features) ct = torch.matmul(target_features.t(), target_features) metrics['coral_loss'] = torch.norm(cs - ct, 'fro') ** 2 # Earth Mover's Distance metrics['emd'] = wasserstein_distance( source_features, target_features ) return metrics
Best Practices
1. Data Considerations
- Data Quality: Clean target data is crucial
- Data Quantity: Even small target datasets help
- Data Diversity: Cover target domain variations
2. Training Strategies
# Gradual unfreezing def gradual_unfreeze(model, target_data): layers = list(model.children()) for i in range(len(layers)): # Unfreeze from top to bottom for j in range(len(layers) - i, len(layers)): for param in layers[j].parameters(): param.requires_grad = True # Train for a few epochs train_epochs(model, target_data, epochs=2)
3. Hyperparameter Guidelines
Parameter | Recommended Range | Notes |
---|---|---|
Learning Rate | 1e-5 to 1e-4 | Lower than source training |
Batch Size | 8-32 | Smaller for limited target data |
Epochs | 3-10 | Avoid overfitting |
Warmup Steps | 10% of total | Stabilize training |
Mix Ratio | 0.1-0.3 | Source:Target ratio |
Common Pitfalls
1. Catastrophic Forgetting
Model forgets source knowledge:
Solutions:
- Use adapter layers
- Apply EWC or similar regularization
- Mix source and target data
- Lower learning rates
2. Negative Transfer
Source hurts target performance:
Solutions:
- Careful source selection
- Domain-adversarial training
- Start from general pre-trained models
3. Overfitting to Target
Model memorizes small target dataset:
Solutions:
- Strong regularization
- Data augmentation
- Early stopping
- Ensemble methods
Future Directions
Emerging Techniques
- Meta-Learning: Learn to adapt quickly
- Continuous Adaptation: Online learning in deployment
- Multi-Source Adaptation: Leverage multiple source domains
- Test-Time Adaptation: Adapt during inference
Research Challenges
- Unsupervised domain adaptation
- Open-set domain adaptation
- Domain generalization
- Federated domain adaptation
Conclusion
Domain adaptation is essential for deploying models in real-world scenarios where training and deployment distributions differ. The interactive visualization above demonstrates various adaptation techniques and their effects on model performance across domains.
Success in domain adaptation requires careful consideration of the domain gap, appropriate technique selection, and thorough evaluation. As models become more powerful, effective domain adaptation becomes increasingly critical for practical AI systems.