How to Train a Neural Network
This guide covers the training techniques that experienced practitioners use routinely. It assumes familiarity with the basic training loop (forward pass, loss computation, backpropagation, weight update) and focuses on the choices and optimizations that determine whether training succeeds or fails.
Configure the Learning Rate
The learning rate is the single most important hyperparameter. Too high and training diverges. Too low and training stalls or converges to a poor solution. The optimal learning rate depends on the model, dataset, batch size, and optimizer, but several principles are universal.
Start with Adam at 0.001. Adam with its default hyperparameters (lr=0.001, beta1=0.9, beta2=0.999) is a reliable starting point for most tasks. It adapts per-parameter learning rates and includes momentum, making it robust without extensive tuning. If you later want to squeeze out maximum performance, SGD with momentum (lr=0.01-0.1, momentum=0.9) sometimes generalizes slightly better, but it requires careful tuning.
Use warm-up for transformers. Start with a very small learning rate (1/100th of the target) and linearly increase it over the first 1,000 to 10,000 steps. Transformers are sensitive to large early updates because the attention mechanism amplifies parameter perturbations. Warm-up gives the parameters time to move to a stable region before the learning rate reaches its full value.
Decay the learning rate. Cosine annealing (smoothly decreasing the rate following a cosine curve to near zero) is the most popular schedule. It requires no manual milestone selection and consistently outperforms constant learning rates. An alternative is reduce-on-plateau: monitor validation loss and reduce the learning rate by a factor (typically 0.1) when improvement stalls for a specified number of epochs.
Learning rate finder. Train the model for one epoch while exponentially increasing the learning rate from very small (1e-7) to very large (10). Plot loss vs. learning rate. The optimal learning rate is typically just below the point where the loss starts increasing sharply. This technique, popularized by Leslie Smith, gives a good initial estimate in a few minutes.
Apply Regularization
Regularization prevents overfitting, where the model memorizes training data instead of learning generalizable patterns. The right combination and strength of regularization techniques depends on the ratio of model capacity to dataset size.
Dropout. Randomly zero out neuron outputs during training with probability p (typically 0.1 for transformers, 0.2 to 0.5 for CNNs and feedforward networks). Dropout forces the network to learn redundant representations that do not depend on any single neuron. Apply dropout after attention layers and feedforward layers in transformers, and after dense layers in CNNs and MLPs.
Weight decay (L2 regularization). Add a penalty proportional to the squared magnitude of weights to the loss function. This pushes weights toward zero, preferring simpler models. A weight decay of 0.01 is standard for Adam; for SGD, 0.0001 to 0.001 is typical. In AdamW (Adam with decoupled weight decay), the regularization is applied directly to the weights rather than through the gradient, which is mathematically cleaner and empirically better.
Data augmentation. For images: random horizontal flips, random crops (with padding), color jitter, and random rotation. For text: token masking, synonym replacement, and back-translation. For tabular data: Gaussian noise injection and mixup (blending two examples and their labels). Augmentation increases the effective dataset size and forces the model to learn features that are invariant to superficial variation.
Label smoothing. Instead of training with hard labels (0 or 1), use soft labels (0.1 and 0.9, for instance). This prevents the model from becoming overconfident and acts as a calibration technique. A smoothing factor of 0.1 is standard for classification tasks.
Early stopping. Monitor validation loss and save the model whenever it reaches a new best value. Stop training when validation loss has not improved for a patience period (typically 5 to 20 epochs). The saved best model is your final model. This is the simplest and most effective regularization technique.
Use Mixed Precision Training
Mixed precision training uses FP16 (half precision, 16-bit floating point) for most computations while maintaining FP32 (full precision, 32-bit) master copies of the weights. This cuts memory usage roughly in half and doubles throughput on modern GPUs (NVIDIA Ampere and later), which have dedicated FP16 tensor cores.
The key components are: an FP32 master copy of all parameters (to prevent accumulated rounding errors), FP16 copies used for the forward and backward pass (for speed), and loss scaling (multiplying the loss by a large factor before backpropagation and dividing the gradients afterward) to prevent small gradient values from underflowing to zero in FP16.
PyTorch's torch.cuda.amp and TensorFlow's mixed precision API handle all of these details automatically. Enable mixed precision with a few lines of code and immediately get faster training and lower memory usage with virtually no accuracy loss. There is rarely a reason not to use it on supported hardware.
BF16 (brain floating point 16) is an alternative to FP16 that uses the same number of bits but allocates more to the exponent and fewer to the mantissa. BF16 has a wider dynamic range than FP16, which eliminates the need for loss scaling and makes training more stable. It is supported on NVIDIA A100 and later GPUs, Google TPUs, and Apple M-series chips.
Scale with Distributed Training
When one GPU is not enough (the model is too large, or training takes too long), distributed training spreads the work across multiple GPUs or multiple machines.
Data parallelism is the simplest approach. Each GPU holds a copy of the full model and processes a different batch of data. Gradients are averaged across GPUs before the weight update, so all copies stay synchronized. If one GPU can fit a batch of 32, four GPUs effectively process a batch of 128. PyTorch DistributedDataParallel (DDP) and TensorFlow's MirroredStrategy implement this.
Linear scaling rule. When increasing the effective batch size through data parallelism, increase the learning rate proportionally. If you quadruple the batch size, quadruple the learning rate. This keeps the ratio of update magnitude to batch statistics consistent. Combine this with a warm-up period to let the model adjust to the larger learning rate.
Model parallelism splits the model across multiple GPUs when a single GPU cannot fit the entire model. Tensor parallelism splits individual layers across GPUs. Pipeline parallelism assigns different layers to different GPUs and passes intermediate activations between them. Model parallelism is more complex to implement than data parallelism and is typically needed only for the largest models (70B+ parameters).
Gradient accumulation simulates larger batch sizes on limited hardware. Instead of updating weights after every batch, accumulate gradients over multiple batches and update once. Four accumulation steps with batch size 8 is equivalent to batch size 32 but uses only the memory needed for batch size 8. This is the simplest way to get the benefits of larger batches without more GPUs.
Monitor and Diagnose
Effective monitoring catches problems early, before they waste hours or days of GPU time.
Training and validation loss curves are the primary diagnostic. Both should decrease together in early training. If training loss decreases but validation loss increases, the model is overfitting (add regularization, get more data, or reduce model size). If both remain high, the model is underfitting (increase model size, train longer, reduce regularization, or increase learning rate).
Gradient norms reveal training stability. A sudden spike in gradient norm often precedes a training collapse. Monitoring gradient norms and applying gradient clipping (capping the norm at a threshold, typically 1.0) prevents exploding gradients from destabilizing training.
Learning rate tracking confirms that your schedule is working as intended. Log the current learning rate alongside the loss to verify that warm-up, decay, and any adaptive changes are happening at the right times.
Per-class metrics reveal whether the model is performing well across all categories or just the common ones. A model with 95% overall accuracy might have 0% accuracy on a rare but important class. Confusion matrices, per-class precision and recall, and stratified evaluation are essential for any classification task.
Sample outputs provide qualitative checks that metrics alone cannot. For generative models, periodically examine generated samples. For classifiers, review the most confidently wrong predictions. These qualitative checks often reveal problems (systematic biases, repeated errors, degenerate outputs) that aggregate metrics obscure.
Effective neural network training combines the right learning rate schedule (warm-up plus cosine annealing), appropriate regularization (dropout, weight decay, augmentation, early stopping), mixed precision for speed and memory efficiency, distributed training for scale, and careful monitoring of loss curves and gradient norms. These techniques are not optional refinements; they are the difference between training that converges to a good solution and training that wastes compute without producing a useful model.