KL Divergence
Understand Kullback-Leibler divergence, the fundamental measure of difference between probability distributions used in VAEs, information theory, and model compression.
Best viewed on desktop for optimal interactive experience
KL Divergence: Measuring Distribution Differences
Kullback-Leibler (KL) divergence quantifies how one probability distribution differs from another, serving as a cornerstone in variational inference, generative modeling, and information theory.
Interactive KL Divergence Explorer
Visualize how KL divergence measures the "distance" between distributions:
KL Divergence Configuration
Explore how KL divergence measures the difference between probability distributions
Distribution P (True)
Distribution Q (Approximation)
Probability Distributions
Divergence Visualization
KL Divergence Properties
Key Properties
- • Non-negative: KL(P||Q) ≥ 0
- • Zero iff P = Q exactly
- • Asymmetric: KL(P||Q) ≠ KL(Q||P)
- • Not a true metric (no triangle inequality)
- • Unbounded (can be infinite)
Applications
- • Variational inference (ELBO)
- • VAE regularization
- • Information theory
- • Model compression
- • Distribution matching
Mathematical Definition
Discrete Distributions
For discrete probability distributions P and Q:
Continuous Distributions
For continuous probability densities:
Information-Theoretic Interpretation
KL divergence equals the expected extra bits needed to encode data from P using a code optimized for Q:
Where:
- H(P, Q) is the cross-entropy
- H(P) is the entropy of P
Implementation
PyTorch Implementation
import torch import torch.nn as nn import torch.nn.functional as F class KLDivergenceLoss(nn.Module): """ KL Divergence loss for various distributions """ def __init__(self, reduction='batchmean'): super().__init__() self.reduction = reduction def forward(self, log_p, log_q=None, p=None, q=None): """ Compute KL(P||Q) in different formats Args: log_p: Log probabilities of P log_q: Log probabilities of Q (optional) p: Probabilities of P (optional) q: Probabilities of Q (optional) """ if log_p is not None and log_q is not None: # Both in log space kl_div = torch.exp(log_p) * (log_p - log_q) elif p is not None and q is not None: # Both in probability space kl_div = p * torch.log(p / (q + 1e-10)) elif log_p is not None and q is not None: # Mixed format kl_div = torch.exp(log_p) * (log_p - torch.log(q + 1e-10)) else: raise ValueError("Invalid input format") # Handle reduction if self.reduction == 'none': return kl_div elif self.reduction == 'sum': return kl_div.sum() elif self.reduction == 'mean': return kl_div.mean() elif self.reduction == 'batchmean': return kl_div.sum() / kl_div.size(0) return kl_div # Using PyTorch's built-in KL divergence def kl_divergence_builtin(p_logits, q_logits): """Using F.kl_div for log-space computation""" log_p = F.log_softmax(p_logits, dim=-1) log_q = F.log_softmax(q_logits, dim=-1) # Note: F.kl_div expects (input, target) = (log_q, p) # where target is in probability space return F.kl_div(log_q, torch.exp(log_p), reduction='batchmean') # Analytical KL for Gaussians def kl_divergence_gaussian(mu_p, logvar_p, mu_q=None, logvar_q=None): """ Analytical KL divergence between two Gaussians If mu_q and logvar_q are None, assumes standard normal N(0,1) """ if mu_q is None: mu_q = torch.zeros_like(mu_p) if logvar_q is None: logvar_q = torch.zeros_like(logvar_p) # KL(N(μ_p, σ_p) || N(μ_q, σ_q)) kl = 0.5 * ( logvar_q - logvar_p - 1 + torch.exp(logvar_p - logvar_q) + ((mu_p - mu_q) ** 2) * torch.exp(-logvar_q) ) return kl.sum(dim=-1).mean()
TensorFlow/Keras Implementation
import tensorflow as tf from tensorflow import keras import tensorflow_probability as tfp def kl_divergence_loss(y_true, y_pred): """ KL divergence loss for Keras Expects y_true to be probabilities and y_pred to be log probabilities """ y_true = tf.clip_by_value(y_true, 1e-10, 1.0) return tf.reduce_mean( tf.reduce_sum(y_true * (tf.math.log(y_true) - y_pred), axis=-1) ) # Using TensorFlow Probability def kl_divergence_distributions(dist_p, dist_q): """ KL divergence between TFP distributions """ return tfp.distributions.kl_divergence(dist_p, dist_q) # Analytical KL for common distributions def analytical_kl_normal(mu1, sigma1, mu2, sigma2): """Analytical KL between two normal distributions""" kl = tf.math.log(sigma2 / sigma1) + \ (sigma1 ** 2 + (mu1 - mu2) ** 2) / (2 * sigma2 ** 2) - 0.5 return kl def analytical_kl_categorical(p, q): """KL between categorical distributions""" p = tf.clip_by_value(p, 1e-10, 1.0) q = tf.clip_by_value(q, 1e-10, 1.0) return tf.reduce_sum(p * tf.math.log(p / q), axis=-1)
VAE: KL Divergence in Action
VAE Loss with KL Regularization
class VAE(nn.Module): """ Variational Autoencoder with KL divergence regularization """ def __init__(self, input_dim, hidden_dim, latent_dim): super().__init__() # Encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) # Latent space parameters self.mu_layer = nn.Linear(hidden_dim, latent_dim) self.logvar_layer = nn.Linear(hidden_dim, latent_dim) # Decoder self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) def encode(self, x): h = self.encoder(x) mu = self.mu_layer(h) logvar = self.logvar_layer(h) return mu, logvar def reparameterize(self, mu, logvar): """Reparameterization trick""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): return self.decoder(z) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar def vae_loss(recon_x, x, mu, logvar, beta=1.0): """ VAE loss = Reconstruction loss + β * KL divergence Args: recon_x: Reconstructed input x: Original input mu: Mean of latent distribution logvar: Log variance of latent distribution beta: Weight for KL term (β-VAE) """ # Reconstruction loss (negative log likelihood) recon_loss = F.mse_loss(recon_x, x, reduction='sum') # KL divergence between q(z|x) and p(z) = N(0,1) # Analytical formula for Gaussian KL kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Total loss return recon_loss + beta * kl_loss # Training VAE vae = VAE(784, 400, 20) optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) for epoch in range(epochs): for batch in dataloader: optimizer.zero_grad() # Forward pass recon, mu, logvar = vae(batch) # Compute loss loss = vae_loss(recon, batch, mu, logvar, beta=1.0) # Backward pass loss.backward() optimizer.step()
Forward vs Reverse KL
Understanding the Asymmetry
def compare_kl_directions(p_samples, q_model): """ Compare forward and reverse KL divergence """ # Forward KL: KL(P||Q) - mean-seeking # Q tries to cover all modes of P forward_kl = compute_forward_kl(p_samples, q_model) # Reverse KL: KL(Q||P) - mode-seeking # Q focuses on high-probability regions of P reverse_kl = compute_reverse_kl(q_model, p_samples) return { 'forward': forward_kl, # Used in: Maximum likelihood 'reverse': reverse_kl, # Used in: Variational inference 'symmetric': (forward_kl + reverse_kl) / 2 # JS divergence } class ForwardKL: """ Forward KL: Minimize KL(data||model) Mean-seeking behavior - covers all data modes """ def fit(self, data_samples): # Maximum likelihood estimation return -torch.mean(self.model.log_prob(data_samples)) class ReverseKL: """ Reverse KL: Minimize KL(model||data) Mode-seeking behavior - focuses on single mode """ def fit(self, data_distribution): # Variational inference samples = self.model.sample(n_samples) return -torch.mean(data_distribution.log_prob(samples))
JS Divergence: Symmetric Alternative
def jensen_shannon_divergence(p, q): """ Jensen-Shannon divergence: symmetric version of KL JS(P||Q) = 0.5 * KL(P||M) + 0.5 * KL(Q||M) where M = 0.5 * (P + Q) """ m = 0.5 * (p + q) # Compute both KL divergences kl_pm = torch.sum(p * torch.log(p / (m + 1e-10) + 1e-10), dim=-1) kl_qm = torch.sum(q * torch.log(q / (m + 1e-10) + 1e-10), dim=-1) # JS divergence js = 0.5 * kl_pm + 0.5 * kl_qm return js def js_divergence_gaussian(mu1, sigma1, mu2, sigma2): """ JS divergence between two Gaussians (no closed form) Uses Monte Carlo approximation """ n_samples = 1000 # Sample from both distributions dist1 = torch.distributions.Normal(mu1, sigma1) dist2 = torch.distributions.Normal(mu2, sigma2) samples = torch.linspace(mu1 - 4*sigma1, mu2 + 4*sigma2, n_samples) # Compute PDFs p = torch.exp(dist1.log_prob(samples)) q = torch.exp(dist2.log_prob(samples)) return jensen_shannon_divergence(p, q)
Applications
1. Model Distillation
class KnowledgeDistillation: """ Distill knowledge from teacher to student using KL divergence """ def __init__(self, teacher, student, temperature=3.0, alpha=0.7): self.teacher = teacher self.student = student self.temperature = temperature self.alpha = alpha def distillation_loss(self, inputs, labels): # Teacher predictions (soft targets) with torch.no_grad(): teacher_logits = self.teacher(inputs) soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1) # Student predictions student_logits = self.student(inputs) soft_predictions = F.log_softmax(student_logits / self.temperature, dim=-1) # KL divergence loss (scaled by T^2) kl_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean') kl_loss *= self.temperature ** 2 # Standard cross-entropy loss ce_loss = F.cross_entropy(student_logits, labels) # Combined loss return self.alpha * kl_loss + (1 - self.alpha) * ce_loss
2. Variational Inference
class VariationalInference: """ Approximate posterior using variational inference """ def __init__(self, prior, likelihood): self.prior = prior self.likelihood = likelihood # Variational parameters self.var_mu = nn.Parameter(torch.randn(latent_dim)) self.var_logvar = nn.Parameter(torch.randn(latent_dim)) def elbo(self, x): """ Evidence Lower Bound (ELBO) ELBO = E_q[log p(x|z)] - KL(q(z|x)||p(z)) """ # Sample from variational distribution q_dist = torch.distributions.Normal( self.var_mu, torch.exp(0.5 * self.var_logvar) ) z = q_dist.rsample() # Reconstruction term log_likelihood = self.likelihood.log_prob(x, z) # KL divergence term kl_div = kl_divergence_gaussian( self.var_mu, self.var_logvar, self.prior.mu, self.prior.logvar ) # ELBO (to maximize) elbo = log_likelihood - kl_div return -elbo # Return negative for minimization
3. Mutual Information Estimation
class MutualInformationEstimator: """ Estimate mutual information using KL divergence I(X;Y) = KL(P(X,Y)||P(X)P(Y)) """ def __init__(self, method='mine'): self.method = method def estimate_mi(self, x, y): if self.method == 'mine': return self.mine_estimator(x, y) elif self.method == 'nwj': return self.nwj_estimator(x, y) elif self.method == 'infonce': return self.infonce_estimator(x, y) def mine_estimator(self, x, y): """ Mutual Information Neural Estimation """ # Network to estimate MI T = nn.Sequential( nn.Linear(x.shape[-1] + y.shape[-1], 100), nn.ReLU(), nn.Linear(100, 1) ) # Joint samples joint = torch.cat([x, y], dim=-1) # Marginal samples (shuffle y) y_shuffle = y[torch.randperm(y.shape[0])] marginal = torch.cat([x, y_shuffle], dim=-1) # MINE lower bound t_joint = T(joint) t_marginal = T(marginal) mi_lower_bound = torch.mean(t_joint) - torch.log(torch.mean(torch.exp(t_marginal))) return mi_lower_bound
Practical Considerations
1. Numerical Stability
def stable_kl_divergence(log_p, log_q): """ Numerically stable KL divergence computation """ # Use log-sum-exp trick for stability max_val = torch.max(log_p, log_q) p = torch.exp(log_p - max_val) q = torch.exp(log_q - max_val) # Clip to avoid log(0) q = torch.clamp(q, min=1e-10) kl = torch.sum(p * (torch.log(p / q + 1e-10)), dim=-1) return kl
2. Approximating KL for Complex Distributions
def monte_carlo_kl(p_dist, q_dist, n_samples=1000): """ Monte Carlo approximation of KL divergence """ # Sample from P samples = p_dist.sample((n_samples,)) # Compute log probabilities log_p = p_dist.log_prob(samples) log_q = q_dist.log_prob(samples) # Monte Carlo estimate kl = torch.mean(log_p - log_q) return kl def importance_weighted_kl(p_dist, q_dist, n_samples=1000): """ Importance-weighted KL divergence estimation """ # Sample from Q (importance distribution) samples = q_dist.sample((n_samples,)) # Compute importance weights log_p = p_dist.log_prob(samples) log_q = q_dist.log_prob(samples) weights = torch.exp(log_p - log_q) # Weighted estimate kl = torch.mean(weights * (log_p - log_q)) return kl
3. Annealing KL in VAEs
class AnnealedVAE(VAE): """ VAE with KL annealing for better optimization """ def __init__(self, *args, anneal_steps=1000, **kwargs): super().__init__(*args, **kwargs) self.anneal_steps = anneal_steps self.current_step = 0 def get_kl_weight(self): """Anneal KL weight from 0 to 1""" return min(1.0, self.current_step / self.anneal_steps) def loss(self, x, recon, mu, logvar): # Reconstruction loss recon_loss = F.mse_loss(recon, x, reduction='sum') # KL loss with annealing kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kl_weight = self.get_kl_weight() # Update step counter self.current_step += 1 return recon_loss + kl_weight * kl_loss
Common Pitfalls
1. Zero Probability Issues
# Problem: KL diverges when Q(x) = 0 but P(x) > 0 # Solution: Add smoothing def smoothed_kl(p, q, epsilon=1e-10): q_smooth = (1 - epsilon) * q + epsilon * torch.ones_like(q) / q.shape[-1] return torch.sum(p * torch.log(p / q_smooth), dim=-1)
2. Choosing Direction
# Forward KL: Use when you want Q to cover all of P # Good for: density estimation, avoiding mode collapse # Reverse KL: Use when you want Q to focus on high-prob regions of P # Good for: variational inference, compression
3. Gradient Issues
# Problem: Gradients can vanish/explode with KL # Solution: Clip gradients and use appropriate scaling def train_with_kl(model, target_dist): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for step in range(training_steps): # Forward pass kl_loss = kl_divergence(model.dist, target_dist) # Backward with gradient clipping optimizer.zero_grad() kl_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
Related Concepts
- Cross-Entropy Loss - Related information measure
- VAE - Uses KL divergence regularization
- Information Theory - Theoretical foundation
- Variational Inference - Optimization via KL
- Model Compression - Knowledge distillation