Xavier/Glorot Initialization

Understand Xavier (Glorot) initialization, the weight initialization technique that maintains signal variance across layers for stable deep network training.

Best viewed on desktop for optimal interactive experience

Xavier/Glorot Initialization: Maintaining Signal Variance

Xavier initialization (also known as Glorot initialization) is a weight initialization technique designed to maintain approximately equal variance of activations and gradients across all layers of a deep neural network.

Interactive Initialization Explorer

Visualize how different initialization methods affect gradient flow:

Init Method:
Activation:
Depth:5

Network Activation Flow

Healthy (σ ≈ 1)
Vanishing (σ → 0)
Exploding (σ → ∞)

Weight Distribution

Xavier Principle
• Maintain variance: Var(input) ≈ Var(output)
• Uniform: ±√(6/(fan_in + fan_out))
• Normal: σ = √(2/(fan_in + fan_out))

Layer-wise Activation Statistics

LayerMeanStd DevMinMaxStatus
1-0.0070.683-1.0001.000Suboptimal
20.0000.511-0.9860.989Suboptimal
3-0.0060.427-0.9830.987Suboptimal
40.0050.371-0.9520.932Suboptimal
5-0.0030.328-0.8790.870Suboptimal

The Problem: Vanishing and Exploding Gradients

Without proper initialization, deep networks suffer from:

  1. Vanishing Gradients: Signals decay exponentially
  2. Exploding Gradients: Signals grow exponentially
  3. Dead Neurons: Activations saturate
  4. Slow Convergence: Poor initial conditions

The Xavier/Glorot Solution

The key insight: maintain equal variance of activations across layers.

Core Principle

For a linear layer with input x and weights W:

Var(y) = Var(x) · Var(W) · fanin

To maintain variance: Var(W) = 1fanin

Considering Backpropagation

For gradients flowing backward:

Var(∂ L∂ x) = Var(∂ L∂ y) · Var(W) · fanout

To maintain gradient variance: Var(W) = 1fanout

The Compromise

Xavier initialization averages both requirements:

Var(W) = 2fanin + fanout

Xavier Initialization Variants

1. Xavier Uniform

Draw weights from uniform distribution:

W ∼ U[-√(\frac{6){fanin + fanout}}, √(\frac{6){fanin + fanout}}]

2. Xavier Normal

Draw weights from normal distribution:

W ∼ N(0, √(\frac{2){fanin + fanout}})

Implementation

PyTorch Implementation

import torch import torch.nn as nn import math def xavier_uniform_(tensor, gain=1.0): """ Xavier uniform initialization Args: tensor: Tensor to initialize gain: Scaling factor for different activations """ fan_in, fan_out = calculate_fan_in_and_out(tensor) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) bound = math.sqrt(3.0) * std with torch.no_grad(): tensor.uniform_(-bound, bound) return tensor def xavier_normal_(tensor, gain=1.0): """ Xavier normal initialization Args: tensor: Tensor to initialize gain: Scaling factor for different activations """ fan_in, fan_out = calculate_fan_in_and_out(tensor) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) with torch.no_grad(): tensor.normal_(0, std) return tensor def calculate_fan_in_and_out(tensor): """Calculate fan_in and fan_out for a tensor""" dimensions = tensor.dim() if dimensions < 2: raise ValueError("Fan in and fan out can't be computed for tensor with fewer than 2 dimensions") num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim() > 2: # For convolutional layers receptive_field_size = tensor[0][0].numel() fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out # Using PyTorch's built-in initialization class XavierNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) # Apply Xavier initialization nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.xavier_uniform_(self.fc3.weight) # Initialize biases to zero nn.init.zeros_(self.fc1.bias) nn.init.zeros_(self.fc2.bias) nn.init.zeros_(self.fc3.bias) def forward(self, x): x = torch.tanh(self.fc1(x)) # Xavier works well with tanh x = torch.tanh(self.fc2(x)) x = self.fc3(x) return x

Gain Values for Different Activations

def get_gain(activation): """Get the recommended gain value for different activation functions""" gains = { 'linear': 1.0, 'sigmoid': 1.0, 'tanh': 5.0 / 3.0, # ~1.667 'relu': math.sqrt(2.0), # For He initialization 'leaky_relu': math.sqrt(2.0 / (1 + 0.01**2)), 'selu': 3.0 / 4.0 } return gains.get(activation, 1.0) # Example usage def initialize_layer(layer, activation='tanh'): gain = get_gain(activation) nn.init.xavier_uniform_(layer.weight, gain=gain) nn.init.zeros_(layer.bias)

Theoretical Foundation

Forward Pass Variance

For layer l with weights W(l):

Var(y(l)) = n(l-1) · Var(W(l)) · Var(y(l-1))

Where n(l-1) is the number of inputs to layer l.

Backward Pass Variance

For gradients:

Var(∂ L∂ y(l-1)) = n(l) · Var(W(l)) · Var(∂ L∂ y(l))

Optimal Variance

To maintain both forward and backward signal:

n(l-1) · Var(W(l)) = 1 \text{and} n(l) · Var(W(l)) = 1

Xavier's compromise:

Var(W(l)) = 2n(l-1) + n(l)

When to Use Xavier Initialization

Best For:

  • Tanh activations: Designed specifically for tanh
  • Sigmoid activations: Works well with sigmoid
  • Linear networks: Optimal for identity activation
  • Shallow networks: Excellent for 2-10 layers

Not Ideal For:

  • ReLU networks: Use He initialization instead
  • Very deep networks: Consider LSUV or careful normalization
  • Spiking networks: Different dynamics entirely

Comparison with Other Methods

MethodFormulaBest ForVariance Assumption
Xavier2nin + noutTanh, SigmoidSymmetric, linear
He2ninReLU, Leaky ReLUAsymmetric (ReLU)
LeCun1ninEfficient backpropForward pass only
OrthogonalQ from QRRNNsPreserve norm

Practical Tips

1. Check Activation Statistics

def check_activation_stats(model, input_batch): """Monitor activation statistics during forward pass""" activations = [] def hook_fn(module, input, output): activations.append({ 'mean': output.mean().item(), 'std': output.std().item(), 'min': output.min().item(), 'max': output.max().item() }) hooks = [] for module in model.modules(): if isinstance(module, nn.Linear): hooks.append(module.register_forward_hook(hook_fn)) with torch.no_grad(): model(input_batch) for hook in hooks: hook.remove() return activations

2. Adaptive Initialization

class AdaptiveXavierInit: """Adapt initialization based on network depth""" def __init__(self, depth_factor=1.0): self.depth_factor = depth_factor def initialize(self, module, layer_idx, total_layers): if isinstance(module, nn.Linear): # Scale variance based on depth depth_scale = math.sqrt(total_layers / (layer_idx + 1)) gain = self.depth_factor * depth_scale nn.init.xavier_uniform_(module.weight, gain=gain) nn.init.zeros_(module.bias)

3. Monitor Gradient Flow

def monitor_gradients(model): """Track gradient statistics during training""" for name, param in model.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() print(f"{name}: grad_norm={grad_norm:.4f}") if grad_norm < 1e-6: print(f" WARNING: Vanishing gradient!") elif grad_norm > 100: print(f" WARNING: Exploding gradient!")

Common Pitfalls

1. Wrong Activation Function

# ❌ BAD: Xavier with ReLU nn.init.xavier_uniform_(layer.weight) output = F.relu(layer(input)) # Suboptimal for ReLU # ✅ GOOD: Xavier with Tanh nn.init.xavier_uniform_(layer.weight) output = torch.tanh(layer(input)) # Designed for tanh

2. Ignoring Depth

# ❌ BAD: Same init for all depths for layer in layers: nn.init.xavier_uniform_(layer.weight) # ✅ GOOD: Adjust for depth for i, layer in enumerate(layers): depth_factor = math.sqrt(len(layers) / (i + 1)) nn.init.xavier_uniform_(layer.weight, gain=depth_factor)

3. Forgetting Biases

# ❌ BAD: Random bias initialization # Biases left uninitialized # ✅ GOOD: Zero biases nn.init.zeros_(layer.bias)

Advanced Techniques

Layer-wise Adaptive Rate Scaling (LARS)

def lars_init(model, base_lr=0.1): """Adapt learning rate per layer based on Xavier init""" for name, param in model.named_parameters(): if 'weight' in name: fan_in, fan_out = calculate_fan_in_and_out(param) scale = math.sqrt(2.0 / (fan_in + fan_out)) param.lr_scale = base_lr * scale

Variance Scaling Initializer

class VarianceScaling: """Generalized variance scaling initializer""" def __init__(self, scale=1.0, mode='fan_in', distribution='normal'): self.scale = scale self.mode = mode self.distribution = distribution def __call__(self, tensor): fan_in, fan_out = calculate_fan_in_and_out(tensor) if self.mode == 'fan_in': n = fan_in elif self.mode == 'fan_out': n = fan_out elif self.mode == 'fan_avg': n = (fan_in + fan_out) / 2.0 if self.distribution == 'normal': std = math.sqrt(self.scale / n) tensor.normal_(0, std) else: # uniform limit = math.sqrt(3.0 * self.scale / n) tensor.uniform_(-limit, limit)

If you found this explanation helpful, consider sharing it with others.

Mastodon