post by Ioannis Mitliagkas, Dan Iter and Chris Ré.
Includes results from a body of work performed by Stefan Hadjis, Ce Zhang and the authors.
Deep learning is all the rage, and the method to solve the underlying learning problem is a technique called Stochastic Gradient Descent (SGD). Often, people use momentum, which is a parameter meant to help the algorithm converge faster. While everyone tunes the step-size, for some reason, many people have decided to fix the value of momentum to 0.9. Can we converge faster by tuning it? Yep!
Point 1: Tune your momentum! The optimal setting for momentum depends on the data, the task and, as we'll see, the degree of asynchrony. So, tune momentum!
To scale deep learning systems to many workers, one popular technique sounds a bit crazy: remove all the locks! Such methods are called "asynchronous-parallel methods" or Hogwild!. These methods are used by real companies like Microsoft and Google—not just joker academics like us. However, they are not well understood. We recently described a new theoretical link between asynchrony and momentum.
Point 2: Asynchrony implicitly increases the momentum.
Our companion blog post gives a nice introduction to the theory behind this new line of work.
These results hold on many systems: on our prototype, on MXNET, and on TensorFlow. We've verified them from AlexNet to Inception v3 to Nihilism 4.0 (ok, I made that last one up). The point is that if you don't tune then your results can change qualitatively. For example, the Adam paper tunes the momentum and shows asynchrony is the method of choice. In contrast the latest TensorFlow paper does not tune momentum and concludes synchronous is faster.
In this post, we're not arguing about a particular strategy; we've explored different levels of parallelism and some asynchrony can be faster than either extreme. Here, our goal is to contribute one point to the discussion of how to build deep learning systems... and get us one step closer to the singularity? Either way, just remember:
"Tune your f*%&ing momentum."
Ray knows tuning is the key to the singularity.
This post contains a quick summary of results to make the above point. If you're interested in the theoretical exposition, try our theory post.
Here is the typical SGD rule people use to train deep learning models:
where f is the loss, wt is the parameter value, αt denotes the learning rate and zit is the mini-batch used to evaluate the gradient at time t. Here μ is the parameter called momentum. Intuitively, it acts like a fuzzy memory pushing the next iteration in the same direction as the last step.
The larger the μ, the bigger the push.
Among deep learning practitioners—and some theoreticians—"momentum" is a synonym for 0.9. A large number of papers and tutorials prescribe 0.9 (paper, paper, paper, tutorial, tutorial) or a simple schedule that is independent of any other aspects of the system (tutorial, tutorial, tutorial, tutorial). It's worth mentioning that everyone tunes other parameters such as the step size αt.
So what happens if we tune the momentum? We see that tuning can significantly reduce the number of steps required to reach a given loss. To keep this short, we're going to report numbers to a fixed loss. But there is no funny business: they all achieve the same world-beating, cat-recognizing, AI-overlord-producing accuracy.
We first show results on CIFAR. We vary the number of workers and report the number of iterations to reach a particular loss divided by the number of iterations it took the synchronous case: we call this a statistical 'penalty'—note the iterations could be much faster without locking! We first draw the penalty curve we get by using the standard value of momentum, 0.9. Then, we draw the same curve when we grid-search momentum. Check out our paper for more details on the experimental setup.
The first plot shows the penalty incurred under three different tuning strategies; the second shows the optimal momentum that we found.
Tuning momentum makes a huge difference: Tuning using non-negative momentum achieves a speed-up of about 2.5x when using 16 workers, and when we allow negative momentum values, the penalty for 16 workers reduces by another 2x. For any fixed number of workers, they run at exactly the same speed—only the tuned momentum converges MUCH faster.
The next figure shows the tuned curve for ImageNet. Here we see that, for some datasets, tuning momentum can eliminate the penalty for asynchrony.
So the take away:
Point 1: "Tune your f*%&ing momentum."
One question stands out: Why is the momentum changing as we change the number of workers? This is a little puzzling, but here is a rough way to think about it: more asynchrony increases the staleness of the updates. Since momentum is a kind of a fuzzy memory, increased staleness may be like having a longer fuzzy memory? Kind of, and we can be more precise...
Our theorem predicts that, when training asynchronously, there are two sources of momentum:
The total effective momentum is the sum of explicit and implicit terms.
The second important theoretical prediction is more workers introduce more momentum.
Consider the following thought experiment. We vary the number of asynchronous workers and, in each case, we tune for the optimal explicit momentum. According to the model we would see a picture like the following:
There is an optimal level of momentum equal to the tuned value for a single worker, the synchronous case. When we introduce a second worker, asynchrony induces non-zero momentum (orange bars). As we add more workers, we get more implicit momentum.
So do we see this in experiments? Yes! We see it on our system as well as TensorFlow and MXnet—so there is no trickery in the implementation itself. Here, we grid search the learning rate and momentum for a number of different configurations and report the optimal explicit momentum values below for CIFAR and ImageNet.
We see the same monotonic behavior we expect from theory. As we increase the level of asynchrony we have to decrease the amount of explicit momentum.
Point 2: Αsynchrony introduces momentum implicitly, so you'd better tune momentum if you use asynchrony!
This is a new line of work for us, and we don't have many answers!