He/Kaiming Initialization
Master He (Kaiming) initialization, the optimal weight initialization technique for ReLU networks that prevents gradient vanishing in deep neural architectures.
Best viewed on desktop for optimal interactive experience
He/Kaiming Initialization: Optimizing for ReLU Networks
He initialization (also known as Kaiming initialization) is a weight initialization method specifically designed for networks using ReLU activation functions, addressing the asymmetric nature of ReLU's gradient flow.
Interactive ReLU Network Explorer
Visualize how He initialization maintains signal variance in ReLU networks:
Layer Activation Statistics
Variance Evolution Comparison
Layer-by-Layer Analysis
Layer | Pre-Act Mean | Pre-Act Std | Post-Act Mean | Post-Act Std | Dead Neurons | Health |
---|---|---|---|---|---|---|
1 | -0.031 | 0.215 | 0.068 | 0.099 | 51.9% | ✗ Poor |
2 | 0.006 | 0.005 | 0.006 | 0.005 | 0.4% | ✗ Poor |
3 | 0.001 | 0.001 | 0.001 | 0.001 | 0.5% | ✗ Poor |
4 | -0.001 | 0.000 | 0.000 | 0.000 | 98.8% | ✗ Poor |
5 | 0.000 | 0.000 | 0.000 | 0.000 | 0.9% | ✗ Poor |
6 | 0.000 | 0.000 | 0.000 | 0.000 | 1.8% | ✗ Poor |
7 | 0.000 | 0.000 | 0.000 | 0.000 | 100.0% | ✗ Poor |
8 | 0.000 | 0.000 | 0.000 | 0.000 | 100.0% | ✗ Poor |
Mathematical Foundation
He Initialization Formula
ReLU Variance Correction
Why He Initialization for ReLU?
The ReLU Challenge
ReLU (Rectified Linear Unit) introduces unique challenges:
- Asymmetric Activation: f(x) = max(0, x)
- Dead Neurons: Negative inputs produce zero output
- Gradient Asymmetry: Zero gradient for negative inputs
- Variance Reduction: ReLU cuts variance in half
Xavier's Limitation with ReLU
Xavier initialization assumes symmetric activations (like tanh):
But ReLU breaks this assumption:
The He Solution
Core Insight
Account for ReLU's variance reduction by doubling the weight variance:
This compensates for the halving effect of ReLU, maintaining signal variance.
Mathematical Derivation
For a layer with ReLU activation:
To maintain variance:
Considering ReLU's effect:
He Initialization Variants
1. He Normal
Draw weights from normal distribution:
2. He Uniform
Draw weights from uniform distribution:
3. Generalized He
For different activation functions:
Where gain depends on activation:
- ReLU: √(2)
- Leaky ReLU: √(\frac{2){1 + α2}}
- ELU: 1.0
- SELU: 34
Implementation
PyTorch Implementation
import torch import torch.nn as nn import math def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='relu'): """ He/Kaiming normal initialization Args: tensor: Tensor to initialize a: Negative slope for leaky_relu mode: 'fan_in' or 'fan_out' nonlinearity: Type of activation function """ fan = calculate_fan(tensor, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) with torch.no_grad(): tensor.normal_(0, std) return tensor def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='relu'): """ He/Kaiming uniform initialization Args: tensor: Tensor to initialize a: Negative slope for leaky_relu mode: 'fan_in' or 'fan_out' nonlinearity: Type of activation function """ fan = calculate_fan(tensor, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std with torch.no_grad(): tensor.uniform_(-bound, bound) return tensor def calculate_gain(nonlinearity, param=None): """Calculate the gain for different activation functions""" linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] if nonlinearity in linear_fns or nonlinearity == 'sigmoid': return 1 elif nonlinearity == 'tanh': return 5.0 / 3 elif nonlinearity == 'relu': return math.sqrt(2.0) elif nonlinearity == 'leaky_relu': if param is None: negative_slope = 0.01 else: negative_slope = param return math.sqrt(2.0 / (1 + negative_slope ** 2)) elif nonlinearity == 'selu': return 3.0 / 4 else: raise ValueError(f"Unsupported nonlinearity: {nonlinearity}") def calculate_fan(tensor, mode): """Calculate fan_in or fan_out""" dimensions = tensor.dim() if dimensions < 2: raise ValueError("Fan can't be computed for tensor with fewer than 2 dimensions") num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim() > 2: # For convolutional layers for s in tensor.shape[2:]: receptive_field_size *= s fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in if mode == 'fan_in' else fan_out # Using PyTorch's built-in initialization class ReLUNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) # Apply He initialization nn.init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu') nn.init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='relu') nn.init.kaiming_normal_(self.fc3.weight, mode='fan_in', nonlinearity='relu') # Initialize biases to zero nn.init.zeros_(self.fc1.bias) nn.init.zeros_(self.fc2.bias) nn.init.zeros_(self.fc3.bias) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
TensorFlow/Keras Implementation
import tensorflow as tf from tensorflow.keras import layers, initializers # Using Keras initializers model = tf.keras.Sequential([ layers.Dense( 256, activation='relu', kernel_initializer=initializers.HeNormal(), bias_initializer='zeros' ), layers.Dense( 128, activation='relu', kernel_initializer=initializers.HeNormal(), bias_initializer='zeros' ), layers.Dense( 10, kernel_initializer=initializers.HeNormal() ) ]) # Custom implementation class HeInitializer(tf.keras.initializers.Initializer): def __init__(self, seed=None): self.seed = seed def __call__(self, shape, dtype=None): fan_in = shape[0] if len(shape) >= 2 else shape[0] stddev = tf.sqrt(2.0 / fan_in) return tf.random.normal( shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed )
Advanced Techniques
Mode Selection: Fan-in vs Fan-out
def choose_initialization_mode(network_depth, layer_position): """ Choose between fan_in and fan_out based on network architecture Args: network_depth: Total number of layers layer_position: Current layer index """ if layer_position < network_depth // 2: # Early layers: preserve forward signal return 'fan_in' else: # Later layers: preserve gradient flow return 'fan_out' class AdaptiveHeInit: """Adaptive He initialization based on layer position""" def __init__(self, network_depth): self.network_depth = network_depth def initialize_layer(self, layer, layer_idx): mode = choose_initialization_mode(self.network_depth, layer_idx) nn.init.kaiming_normal_( layer.weight, mode=mode, nonlinearity='relu' )
Leaky ReLU and Variants
def initialize_for_activation(layer, activation_type, **kwargs): """Initialize weights based on activation function""" if activation_type == 'relu': nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') elif activation_type == 'leaky_relu': negative_slope = kwargs.get('negative_slope', 0.01) nn.init.kaiming_normal_( layer.weight, a=negative_slope, nonlinearity='leaky_relu' ) elif activation_type == 'elu': # ELU has different variance properties nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') layer.weight.data *= 0.8 # Empirical adjustment elif activation_type == 'prelu': # PReLU learns the negative slope nn.init.kaiming_normal_(layer.weight, a=0.25, nonlinearity='leaky_relu') elif activation_type == 'selu': # SELU has self-normalizing properties nn.init.kaiming_normal_(layer.weight, nonlinearity='linear')
Residual Networks
class ResidualBlock(nn.Module): """He initialization for residual connections""" def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.bn2 = nn.BatchNorm2d(channels) # He initialization for conv layers nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu') # Zero-initialize the last BN in each residual branch nn.init.constant_(self.bn2.weight, 0) nn.init.constant_(self.bn2.bias, 0) def forward(self, x): identity = x out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += identity out = F.relu(out) return out
Comparison with Other Methods
Method | Formula | Best For | Key Assumption |
---|---|---|---|
He | 2nin | ReLU, Leaky ReLU | Asymmetric activation |
Xavier | 2nin + nout | Tanh, Sigmoid | Symmetric activation |
LeCun | 1nin | SELU | Self-normalizing |
Orthogonal | Q from QR | RNNs | Preserve angles |
LSUV | Layer-wise normalization | Very deep nets | Unit variance |
Practical Guidelines
1. Activation Function Matching
INIT_MAPPING = { 'relu': ('kaiming_normal_', {'nonlinearity': 'relu'}), 'leaky_relu': ('kaiming_normal_', {'nonlinearity': 'leaky_relu', 'a': 0.01}), 'tanh': ('xavier_normal_', {}), 'sigmoid': ('xavier_uniform_', {}), 'selu': ('kaiming_normal_', {'nonlinearity': 'linear'}), 'gelu': ('kaiming_normal_', {'nonlinearity': 'relu'}), # Approximation 'swish': ('kaiming_normal_', {'nonlinearity': 'relu'}) # Approximation } def auto_initialize(model): """Automatically initialize based on activation functions""" for name, module in model.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): # Find the activation that follows this layer next_activation = find_next_activation(model, name) if next_activation in INIT_MAPPING: init_fn_name, kwargs = INIT_MAPPING[next_activation] init_fn = getattr(nn.init, init_fn_name) init_fn(module.weight, **kwargs)
2. Debugging Initialization
def analyze_initialization(model, input_shape): """Analyze activation statistics after initialization""" model.eval() hooks = [] activations = {} def hook_fn(name): def hook(module, input, output): activations[name] = { 'mean': output.mean().item(), 'std': output.std().item(), 'dead_neurons': (output == 0).float().mean().item(), 'max': output.max().item(), 'min': output.min().item() } return hook # Register hooks for name, module in model.named_modules(): if isinstance(module, nn.ReLU): hooks.append(module.register_forward_hook(hook_fn(name))) # Forward pass with random input with torch.no_grad(): x = torch.randn(32, *input_shape) _ = model(x) # Remove hooks for hook in hooks: hook.remove() return activations # Usage stats = analyze_initialization(model, (3, 224, 224)) for layer, stat in stats.items(): print(f"{layer}: mean={stat['mean']:.3f}, std={stat['std']:.3f}, " f"dead={stat['dead_neurons']:.1%}")
3. Fixing Dead ReLUs
class SmartHeInitializer: """He initialization with dead neuron prevention""" def __init__(self, dead_neuron_threshold=0.1): self.threshold = dead_neuron_threshold def initialize(self, module, test_input): """Initialize and verify no excessive dead neurons""" max_attempts = 5 for attempt in range(max_attempts): # Apply He initialization nn.init.kaiming_normal_(module.weight, nonlinearity='relu') nn.init.zeros_(module.bias) # Test for dead neurons with torch.no_grad(): output = F.relu(module(test_input)) dead_ratio = (output == 0).float().mean() if dead_ratio < self.threshold: return True # Add small positive bias if too many dead neurons if attempt < max_attempts - 1: module.bias.data += 0.01 print(f"Warning: Could not reduce dead neurons below {self.threshold}") return False
Common Pitfalls
1. Wrong Activation Pairing
# ❌ BAD: He init with tanh nn.init.kaiming_normal_(layer.weight) output = torch.tanh(layer(input)) # Suboptimal # ✅ GOOD: He init with ReLU nn.init.kaiming_normal_(layer.weight) output = F.relu(layer(input)) # Designed for ReLU
2. Ignoring Batch Normalization
# ❌ BAD: Standard He init with BN nn.init.kaiming_normal_(conv.weight) # ✅ GOOD: Adjusted for BN nn.init.kaiming_normal_(conv.weight, mode='fan_out') # BN will normalize anyway, but fan_out helps gradients
3. Mode Confusion
# ❌ BAD: Always using fan_in nn.init.kaiming_normal_(layer.weight, mode='fan_in') # ✅ GOOD: Consider the use case # Forward pass critical: use fan_in # Backward pass critical: use fan_out # Convolutional layers with BN: use fan_out
Performance Impact
Convergence Speed
Initialization | Epochs to 90% Accuracy | Final Accuracy |
---|---|---|
Random (σ=0.01) | Failed to converge | N/A |
Xavier (with ReLU) | 25 | 92.3% |
He (with ReLU) | 15 | 94.7% |
He + BN | 12 | 95.2% |
Deep Network Stability
For a 50-layer ReLU network:
- Random Init: Gradient vanishing by layer 10
- Xavier Init: Gradient vanishing by layer 30
- He Init: Stable gradients through all layers
- He + Residual: Perfect gradient flow
Advanced Research
FIXUP Initialization
For very deep networks without normalization:
def fixup_initialization(layer, depth, layer_idx): """FIXUP: Initialize for deep networks without BN""" if isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, std=math.sqrt(2/layer.in_features) * depth**(-1/4)) if layer_idx == depth - 1: nn.init.zeros_(layer.weight) nn.init.zeros_(layer.bias)
Lottery Ticket Hypothesis
def lottery_ticket_init(model, sparsity=0.9): """Initialize with structured sparsity""" for module in model.modules(): if isinstance(module, nn.Linear): # He initialization nn.init.kaiming_normal_(module.weight) # Create mask for lottery tickets mask = torch.rand_like(module.weight) > sparsity module.weight.data *= mask # Scale remaining weights module.weight.data /= (1 - sparsity)
Related Concepts
- Xavier Initialization - For symmetric activations
- Batch Normalization - Reduces initialization sensitivity
- Gradient Flow - Understanding signal propagation
- ReLU Activation - The activation He init optimizes for
- Residual Networks - Architectural solution to depth