## The Danger of Batch Normalization in Deep Learning6 min read

If you have ever read the implementation of a deep learning model, chances are you have already encountered BatchNorm (Batch Normalization). This is a very common operation that is used to accelerate the training of large models and to stabilise unstable ones. However, if you are a practitioner, it’s quite possible that that you also struggled with that operation, which notoriously poses many problems. In this article, we will review the problems we often encounter and propose a few solutions.

## What is a Batch Normalisation Layer ?

BatchNorm aims at solving the problem of the covariate shift. What this means is that for a given layer in a deep network, the output has a mean and standard deviation across the dataset. During training, this mean and standard deviation are unconstrained and can randomly evolve, which can pose some numerical stability issues. The BatchNorm operation attempts to remove this problem by normalising the layer’s output. However, it is too costly to evaluate the mean and the standard deviation on the whole dataset, so we only evaluate them on a batch of data.

This works well in practice, but we cannot do the same at inference time, because we receive data one by one, so the averages do not make sense anymore. In order to solve this problem, modern implementations propose to calculate a running average over the data.

## The Problem

In summary, the behaviour is different between training and inference. At training time \(t, m_t\), and \(\sigma_t\) are used, but at inference time \(\widehat{m_t}\) and \(\widehat{\sigma_t}\) are used. This difference is the root of all evil as the metrics in validation and in training can be very different. More precisely, as the real quantity evolves during training, the running average will often lag behind, which can cause a significant difference. In principle, if the batch is large, and if the model converges fine, then those quantities should become the same. But in practice, it is often wrong or impractical. For example, it will not be obvious if a large discrepancy between the training and validation loss is due to severe overfitting, or because those quantities did not converge yet.

More dangerously, we regularly observe that although the training loss converges to some value, the validation loss can remain considerably higher, due to the mean and standard deviation of the BatchNorm never stabilising. We, the authors, are not entirely sure of the cause of the problem, but we believe that this can happen when the minimum is heavily degenerated. For example, in a loss landscape like illustrated below, the model will move randomly in the circular valley, causing the running average to lag behind forever.

## The Solution

The first thing to do if you encounter this problem is to try a few standard tricks. Here are some typical ones:

- Try to use another normalisation solution (i.e. LayerNorm, InstanceNorm…);
- Increase the batch size, which can stabilise the estimation of the mean and standard deviation among the batch;
- Play with the momentum parameter of the moving average. It tells you how much previous batches are persistent in the running average, in other words, how much the estimates can “lag behind”;
- Shuffle your training set in every epoch, to avoid correlation between data points.

However, sometimes those basic tricks will not suffice. In that case we propose to use a more powerful astuce.

Keep in mind that we have two different behaviours of the BatchNorm layer:

- In what we will call Batch Estimation Mode, the mean and standard deviations are estimated on the batch. This is the mode used in training;
- In what we will call Inference mode, the mean and standard deviation are based on previous estimations, meaning the running average. This is what is usually used during validation and inference.

Our solution is two steps! First, we deactivate the difference between training and validation by always using the batch estimation mode. Secondly, in order to use the model in production, we still need to estimate the mean and standard deviation to be able to use inference mode. So, after the model is trained, we calculate the mean and standard deviation to be used. By doing so they are evaluated on a model with fixed weight and we avoid the previously described “lagging behind” effect. More concretely, after training, we freeze all the weights of the model and run one epoch in to estimate the moving average on the whole dataset.

## Experimenting with the Solution

In order to show the advantage of our solution, let’s do small experiment. We purposely used a very bad architecture and trained it with a relatively high learning rate, resulting in a model with unstable BatchNorm. The code, written in Python and using PyTorch is available here.

The network is a stack of 3 convolution layers, with BatchNorm and ReLU activation followed by a global average pooling layer. We trained it on MNIST for 10 epochs using the Adam Optimization Algorithm. In the figure below are the training and validation accuracy per epoch in 4 modes:

- Mode 0: No BatchNorm layers are used.
- Mode 1: Basic BatchNorm with no modifications.
- Mode 2: Almost Smart BatchNorm: we activated the running stats for inference but we didn’t run the model 1 epoch to estimate the moving average of stats.
- Mode 3: Smart BatchNorm: we estimate on 1 epoch the average stats of the dataset before inference mode.

We observe two things. First, BatchNorm helps increase the accuracy. Secondly, without our solution, the validation metric is erratic and uninformative. Finally, we provide the test accuracy for all 4 situations.

As you can see, we could obtain better results using our solution. The third mode is really bad: we activate the running stats (inference mode) but we don’t estimate those statistics on the dataset so when testing in inference conditions with a batch size of 1, we have poor results. This shows the necessity to combine the running statistics at inference time with the estimation of the dataset statistics on a whole epoch of the dataset before using the model for inference.

## Is that Solution Perfect?

No, obviously not! Many bad things can still happen. The trickiest is that your estimated mean and standard deviation will still be different from the batch estimate and some really odd phenomena can still hit you hard. For example, it has been shown that some models can actually encode information in statistical noise. Fortunately, those extreme cases are very scarce and experience has shown that this solution is quite robust, it should only improve your performances and save you a lot of headaches. If you want to avoid weird behaviors with your BatchNorm layers, go for it.

*Feature image by Pietro Jeng*