Maschinelles Lernen Lernen

Exponentially moving average and batch size

April 12, 2019

Exponentially moving averages are used in several places in machine learning (often under the header of momentum). We look at the connection between batch size and momentum.

So today, we're looking at a less specialized topic than the internals of JIT optimizations, but we will have generally applicable conclusions.

There are a number of places where momentum is used in learning. One of them is batch norm, but today, we're looking at its use in optimizers. Specifically, we're interested in the effect of batch size.

Batch Size and Moving Averages in the Optimizer

So one thing that often happens to me - I do most of my experimentation on a GTX 1080 Ti card, so with 11GB of GPU memory - is that I need to reduce the batch size for my training relative to what people publishing models have. How should I change the optimizer hyperparameters in such a case? I've seen this on the PyTorch forums a couple of times, and it also came up recently in a discussion in the context of the fast.ai course, so I finally wrote down the calculation.

A quick look at the LAMB optimizer

Let's take the recent and exciting LAMB optimizer as an example, and also highlight the key differences to the more well-established ADAM. Taking a parameter (weight) $w$ and a gradient $g$ for this parameter, it does the following four things:

  1. Just like ADAM, it keeps a (pointwise) exponentially weighted moving average of gradient $g$ and squared gradient $g^2$1, and uses the de-biased estimates, let's call them $m$ and $v$ to then normalize the averaged gradient to $\hat m = m / \sqrt{v + \varepsilon}$. We need the "momentum" parameters going into these exponentially weighted moving averages and call the $\beta_1$ and $\beta_2$, respectively. ADAM without weight decay would now take a step by multiplying $\hat m$ with the learning rate.2 But we continue.

  2. LAMB then adds a weight decay term to get the step direction $s = \hat m + \lambda w$, where $w$ is the current parameter. This is similar to the AdamW variant of ADAM.

  3. Finally it scales the step to the size of $w$ instead of using the magnitude of the "raw" step $s$. So the step is $\hat s = \|w\|_2 \cdot (s / \|s\|_2)$. Here, we use the Euclidean norm $\|\,\cdot \,\|_2$. I rewrote this compared to the original paper to emphasize that $s$ gives the direction, but $w$ the magnitude.

  4. Now, given a learning rate $\eta$, the step is $w_\text{new} = w - \eta \hat s$.

To me it's really interesting that LAMB scales the step to the weight magnitude (and I wonder if we could run into trouble with this when our weights become small. But maybe then we're just in a pickle anyways) - note this is not on individual weights, but a whole parameter (tensor). One trick I saw in Jeremy Howard's lecture was to bound the ratio between weight norm and gradient norm to be no larger than 10. When the gradient almost vanishes, it probably makes sense to not scale it up artificially (e.g. if you're really close to a local optimum).

I wondered a bit (and I'll readily admit I didn't read the papers carefully) about how things change when we divide by the square root centered second moment (aka variance) rather than the second moment. Happily Jeremy provides a notebook trying out LAMB on his Imagenette toy data set. The experiments on that seem to suggest that it does not matter that much (but you have to align the momentum parameters to avoid getting negative variances), but I didn't tune the parameters a lot, and neither did Jeremy apparently (the notebook was for showing how to implement LAMB).

But we wanted to talk about the rĂ´le of the batch size here. Did you see it? No!

So what we will do, is we pretend that we have found $\lambda$, $\beta_1$, $\beta_2$, $\eta$ that work on a batch size of $b \in \mathbb{N}$ and we are interested how we would change the parameters for a batch size of $1$ (and later more generically $b'$) to keep the quantities roughly the same.

Time, discounting, survival rates, and moving averages

So let us look at the update of the moving average in step 1. We can leave out the de-biasing, as it is asymptotically $1$. Now we have $m = \beta_1 m_\text{old} + (1 - \beta_2) g$. If we now move from a batch size of $b$ to a batch size of $1$, how should $\beta_1$ change to have roughly the same?

Let us pretend that $b$ is $12$, like $12$ months in one year. If I have an annual interest rate $r$, a receivable of $1$ unit due in one year is worth $d_{annual} = (1 + r)^{-1}$, and $d_{annual}$ is the discount factor. But how much would it be worth (in a flat interest scenario) if I was $m$ months in? So clearly we would want this to be multiplicative, so that what is now 6 months out will be of the same value as what is 12 months out today will be worth in six months. What achieves this is the fractional power $d_{monthly} = (1 + r)^{-1/12} = d_{annual}$. If you're more into the real world than into financials, we can similarly calculate the remaining fraction of radioactive material for a fraction of the half time.

So what does that mean for the momentum? In order to get the same weighting of $m_\text{old}$ after $b$ steps (with batch size 1) as we previously got with one step, we need to adjust the momentum to $\tilde \beta_1 = \beta_1^{1/b}$. Note that we won't exactly get the same result after these steps, though, because the $b$ incremental updates for, say $k = 1,...,b$ will be weighted with $\tilde \beta_1^(b-k) (1-\tilde \beta_1)$. If they were all the same, we can use a cash flow analogy, that we have now moved the payment date from a single payment at the end to payments at the end of fractional periods. But the general size of the quantity should be very similar. And we consider step 1 mastered.

Note that this discussion implicitly assumes that the gradient will be of similar size. This is usually the case because the losses typically average over the batch dimension. However, the division in step 1 would let a constant factor cancel, too.

Adapting the other hyperparameters

So what with the weight decay term in step two? Let's pretend the $weight$ would stay roughly of the same size between steps (it will if $\eta$ is not close to $1$), then we have gotten the same order in $\hat m$ from step one. Then to be consistent, we would like to $\lambda w$ to be of the same order, too. So we leave $\lambda$ unchanged.

On to the third step: Here, there are no parameters. Note that the step $\hat s$ will be of the same size as before.

This means, that in the fourth step, which we will now take $b$ times as often, we have the same size inputs, and we want to change the learning rate to $\tilde \eta = \frac{\eta}{b}$ to be in the same place for $b$ steps where we previously were in one step.

An we're done. So if you change the batch size to $b$ in LAMB (or even have varying batch size), just change the momentum hyperparameter to $\beta_i = \tilde \beta_i^{b}$ and the learning rate to $\eta = b \tilde \eta$.


So we saw that we not only need to adapt the learning rate when changing the batch size, but also found out how to change the momentum parameters. Weight decay should, on the other hand, stay the same, at least in the AdamW / LAMB formulation.

For completeness, I should note that Smith & Le suggest that there is an optimal relation of batch size, momentum, and learning rate in SGD, but here I'm after a much simpler trying to use the adaptiveness of LAMB and ADAM.

  1. The fancy people use $g \odot g$, the Hadarmard (=pointwise) product. But my tastes here are rather plain. 

  2. And there is something funny with the $\varepsilon$, which ADAM puts outside the square root. (I learned from the fast.ai course).