How to Debug Neural Networks
Most neural network training failures have mundane causes: a data loading bug, an incorrect loss function, a learning rate that is too high or too low, or mislabeled training data. The systematic approach below catches these common problems before you waste time on exotic explanations.
Step 1: Verify the Data Pipeline
The most common source of training failure is bad data. Before debugging the model, confirm that the data reaching the model is correct.
Visualize batches. Display a batch of training images with their labels. Are the images the right way up? Are the labels correct? Are the images the expected resolution and color space? For text, print a batch of tokenized inputs and decode them back to text. Do they make sense?
Check preprocessing. Verify that normalization is applied correctly (mean and standard deviation matching the expected values), that augmentation produces reasonable variations (not so extreme that images become unrecognizable), and that the data types are correct (float32, not uint8 or int64).
Check label distribution. Are the classes balanced? If 95% of your data is one class, the model can achieve 95% accuracy by always predicting that class. Extreme class imbalance requires adjusted sampling, class-weighted loss, or oversampling of minority classes.
Check for data leakage. Ensure that training and validation sets do not share examples. If the same image appears in both sets (a common mistake with data augmentation), validation metrics will be misleadingly optimistic. For time series, ensure that future data is not leaking into the training set.
Step 2: Start with a Minimal Reproducible Setup
Before training on the full dataset, confirm that the model can overfit a tiny subset. Take 1 to 10 training examples and train the model on just these examples with no regularization. The model should reach near-zero training loss within a few hundred iterations.
If the model cannot memorize 10 examples, the problem is in the model architecture or the training code, not the data or the regularization. Common causes: the loss function does not match the output layer (using MSE with a softmax output, for instance), the model's forward pass has a bug (tensor dimensions are wrong), or the optimizer is not updating the parameters (gradients are not reaching all layers).
Once the model can overfit a small subset, gradually increase the dataset size and add regularization. This incremental approach isolates problems by adding complexity one step at a time.
Step 3: Read the Loss Curves
Training and validation loss curves are the primary diagnostic tool. Plot them and look for these patterns:
Training loss does not decrease. The learning rate is either too low (gradients produce negligible updates) or too high (updates overshoot and the loss oscillates). Try multiplying the learning rate by 10 and by 0.1 to find the right range. Also check that the loss function is correct and that gradients are flowing (not all zero).
Training loss decreases but is very noisy. The batch size is too small, producing noisy gradient estimates. Increase the batch size or use gradient accumulation. Some noise is normal and even beneficial, but if the loss jumps dramatically between batches, the gradient estimates are too unreliable.
Training loss goes to NaN or infinity. Exploding gradients or numerical overflow. Reduce the learning rate, add gradient clipping, check for division by zero in the loss computation, and ensure that inputs are finite (no NaN or inf values in the data).
Training loss decreases, validation loss increases. Overfitting. The model is memorizing training data rather than learning generalizable patterns. Add regularization (dropout, weight decay, augmentation), reduce model size, or get more training data.
Both losses plateau at a high value. Underfitting. The model lacks capacity to learn the patterns in the data. Increase model size, train longer, reduce regularization, or try a different architecture better suited to the data type.
Validation loss is much lower than training loss. This rare pattern indicates a bug in the validation pipeline (such as data augmentation being applied during validation when it should only apply during training, or a label distribution difference between training and validation sets).
Step 4: Check Gradients and Weights
When loss curves are not sufficient to diagnose the problem, inspect the internal state of the network during training.
Gradient norms per layer. Compute the L2 norm of gradients for each layer at each training step. If gradient norms decrease sharply from later layers to earlier layers, vanishing gradients are preventing early layers from learning. If gradient norms spike suddenly, exploding gradients are about to destabilize training.
Weight distributions. Histogram the weight values for each layer. Weights should be roughly normally distributed with a moderate standard deviation. If all weights converge to very small values, the learning signal is too weak. If weights grow very large, regularization is insufficient.
Activation statistics. Check the fraction of ReLU neurons that are always zero (dead neurons). If more than 30% of neurons are dead, the learning rate may be too high, or the initialization may be poor. For transformers, check that attention weights are not all concentrated on one position (degenerate attention).
Gradient flow test. After a backward pass, check that every layer's parameters have non-zero gradients. If any layer has zero gradients, there is a disconnection in the computational graph, possibly caused by a detached tensor, an in-place operation that breaks autograd, or a layer that was accidentally frozen.
Step 5: Apply Targeted Fixes
Match the diagnosis to the fix. Do not apply multiple fixes simultaneously, as you will not know which one worked.
For overfitting: Add dropout (0.1 to 0.5), increase weight decay (try 10x), add data augmentation, reduce model size, or collect more training data. Early stopping is a universal safeguard.
For underfitting: Increase model capacity (more layers, wider layers), reduce regularization, train for more epochs, increase learning rate, or try a more expressive architecture.
For training instability: Reduce learning rate, add gradient clipping (clip norm to 1.0), add warm-up (start with lr/100 and ramp up), add layer normalization, or check for numerical issues in the data or loss computation.
For vanishing gradients: Switch from sigmoid/tanh to ReLU, add residual connections, use proper initialization (He for ReLU), and add batch or layer normalization.
For poor convergence: Try a different optimizer (switch from SGD to Adam or vice versa), tune the learning rate with a finder, try a different learning rate schedule, or increase batch size for more stable gradients.
Debug neural networks systematically: verify the data first, confirm the model can overfit a small subset, read the loss curves for specific patterns, check gradient and weight statistics for internal problems, and apply one targeted fix at a time. Most training failures have mundane causes (data bugs, wrong learning rate, mismatched loss function), and the systematic approach finds them quickly without guesswork.