How to set AdamW's weight decay as you scale model and dataset size
📈
Overview
- The paper shows that the weights learned by the AdamW optimization algorithm can be understood as an exponential moving average (EMA) of recent updates.
- This insight provides critical guidelines for setting the weight decay hyperparameter in AdamW, which should scale with the model and dataset size.
- The optimal EMA timescale, which determines the weight decay, should be on the order of the number of training epochs and not change much as the model and dataset scale.
Plain English Explanation
The paper examines the AdamW optimization algorithm, a popular method for training large machine learning models. The key insight is that the weights learned by AdamW can be viewed as an exponential moving average (EMA) of the recent updates made during training.
This EMA interpretation provides important guidelines for setting the weight decay hyperparameter in AdamW. The weight decay controls how quickly the model "forgets" past updates and is a critical setting that impacts the model's performance. The authors show that the optimal weight decay should scale with the model and dataset size.
Specifically, the EMA timescale - the number of recent iterations the EMA averages over - is a key parameter that determines the weight decay. The authors argue that the EMA timescale should be on the order of the total number of training epochs, but not much smaller or larger. This is because we want the EMA to average over all the training data, but also to forget about very early updates.
Importantly, this implies that as the dataset size increases, the optimal weight decay should decrease. And as the model size increases, the optimal weight decay should increase (following the recommendation to also scale the learning rate with model size).
The paper validates these guidelines empirically and shows they are consistent with the hyperparameter choices made in recent large-scale language model pretraining efforts like Llama and Stable LM. This provides confidence that the insights around AdamW and weight decay are important for successfully scaling up machine learning models.
Technical Explanation
The paper provides a theoretical analysis showing that the weights learned by the AdamW optimization algorithm can be understood as an exponential moving average (EMA) of the recent parameter updates. This EMA interpretation leads to critical insights about how to set the weight decay hyperparameter in AdamW.
Specifically, the authors show that there is a one-to-one mapping between the EMA timescale (the number of recent iterations the EMA averages over) and the usual weight decay hyperparameter, given a fixed learning rate. Intuitively, the EMA timescale can be thought of as the "memory" of the optimization algorithm - how many recent updates it remembers.
The authors argue that the EMA timescale should be on the order of the total number of training epochs, but not much smaller or larger. This is because we want the EMA to average over all the training data, but also to forget about very early updates that may no longer be relevant. Empirically, the authors find that optimal EMA timescales are consistent with these guidelines.
Critically, this implies that as the dataset size increases, the optimal weight decay should decrease, since the same number of epochs will cover more data. And as the model size increases, the optimal weight decay should increase, following the recommendation to also scale up the learning rate.
The authors validate these insights through experiments and show that the hyperparameter choices made in recent large-scale language model pretraining runs, such as Llama and Stable LM, are consistent with the guidelines derived from the EMA interpretation of AdamW.
Critical Analysis
The paper provides a compelling theoretical analysis and empirical validation of the relationship between the EMA timescale, weight decay, and model and dataset scaling in the context of the AdamW optimization algorithm. The insights offered are potentially quite impactful, as they provide clear guidelines for tuning a critical hyperparameter in large-scale machine learning models.
One potential limitation of the work is that it focuses exclusively on the AdamW algorithm, leaving open the question of whether the insights generalize to other optimization methods. Additionally, the paper does not delve into the potential reasons why the EMA timescale guidelines would hold, beyond the intuitive explanations provided. A deeper exploration of the underlying dynamics and theoretical justifications could further strengthen the conclusions.
It would also be valuable to see the guidelines validated on a broader range of model architectures and tasks, beyond just the language models discussed. Applying the insights to other domains, such as computer vision or reinforcement learning, could help establish their broader applicability.
Finally, the paper does not address potential issues that could arise from the proposed scaling of weight decay, such as interactions with other hyperparameters or potential negative impacts on model performance or generalization. A more comprehensive analysis of the practical implications and potential drawbacks would be a useful addition.
Overall, the paper presents an important and insightful contribution to the understanding of optimization algorithms and their role in scaling up machine learning models. The clear guidelines and empirical validation make it a valuable resource for practitioners working on large-scale model development.
Conclusion
This paper provides a critical analysis of the AdamW optimization algorithm, showing that the weights it learns can be understood as an exponential moving average (EMA) of recent updates. This EMA interpretation leads to important guidelines for setting the weight decay hyperparameter in AdamW, which should scale with the model and dataset size.
Specifically, the authors show that the optimal EMA timescale - the number of recent iterations the EMA averages over - should be on the order of the total number of training epochs, but not much smaller or larger. This implies that as the dataset size increases, the optimal weight decay should decrease, while as the model size increases, the optimal weight decay should increase.
These insights are validated empirically and shown to be consistent with the hyperparameter choices made in recent large-scale language model pretraining efforts. The work offers valuable guidance for practitioners working on scaling up machine learning models, and highlights the importance of understanding the underlying dynamics of optimization algorithms.
This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!
0
Related Papers
📈
0
How to set AdamW's weight decay as you scale model and dataset size
Xi Wang, Laurence Aitchison
We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. Given a fixed learning rate, there is a one-to-one mapping from the EMA timescale to the usual weight decay hyperparameter. Thus, choosing an EMA timescale implicitly sets the weight decay. Importantly, there are natural guidelines for sensible values for the EMA timescale: we need to average over all datapoints, so the EMA timescale should not be (much) smaller than 1 epoch, and we need to forget early updates, so the EMA timescale should not be (much) bigger than the total number of training epochs. In our experiments, we find that optimal EMA timescales are consistent with these guidelines, as are the hyperparameters chosen in recent large-scale LLM pretraining runs (e.g. Llama 1+2 and Stable LM). Critically, these guidelines suggest that the optimal EMA timescale should not change (much) as we scale the model and dataset. That implies that as the dataset size increases, the optimal weight decay should fall. Moreover, as the model size increases, the optimal weight decay should also increase (if we follow the muP recommendation for scaling the learning rate).
Read more5/24/2024
📈
0
Adam with model exponential moving average is effective for nonconvex optimization
Kwangjun Ahn, Ashok Cutkosky
In this work, we offer a theoretical analysis of two modern optimization techniques for training large and complex models: (i) adaptive optimization algorithms, such as Adam, and (ii) the model exponential moving average (EMA). Specifically, we demonstrate that a clipped version of Adam with model EMA achieves the optimal convergence rates in various nonconvex optimization settings, both smooth and nonsmooth. Moreover, when the scale varies significantly across different coordinates, we demonstrate that the coordinate-wise adaptivity of Adam is provably advantageous. Notably, unlike previous analyses of Adam, our analysis crucially relies on its core elements -- momentum and discounting factors -- as well as model EMA, motivating their wide applications in practice.
Read more5/29/2024
1
The AdEMAMix Optimizer: Better, Faster, Older
Matteo Pagliardini, Pierre Ablin, David Grangier
Momentum based optimizers are central to a wide range of machine learning applications. These typically rely on an Exponential Moving Average (EMA) of gradients, which decays exponentially the present contribution of older gradients. This accounts for gradients being local linear approximations which lose their relevance as the iterate moves along the loss landscape. This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients. Building on this observation, we propose AdEMAMix, a simple modification of the Adam optimizer with a mixture of two EMAs to better take advantage of past gradients. Our experiments on language modeling and image classification show -- quite surprisingly -- that gradients can stay relevant for tens of thousands of steps. They help to converge faster, and often to lower minima: e.g., a $1.3$B parameter AdEMAMix LLM trained on $101$B tokens performs comparably to an AdamW model trained on $197$B tokens ($+95%$). Moreover, our method significantly slows-down model forgetting during training. Our work motivates further exploration of different types of functions to leverage past gradients, beyond EMAs.
Read more9/6/2024
0
Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations
Alexander Hagele, Elie Bakouch, Atli Kosson, Loubna Ben Allal, Leandro Von Werra, Martin Jaggi
Scale has become a main ingredient in obtaining strong machine learning models. As a result, understanding a model's scaling properties is key to effectively designing both the right training setup as well as future generations of architectures. In this work, we argue that scale and training research has been needlessly complex due to reliance on the cosine schedule, which prevents training across different lengths for the same model size. We investigate the training behavior of a direct alternative -- constant learning rate and cooldowns -- and find that it scales predictably and reliably similar to cosine. Additionally, we show that stochastic weight averaging yields improved performance along the training trajectory, without additional training costs, across different scales. Importantly, with these findings we demonstrate that scaling experiments can be performed with significantly reduced compute and GPU hours by utilizing fewer but reusable training runs. Our code is available at url{https://github.com/epfml/schedules-and-scaling/}.
Read more10/18/2024