SGD with momentum

post by Ioannis Mitliagkas, Dan Iter and Chris Ré.
Includes results from a body of work performed by Stefan Hadjis, Ce Zhang and the authors.

Stochastic Gradient Descent (SGD) and its variants are the optimization method of choice for many large-scale learning problems including deep learning. A popular approach to running these systems removes locks and synchronization barriers. Such methods are called "asynchronous-parallel methods" or Hogwild! and are used on many systems by companies like Microsoft and Google.

However, the effectiveness of asynchrony is a bit of a mystery. For convex problems on sparse data, Hogwild! shows that these race conditions don't slow down convergence too much, and the lack of locking means that each step takes less time. Sparsity couldn't be the complete story as many people reported that async would be faster even for dense data, in which case Hogwild!'s theorem does not apply. Speaking of settings in which theorems don't apply...

In deep learning, there has been a debate about how to scale up training. Many systems have run asynchronously, but some have proposed that synchronous training may be faster. As part of our work, we realized that often many systems do not tune a seemingly innocuous parameter called the momentum and this parameter changes the results drastically.

Like the step size, the value of the momentum parameter depends on the objective, the data, and the underlying hardware. In all fairness, though, the latter connection had not been apparent to people, ourselves included—until the results below: We draw a new, predictive link between asynchrony and momentum, which we verify both theoretically and empirically. This post discusses the following:

If you'd like a reminder about what we mean by synchronous and asynchronous, please check out this background info.

Asynchrony induces Momentum

In a recent note we described that asynchrony induces momentum. Intuitively, the main effect of asynchrony is to increase the staleness of the updates. Consider the following simple stale read model, inspired by typical deep learning systems.

Staleness model

Worker 1 reads a value from the parameter server at time t and uses it to start computing a gradient based on an assigned mini-batch. By the time the gradient is computed and ready to be sent to the parameter server, a number of other workers have sent their own updates to the server.

Assume for now that our workers run SGD without momentum. Then, under staleness, the update has the form.

Stale update in which Staleness distribution,

and

Staleness distribution

Basically, staleness here acts as “fuzzy” memory. In expectation, our gradient is evaluated on a convex combination of recent parameter history. This is the intuition that leads to the following result.

Theorem

Let us consider the case of M asynchronous workers running plain SGD. Under staleness and assuming a simple queueing theory model, we get the following result.

Asynchrony-induced momentum

The current update is a scaled-down version of the previous update, plus the standard gradient update term. In other words, our update contains a momentum term! Let us first unpack this result and then see how this simple model can predict behavior we see in a real system.

What the theory predicts

The theorem suggests that, when training asynchronously, there are two sources of momentum:

  1. Explicit or algorithmic momentum: what we introduce algorithmically by tuning the momentum parameter in our optimizer.
  2. Implicit or asynchrony-induced momentum: what asynchrony contributes.

For now we can consider that they act additively: the total effective momentum is the sum of explicit and implicit terms. This is a good-enough first-order approximation, but can be improved by carefully modeling their higher-order interactions.

The second important theoretical prediction is more workers introduce more momentum.

Consider the following thought experiment, visualized in the figure below. We increase numbers of asynchronous workers and, in each case, we tune for the optimal explicit momentum. According to the model this is what we would see a picture like the following:

Model prediction

There is an optimal level of momentum equal to the tuned value for a single worker, the synchronous case. When we introduce a second worker, asynchrony induces non-zero momentum (orange bars). As we add more workers, we get more implicit momentum, which eventually exceeds the optimal level.

What we see in experiments

We tried these experiments on our system as well as TensorFlow and MXnet—so there is no trickery in the implementation itself. We grid search the learning rate and momentum for a number of different configurations and report the optimal explicit momentum values below for CIFAR and ImageNet.

Tuning CIFAR Tuning ImageNet

We see the same monotonic behavior we expect from theory. As we increase the level of asynchrony we have to decrease the amount of explicit momentum. Αsynchrony introduces the rest of the momentum.

The importance of tuning

We see that tuning can significantly improve the statistical efficiency of asynchronous training. In the first figure, we show results on CIFAR. We plot the normalized number of iterations to convergence for each configuration, a statistical 'penalty' compared to the best case. We first draw the penalty curve we get by using the standard value of momentum, 0.9, in all configurations. Then we draw the penalty curve we get when we grid-search momentum.

Tuning CIFAR

How to read this graph?

The next figure reports the statistical penalty curve for ImageNet when we tune momentum. Even though, we expect async to have strictly worse statistical efficiency, we see that, for larger datasets, tuning momentum can eliminate the penalty for asynchrony. Check out our paper for more details on the experimental setup.

Tuning ImageNet

How to read this graph?

Bonus idea: Negative momentum

As we saw, our model predicts that given enough asynchronous workers, implicit momentum can exceed the optimal momentum. At first glance, it might seem inevitable that we will incur a statistical penalty in that case. Some not-so-hard thinking, however, suggests a new idea:

Negative momentum animation

We could use a negative value for the explicit momentum. The hope, as shown in the figure, is that negative momentum will counteract the implicit momentum due to asynchrony. It's worth pausing to point out that this idea does not make sense in a synchronous setting as it needlessly increases the variance of the SGD step.

We test this idea on CIFAR on our system. The following figure shows the statistical penalty (increased number of iterations to the goal) we pay for asynchrony. The top curve is the penalty for using μ=0.9. The next one tunes momentum over non-negative values and achieves a speed-up of about 2.5x in the fully asynchronous case, using 16 workers. When we allow for tuning over negative momentum values, the penalty for 16 workers improves by another 2x. The bottom plot shows the momentum values selected by the latter tuning process. This is a preliminary result, but it suggests that negative momentum can reduce the statistical penalty further compared to non-negative tuning.

CIFAR negative momentum

Conclusion

A few take-away points:

What's next?


Background

Parallelization

When you scale, you still need to aggregate what you've learned to a single model. You have two major styles of parallelization: synchronous-parallel (left) and asynchronous-parallel, aka. HogWild! (right).

Synchronous-parallel Asynchronous-parallel

Synchronous parallel in a nutshell

All workers read the current parameter value from a parameter server and start computing their individual gradients together. Workers that are done block at a synchronization barrier until all workers have finished. The gradients are then aggregated and communicated to the parameter server. The server then replies with the new parameter value and the process starts over. This method has the benefit that workers compute their gradients based on the latest parameter value, however waiting at the barrier can hurt performance.

Asynchronous-parallel in a nutshell

Workers read parameter values and send their updates to parameter server as soon as they are ready. Then they read the new parameter values and proceed with the next mini-batch without waiting for other workers to finish. This results into some workers computing gradients based on stale parameter values, however there are no stalls due to synchronization.

Side note: If you think about it, synchronous updates are really just increasing your batch size. Does a larger batch size really increase performance? If not, does it make sense to scale up with a fixed total batch size? Each worker will end up dealing with a handful of samples at every step.

Optimization

The basic update of SGD looks as follows.

SGD

where f is typically some loss, wt is the parameter value, αt denotes the learning rate and zit is the mini-batch used to evaluate the gradient at time t. SGD with momentum, aka momentum is a very popular variant that often improves convergence speed by adding memory to the updates.

"Those who forget their update history are doomed to repeat it."
B.T. Polyak (probably not, but maybe?)

The current parameter update is a scaled down version of the previous update plus the gradient term of SGD.

SGD with momentum

As we see in the body of this post, the optimal value for the momentum parameter, μ, depends on a number of factors.

Decomposing Performance

The main performance metric is the wall-clock time to reach a certain training loss accuracy. In order to better understand how our design choices affect this metric we decompose it into two factors: hardware and statistical efficiency.

Performance factors

The first factor, the number of steps to convergence, is mainly influenced by algorithmic choices and improvements. This factor leads to the notion of statistical efficiency.

Statistical efficiency: The normalized number of steps to convergence.

The second factor, the time to finish one step, is mainly influenced by hardware and system optimizations. It leads to the notion of hardware efficiency.

Hardware efficiency: The normalized amount of time per step.

These two factors of performance are ideal for capturing the tradeoff we face when selecting the amount of asynchrony to use in a system. Synchronous methods have better statistical efficiency, since all gradients have 0 staleness; they however suffer from worse hardware efficiency due to waiting at the synchronization barrier. Asynchronous methods provide worse statistical efficiency, but enjoy significant gains in terms of hardware efficiency: there are no stalls.