Focal Loss for Imbalanced Classification

Master focal loss, the game-changing loss function that addresses extreme class imbalance by down-weighting easy examples and focusing on hard negatives.

Best viewed on desktop for optimal interactive experience

Focal Loss: Solving Extreme Class Imbalance

Focal Loss revolutionized object detection by addressing the extreme foreground-background class imbalance in dense detectors, enabling one-stage detectors to match two-stage detector accuracy.

Interactive Focal Loss Explorer

Visualize how focal loss dynamically scales the cross-entropy loss based on prediction confidence:

Focal Loss Parameters

Adjust parameters to see how focal loss down-weights easy examples

Focusing parameter (0 = CE, higher = more focus)
Class weight (for imbalanced datasets)
Model's predicted probability for correct class
Cross-Entropy Loss
0.1054
Standard loss
Focal Loss
0.0003
Down-weighted by 1.0%
Modulating Factor
0.0100
(1 - p)γ

Class Imbalance Effect

Example Scenarios

Easy Example

p = 0.95 (well-classified)
CE Loss: 0.051
Focal Loss: 0.000
Focal loss is 100% lower

Hard Example

p = 0.3 (misclassified)
CE Loss: 1.204
Focal Loss: 0.147
Focal loss preserves most of the gradient

Key Insights

When γ = 0, focal loss reduces to cross-entropy
Higher γ increases focus on hard examples
α balances positive/negative class importance
Effective for extreme class imbalance (1:1000+)

The Problem: Extreme Class Imbalance

In dense object detection:

  • 100,000+ candidate locations per image
  • Less than 10 are positive (contain objects)
  • Imbalance ratio of 1:10,000
  • Easy negatives dominate the gradient
  • Hard examples get lost in the noise

The Focal Loss Solution

Core Formula

Focal loss adds a modulating factor to cross-entropy:

FL(pt) = -αt (1 - pt)^γ log(pt)

Where:

  • pt is the model's estimated probability for the correct class
  • γ ≥ 0 is the focusing parameter
  • αt ∈ [0, 1] is the weighting factor

Breaking Down the Components

1. Standard Cross-Entropy

CE(p, y) = -log(pt)

Where:

pt = \begin{cases} p & \text{if } y = 1 \ 1 - p & \text{if } y = 0 \end{cases}

2. Modulating Factor

(1 - pt)^γ
  • When pt is small (hard example): factor ≈ 1
  • When pt is large (easy example): factor ≈ 0

3. Balanced Variant

FL(pt) = -αt (1 - pt)^γ log(pt)

Implementation

PyTorch Implementation

import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): """ Focal Loss for dense object detection Args: alpha: Weighting factor in [0, 1] to balance positive/negative examples gamma: Focusing parameter for modulating loss reduction: 'none' | 'mean' | 'sum' """ def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): """ Args: inputs: Predictions (logits or probabilities) targets: Ground truth labels (0 or 1) """ # Convert logits to probabilities p = torch.sigmoid(inputs) # Calculate cross entropy ce_loss = F.binary_cross_entropy_with_logits( inputs, targets, reduction='none' ) # Get p_t p_t = p * targets + (1 - p) * (1 - targets) # Calculate focal loss focal_loss = ce_loss * ((1 - p_t) ** self.gamma) # Apply alpha weighting if self.alpha >= 0: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) focal_loss = alpha_t * focal_loss # Reduce loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss # Multi-class Focal Loss class MultiClassFocalLoss(nn.Module): """Focal loss for multi-class classification""" def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super().__init__() self.gamma = gamma self.alpha = alpha # Can be None or tensor of class weights self.reduction = reduction def forward(self, inputs, targets): """ Args: inputs: [batch_size, num_classes] logits targets: [batch_size] class indices """ # Calculate cross entropy ce_loss = F.cross_entropy(inputs, targets, reduction='none') # Get probability of correct class p = F.softmax(inputs, dim=1) p_t = p.gather(1, targets.view(-1, 1)).squeeze(1) # Calculate focal term focal_term = (1 - p_t) ** self.gamma # Calculate focal loss focal_loss = focal_term * ce_loss # Apply alpha weighting if provided if self.alpha is not None: if isinstance(self.alpha, (float, int)): alpha_t = self.alpha else: # Assuming alpha is a tensor with class weights alpha_t = self.alpha.gather(0, targets) focal_loss = alpha_t * focal_loss # Reduce if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss

TensorFlow/Keras Implementation

import tensorflow as tf from tensorflow import keras import tensorflow.keras.backend as K def focal_loss(alpha=0.25, gamma=2.0): """ Create a focal loss function Args: alpha: Balancing parameter gamma: Focusing parameter Returns: Loss function for Keras model.compile() """ def focal_loss_fixed(y_true, y_pred): # Clip prediction to prevent log(0) epsilon = K.epsilon() y_pred = K.clip(y_pred, epsilon, 1. - epsilon) # Calculate p_t p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred) # Calculate alpha_t alpha_factor = K.ones_like(y_true) * alpha alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) # Calculate focal loss focal_weight = alpha_t * K.pow((1 - p_t), gamma) cross_entropy = -K.log(p_t) focal_loss = focal_weight * cross_entropy return K.mean(focal_loss) return focal_loss_fixed # Usage in model model = keras.Sequential([ keras.layers.Dense(128, activation='relu'), keras.layers.Dense(1, activation='sigmoid') ]) model.compile( optimizer='adam', loss=focal_loss(alpha=0.25, gamma=2.0), metrics=['accuracy'] )

RetinaNet: Focal Loss in Action

Architecture Integration

class RetinaNet(nn.Module): """RetinaNet: One-stage detector with Focal Loss""" def __init__(self, num_classes, backbone='resnet50'): super().__init__() self.backbone = self._build_backbone(backbone) self.fpn = FeaturePyramidNetwork() self.classification_head = ClassificationSubnet(num_classes) self.regression_head = RegressionSubnet() # Initialize classification head with focal loss bias self._init_focal_loss_bias() def _init_focal_loss_bias(self, prior_prob=0.01): """ Initialize bias to predict background with high probability This prevents instability in early training with focal loss """ for module in self.classification_head.modules(): if isinstance(module, nn.Conv2d) and module.out_channels == self.num_classes: # Set bias such that sigmoid(bias) = prior_prob bias_value = -math.log((1 - prior_prob) / prior_prob) torch.nn.init.constant_(module.bias, bias_value) def compute_loss(self, predictions, targets): cls_preds, reg_preds = predictions cls_targets, reg_targets = targets # Focal loss for classification focal_loss = FocalLoss(alpha=0.25, gamma=2.0) cls_loss = focal_loss(cls_preds, cls_targets) # Smooth L1 loss for regression reg_loss = F.smooth_l1_loss( reg_preds[cls_targets > 0], reg_targets[cls_targets > 0] ) return cls_loss + reg_loss

Hyperparameter Selection

Gamma (γ) Selection

def analyze_gamma_effect(gammas=[0, 0.5, 1, 2, 5]): """Analyze effect of gamma on loss distribution""" probabilities = torch.linspace(0.1, 0.9, 100) for gamma in gammas: focal_weights = (1 - probabilities) ** gamma ce_loss = -torch.log(probabilities) focal_loss = focal_weights * ce_loss print(f"Gamma={gamma}:") print(f" Easy (p=0.9) weight: {(1-0.9)**gamma:.4f}") print(f" Hard (p=0.3) weight: {(1-0.3)**gamma:.4f}") print(f" Weight ratio: {((1-0.3)**gamma) / ((1-0.9)**gamma):.1f}x")

Alpha (α) Selection

def calculate_alpha_from_frequency(class_frequencies): """ Calculate alpha weights from class frequencies Args: class_frequencies: List of sample counts per class Returns: Alpha weights for focal loss """ total = sum(class_frequencies) num_classes = len(class_frequencies) # Inverse frequency weighting alphas = [] for freq in class_frequencies: alpha = 1.0 - (freq / total) # Clip to reasonable range alpha = max(0.1, min(0.9, alpha)) alphas.append(alpha) return torch.tensor(alphas) # Example: COCO dataset # ~100k background windows, ~10 foreground background_samples = 100000 foreground_samples = 10 alpha = calculate_alpha_from_frequency([background_samples, foreground_samples]) print(f"Alpha weights: {alpha}") # [0.1, 0.9] approximately

Advanced Techniques

Online Hard Example Mining (OHEM) + Focal Loss

class FocalLossWithOHEM(nn.Module): """Combine Focal Loss with Online Hard Example Mining""" def __init__(self, alpha=0.25, gamma=2.0, ohem_ratio=3): super().__init__() self.focal_loss = FocalLoss(alpha, gamma, reduction='none') self.ohem_ratio = ohem_ratio def forward(self, predictions, targets): # Calculate per-sample focal loss losses = self.focal_loss(predictions, targets) # Separate positive and negative samples pos_mask = targets > 0 neg_mask = ~pos_mask pos_losses = losses[pos_mask] neg_losses = losses[neg_mask] # Keep all positive losses num_pos = pos_losses.numel() # Select hard negatives (OHEM) num_neg_keep = min(neg_losses.numel(), num_pos * self.ohem_ratio) if num_neg_keep > 0: neg_losses_sorted, _ = neg_losses.sort(descending=True) neg_losses_keep = neg_losses_sorted[:num_neg_keep] else: neg_losses_keep = neg_losses # Combine losses total_loss = torch.cat([pos_losses, neg_losses_keep]).mean() return total_loss

Class-Balanced Focal Loss

class ClassBalancedFocalLoss(nn.Module): """ Class-Balanced Focal Loss from "Class-Balanced Loss Based on Effective Number of Samples" """ def __init__(self, samples_per_class, beta=0.999, gamma=2.0): super().__init__() self.gamma = gamma self.beta = beta # Calculate effective number of samples effective_num = 1.0 - torch.pow(beta, samples_per_class) weights = (1.0 - beta) / effective_num self.weights = weights / weights.sum() * len(weights) def forward(self, inputs, targets): # Standard focal loss ce_loss = F.cross_entropy(inputs, targets, reduction='none') p = F.softmax(inputs, dim=1) p_t = p.gather(1, targets.view(-1, 1)).squeeze(1) focal_weight = (1 - p_t) ** self.gamma focal_loss = focal_weight * ce_loss # Apply class balancing weights_t = self.weights.gather(0, targets) balanced_focal_loss = weights_t * focal_loss return balanced_focal_loss.mean()

Adaptive Focal Loss

class AdaptiveFocalLoss(nn.Module): """Dynamically adjust gamma based on training progress""" def __init__(self, alpha=0.25, gamma_init=2.0, gamma_final=0.5): super().__init__() self.alpha = alpha self.gamma_init = gamma_init self.gamma_final = gamma_final self.current_epoch = 0 self.total_epochs = 100 def update_epoch(self, epoch, total_epochs): self.current_epoch = epoch self.total_epochs = total_epochs @property def gamma(self): # Linear decay of gamma progress = self.current_epoch / self.total_epochs return self.gamma_init + (self.gamma_final - self.gamma_init) * progress def forward(self, inputs, targets): p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') p_t = p * targets + (1 - p) * (1 - targets) # Use current gamma focal_weight = (1 - p_t) ** self.gamma focal_loss = self.alpha * focal_weight * ce_loss return focal_loss.mean()

Practical Tips

1. Initialization Strategy

def initialize_for_focal_loss(model, prior_prob=0.01): """ Initialize model for stable focal loss training Args: model: Classification model prior_prob: Prior probability of positive class """ for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): # Xavier/Kaiming initialization for conv layers nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') # Special initialization for final classification layer if 'cls' in name and module.out_channels == num_classes: # Bias initialization for focal loss bias_value = -math.log((1 - prior_prob) / prior_prob) nn.init.constant_(module.bias, bias_value) # Smaller weight initialization nn.init.normal_(module.weight, std=0.01)

2. Learning Rate Scheduling

def get_focal_loss_scheduler(optimizer, num_epochs): """Learning rate schedule optimized for focal loss""" def lr_lambda(epoch): if epoch < 10: # Warmup for focal loss stability return 0.1 + 0.9 * epoch / 10 elif epoch < num_epochs * 0.8: # Constant learning rate return 1.0 else: # Decay return 0.1 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

3. Monitoring Training

class FocalLossMonitor: """Monitor focal loss training dynamics""" def __init__(self): self.easy_losses = [] self.hard_losses = [] def update(self, predictions, targets, threshold=0.7): with torch.no_grad(): p = torch.sigmoid(predictions) p_t = p * targets + (1 - p) * (1 - targets) easy_mask = p_t > threshold hard_mask = ~easy_mask focal_loss = FocalLoss(reduction='none') losses = focal_loss(predictions, targets) if easy_mask.any(): self.easy_losses.append(losses[easy_mask].mean().item()) if hard_mask.any(): self.hard_losses.append(losses[hard_mask].mean().item()) def get_ratio(self): """Get hard/easy loss ratio""" if self.easy_losses and self.hard_losses: return np.mean(self.hard_losses) / np.mean(self.easy_losses) return 0

Common Pitfalls and Solutions

1. Training Instability

# Problem: Loss explodes in early training # Solution: Proper initialization model.apply(lambda m: initialize_for_focal_loss(m, prior_prob=0.01)) # Problem: Gradient vanishing for easy examples # Solution: Use gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

2. Hyperparameter Sensitivity

# Problem: Wrong gamma leads to poor performance # Solution: Grid search for gamma in [0.5, 1.0, 2.0, 5.0]: for alpha in [0.25, 0.5, 0.75]: model = train_model(gamma=gamma, alpha=alpha) print(f"γ={gamma}, α={alpha}: mAP={evaluate(model)}")

3. Multi-Scale Training

# Problem: Focal loss sensitive to object scale # Solution: Multi-scale training def multi_scale_focal_loss(predictions, targets, scales=[1.0, 0.5, 2.0]): total_loss = 0 for scale in scales: scaled_preds = F.interpolate(predictions, scale_factor=scale) scaled_targets = F.interpolate(targets.float(), scale_factor=scale) loss = FocalLoss()(scaled_preds, scaled_targets.long()) total_loss += loss / len(scales) return total_loss

Performance Impact

Detection Results (COCO Dataset)

MethodBackboneLossAPAP50AP75FPS
Faster R-CNNResNet-101CE36.259.139.05
SSDResNet-101CE31.250.433.322
RetinaNetResNet-101Focal39.159.142.318
FCOSResNet-101Focal41.060.744.120

If you found this explanation helpful, consider sharing it with others.

Mastodon