Batch Normalization Explained: Why It Makes Deep Learning Work Better
The Problem: Internal Covariate Shift
When training a deep network, each layer's input distribution changes as the weights of all preceding layers are updated. A layer that was learning to process inputs centered around 5 with a spread of 2 might suddenly receive inputs centered around 50 with a spread of 20 after a weight update in an earlier layer. This forces every layer to continuously re-adapt to shifting input statistics, which slows convergence and requires very small learning rates to prevent the training from becoming unstable.
The original batch normalization paper called this "internal covariate shift," and while subsequent research has debated whether this is the precise mechanism that makes batch normalization work, the practical effects are undisputed. Networks with batch normalization train faster, tolerate higher learning rates, and reach better final accuracy than identical networks without it. The technique works, even if the theoretical explanation is still evolving.
Without normalization, activations in deep networks tend to either grow or shrink as they propagate through layers. If the average activation magnitude increases with depth, later layers receive very large inputs, their gradients become large, and training becomes unstable. If activations shrink, gradients vanish, and the early layers stop learning. Batch normalization prevents both scenarios by enforcing a consistent activation scale at every layer.
How Batch Normalization Works
The computation has four steps, applied independently to each feature (channel) in the layer's output. First, compute the mean of the feature across all examples in the current batch. Second, compute the variance. Third, normalize each value by subtracting the batch mean and dividing by the square root of the batch variance plus a small epsilon (typically 1e-5) for numerical stability. Fourth, scale and shift the normalized values using two learned parameters, gamma and beta.
The formula is: y = gamma * (x - mean) / sqrt(variance + epsilon) + beta. The gamma and beta parameters are learned during training, just like the weights and biases of regular layers. This is critical: if the optimal representation for a layer happens to require a non-zero mean or a variance different from 1, the network can learn the appropriate gamma and beta to achieve it. Batch normalization does not force activations to have zero mean and unit variance; it normalizes them first and then lets the network learn the optimal scale and shift.
The normalization is computed per feature, not per example. In a convolutional network with 64 feature maps, batch normalization computes 64 means and 64 variances, one for each feature map, averaged across both the batch dimension and the spatial dimensions (height and width). This means that each feature map is normalized to have consistent statistics regardless of the specific images in the batch or their spatial positions.
Training vs Inference Behavior
During training, the mean and variance are computed from the current mini-batch. This introduces a source of noise because different mini-batches have different statistics. With a batch size of 32, the batch mean and variance are estimates based on just 32 examples, which may not represent the full dataset accurately. This noise acts as a mild regularizer, similar to dropout, which is one reason batch normalization can slightly improve generalization.
During inference (when the trained model processes new data), you typically process one example at a time or use a small batch, making the mini-batch statistics unreliable. Batch normalization solves this by maintaining running estimates of the population mean and variance during training, updated with an exponential moving average at each training step. At inference time, these fixed running statistics replace the mini-batch statistics, ensuring that the model's output is deterministic and does not depend on what other examples happen to be in the same batch.
Switching between training mode (using batch statistics) and evaluation mode (using running statistics) is essential. Forgetting to switch to evaluation mode before inference is a common bug that produces inconsistent and often degraded predictions, because the batch statistics of a small inference batch are poor estimates of the true population statistics.
Why It Works: Current Understanding
The original explanation, that batch normalization reduces internal covariate shift, has been questioned by subsequent research. A 2018 paper by Santurkar et al. showed that batch normalization does not actually reduce internal covariate shift by the strict definition, and that adding artificial covariate shift to batch-normalized networks does not degrade performance. Instead, they argued that batch normalization smooths the loss landscape, making the optimization surface less rugged and allowing larger learning rate steps without overshooting minima.
The smoothing explanation aligns with the observed practical benefits. A smoother loss landscape means that gradient descent is less likely to oscillate, diverge, or get stuck in sharp local minima. This is why batch normalization allows higher learning rates: the gradients point in more consistent directions, so taking larger steps is safe. Higher learning rates mean faster convergence, and the smoother landscape means the optimizer finds wider, flatter minima that generalize better to new data.
Batch normalization also helps by decoupling the magnitude of the weights from the magnitude of the activations. Without normalization, the scale of a layer's output depends on the magnitude of its weights. With normalization, only the direction of the weight vector matters, not its length, because any scaling is immediately normalized away. This simplifies the optimization problem and makes the effective learning rate for each layer more uniform.
Placement in the Network
The original paper placed batch normalization between the linear transformation and the activation function: Linear -> BatchNorm -> ReLU. This is still the most common placement in convolutional networks. An alternative placement, Linear -> ReLU -> BatchNorm, normalizes the activated outputs and has been shown to work equally well in most cases. Modern architectures like ResNet use the original placement, and changing it typically has minimal impact on final performance.
In residual networks, batch normalization is applied within the residual branch, before the addition of the skip connection. The pre-activation ResNet variant moves batch normalization to before the convolutional layer (BatchNorm -> ReLU -> Conv), which was shown to slightly improve performance for very deep networks. Both placements are widely used, and the choice is often determined by the specific architecture rather than a general rule.
Limitations and Alternatives
Batch normalization depends on the batch size. With very small batches (less than 8 to 16 examples), the batch statistics are noisy estimates of the true statistics, which degrades performance. This is a problem for tasks that require small batch sizes due to memory constraints, such as training on high-resolution images or 3D data. It also makes batch normalization problematic in settings where examples are processed one at a time, like online learning or some reinforcement learning scenarios.
Layer normalization, which normalizes across the feature dimension within each individual example rather than across the batch, does not depend on batch size. It computes the mean and variance for each example independently. Layer normalization has become the standard for transformer architectures, where batch sizes vary and the model must work identically for a single example as for a batch. GPT, BERT, and virtually every transformer model uses layer normalization rather than batch normalization.
Group normalization divides channels into groups and normalizes within each group per example, combining aspects of layer normalization and batch normalization. It performs well across a range of batch sizes and is a good choice when batch size varies during training. Instance normalization, which normalizes each channel of each example independently, is standard in style transfer and image generation tasks where per-instance statistics carry important information about style.
RMSNorm (Root Mean Square Normalization) simplifies layer normalization by removing the mean subtraction step and only normalizing by the root mean square of the activations. This reduces computation slightly while maintaining performance. RMSNorm is used in LLaMA, Mistral, and several other recent language models.
Batch normalization stabilizes deep network training by normalizing layer outputs to consistent statistics, enabling higher learning rates and faster convergence. It works by smoothing the optimization landscape, not just by reducing covariate shift. Use batch normalization for CNNs and layer normalization for transformers. Remember to switch to evaluation mode during inference.