When, Where and Why to Average Weights?
Authors: Niccolò Ajroldi, Antonio Orvieto, Jonas Geiping
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Averaging checkpoints along the training trajectory is a simple yet powerful approach to improve the generalization performance of Machine Learning models and reduce training time. Motivated by these potential gains, and in an effort to fairly and thoroughly benchmark this technique, we present an extensive evaluation of averaging techniques in modern Deep Learning, which we perform using Algo Perf (Dahl et al., 2023), a large-scale benchmark for optimization algorithms. We investigate whether weight averaging can reduce training time, improve generalization, and replace learning rate decay, as suggested by recent literature. Our evaluation across seven architectures and datasets reveals that averaging significantly accelerates training and yields considerable efficiency gains across all considered workloads, at the price of a minimal implementation and memory cost, while mildly improving generalization. |
| Researcher Affiliation | Academia | 1ELLIS Institute Tübingen 2Max Planck Institute for Intelligent Systems, Tübingen, Germany 3Tübingen AI Center. Correspondence to: Niccolò Ajroldi <EMAIL>. |
| Pseudocode | Yes | We report a Py Torch-style implementation of LAWA and EMA in Algorithms 1 and 2. |
| Open Source Code | Yes | We perform all experiments in Py Torch using plain LM4 pretraining codebase and model implementations. The employed model is a decoder-only 124M transformer architecture, similar to GPT-2 (Radford et al., 2019; Karpathy, 2022), enhanced with Rotational Positional Embedding (Su et al., 2023), RMSNorm (Zhang & Sennrich, 2019), and Swi GLU (Shazeer, 2020). Footnote 4: https://github.com/Niccolo-Ajroldi/plain LM |
| Open Datasets | Yes | We conduct our analysis on Algo Perf (Dahl et al., 2023), a large suite of Deep Learning workloads, which provides a comprehensive benchmark for testing optimization algorithms. ... We consider the following workloads from the Algo Perf suite: (i) a DLRMsmall model on Criteo 1TB dataset... (ii) U-Net on Fast MRI... (iii) Vi T on Image Net-1k... (iv) a GNN model on OGBG... (v) a Transformer-Big on WMT... (vi) a Conformer for speech recognition; (vii) a Deep Speech model on Libri Speech... training on 5B tokens from Fine Web Edu (Penedo et al., 2024). |
| Dataset Splits | Yes | The benchmark is composed of eight workloads, each defined by a dataset, a model architecture, and a predefined target metric on a held-out set, designed to represent optimal performance on such a workload... evaluating on a held-out split of 5M tokens. |
| Hardware Specification | Yes | We conducted experiments on 4x A100-SXM4-40GB and 4x H100-HBM3-80GB GPUs, occasionally resorting to 8x V100-32GB. |
| Software Dependencies | No | We make use of the original Algo Perf repository1, with minimal modifications, using the Py Torch (Paszke et al., 2017) implementation of the corresponding algorithms. |
| Experiment Setup | Yes | To derive a strong baseline algorithm, we perform a grid search over weight decay, gradient clipping and learning rate schedule hyperparameters, exploring both decay-to-zero and 10% decay, as reported in Table 3. ... The employed model is a decoder-only 124M transformer architecture... training on 5B tokens from Fine Web Edu (Penedo et al., 2024), using a batch size of 512, context length of 1024 tokens, and GPT-2 tokenizer. |