Batch Normalization
Understanding batch normalization technique that normalizes inputs to accelerate training and improve neural network performance.
Best viewed on desktop for optimal interactive experience
Batch Normalization Explained
Batch normalization is a technique that normalizes the inputs to each layer in a neural network, accelerating training and improving model performance. It addresses the internal covariate shift problem by ensuring that layer inputs maintain consistent statistical properties throughout training.
Interactive Visualization
Batch Normalization Visualization
Batch Normalization Demo
Batch Normalization Formula
Input Batch
Sample | F0 | F1 | F2 | F3 |
---|---|---|---|---|
0 | -0.96 | 15.85 | 23.19 | 25.19 |
1 | -1.71 | 13.96 | 21.25 | 27.30 |
2 | -2.76 | 9.92 | 14.74 | 36.08 |
3 | 2.25 | 13.01 | 16.39 | 31.11 |
4 | -2.88 | 11.44 | 17.65 | 28.10 |
5 | -1.42 | 9.79 | 9.10 | 33.09 |
6 | 4.48 | 14.81 | 20.23 | 28.16 |
7 | 2.09 | 13.21 | 17.62 | 35.79 |
Normalized Data
Sample | F0 | F1 | F2 | F3 |
---|---|---|---|---|
0 | -- | -- | -- | -- |
1 | -- | -- | -- | -- |
2 | -- | -- | -- | -- |
3 | -- | -- | -- | -- |
4 | -- | -- | -- | -- |
5 | -- | -- | -- | -- |
6 | -- | -- | -- | -- |
7 | -- | -- | -- | -- |
Output
Sample | F0 | F1 | F2 | F3 |
---|---|---|---|---|
0 | -- | -- | -- | -- |
1 | -- | -- | -- | -- |
2 | -- | -- | -- | -- |
3 | -- | -- | -- | -- |
4 | -- | -- | -- | -- |
5 | -- | -- | -- | -- |
6 | -- | -- | -- | -- |
7 | -- | -- | -- | -- |
Batch Statistics
Learnable Parameters & Running Stats
Distribution Evolution
Feature 0
Input Distribution
Original data
Normalized Distribution
Zero mean, unit variance
Output Distribution
γ=1, β=0
Feature 1
Input Distribution
Original data
Normalized Distribution
Zero mean, unit variance
Output Distribution
γ=1, β=0
Feature 2
Input Distribution
Original data
Normalized Distribution
Zero mean, unit variance
Output Distribution
γ=1, β=0
Feature 3
Input Distribution
Original data
Normalized Distribution
Zero mean, unit variance
Output Distribution
γ=1, β=0
• Input: Original data with varying means and scales per feature
• Normalized: Centered at zero with unit variance - notice the red line at zero
• Output: Allows network to learn optimal scale (γ) and shift (β) for each feature
💡 Try changing distributions and γ/β parameters to see dramatic effects!
Key Insights
- •Batch Normalization normalizes inputs to have zero mean and unit variance per feature
- •Training mode uses batch statistics; Inference mode uses running averages
- •γ and β are learnable parameters that allow the network to undo normalization if needed
- •Benefits: Faster training, higher learning rates, less sensitive to initialization
- •Prevents internal covariate shift - distribution changes between layers
- •ε = 0.00001 prevents division by zero in the normalization
The Problem: Internal Covariate Shift
Batch normalization addresses a fundamental problem in deep learning called Internal Covariate Shift. This refers to the change in distribution of layer inputs during training, which makes training slower and less stable.
How Batch Normalization Works
Batch normalization normalizes the inputs to a layer by adjusting and scaling the activations. For each feature in a batch, it:
- Computes batch statistics: Calculate the mean (μ) and variance (σ²) across the batch
- Normalizes: Transform inputs to have zero mean and unit variance
- Scales and shifts: Apply learnable parameters γ (scale) and β (shift) to restore representational power
Mathematical Formula
The batch normalization transformation can be expressed as:
Batch mean
Batch variance
Normalize
Scale and shift
Where:
m
is the batch sizexᵢ
are the input activationsε
is a small constant (typically 1e-5) to prevent division by zeroγ
andβ
are learnable parameters
Training vs Inference
Training Mode
- Uses current batch statistics (μ, σ²) for normalization
- Updates running averages of mean and variance using exponential moving average
- Learnable parameters γ and β are updated through backpropagation
Inference Mode
- Uses running averages computed during training instead of batch statistics
- This ensures consistent behavior regardless of batch size during inference
- No updates to running statistics or learnable parameters
Key Benefits
1. Accelerated Training
- Enables higher learning rates by reducing sensitivity to parameter initialization
- Networks converge faster and more reliably
2. Improved Gradient Flow
- Reduces internal covariate shift - the change in distribution of layer inputs during training
- Helps mitigate vanishing/exploding gradient problems
3. Regularization Effect
- Acts as implicit regularization by adding noise through batch statistics
- Often reduces the need for other regularization techniques like dropout
4. Reduced Sensitivity to Initialization
- Networks become less dependent on careful weight initialization
- More robust training across different initialization schemes
Implementation Considerations
Placement in Network
# Typical placement: after linear transformation, before activation x = conv_layer(x) x = batch_norm(x) x = activation(x)
Channel-wise Normalization
- For convolutional layers, normalization is applied per channel
- Each channel has its own γ and β parameters
- Statistics are computed across batch, height, and width dimensions
Momentum Parameter
- Controls the update rate of running statistics
- Typical value: 0.9 (90% old value, 10% new batch value)
- Higher momentum = more stable running statistics
Variations and Extensions
Layer Normalization
- Normalizes across features instead of batch dimension
- Better for variable-length sequences (RNNs, Transformers)
- Not dependent on batch size
Group Normalization
- Divides channels into groups and normalizes within each group
- Effective for small batch sizes where batch statistics are unreliable
Instance Normalization
- Normalizes each sample independently
- Popular in style transfer and generative models
When to Use Batch Normalization
Recommended For:
- Convolutional Neural Networks - Especially deep architectures
- Fully Connected Networks - When training deep feedforward networks
- Computer Vision Tasks - Image classification, detection, segmentation
Consider Alternatives For:
- Small Batch Sizes - Batch statistics become unreliable
- Recurrent Networks - Layer normalization often works better
- Online Learning - When batch statistics aren't available
Common Pitfalls
1. Batch Size Dependency
- Very small batches lead to noisy statistics
- Batch size of 1 makes batch norm equivalent to instance norm
2. Training/Inference Mismatch
- Must ensure proper mode switching between training and inference
- Running statistics must be properly maintained
3. Learning Rate Adjustment
- Batch norm allows higher learning rates, but requires tuning
- Too high learning rates can still cause instability
Related Concepts
- Skip Connections - Often used together with batch norm in modern architectures
- Internal Covariate Shift - The problem batch normalization addresses
- Gradient Flow - How batch norm improves gradient propagation
- Layer Normalization - Alternative normalization technique
- Residual Networks - Architecture that popularized batch normalization
Related Concepts
Deepen your understanding with these interconnected concepts