The AdEMAMix Optimizer: Better, Faster, Older

Authors: Matteo Pagliardini, Pierre Ablin, David Grangier

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our experiments on language modeling and image classification show quite surprisingly that gradients can stay relevant for tens of thousands of steps. They help to converge faster, and often to lower minima: e.g., a 1.3B parameter Ad EMAMix LLM trained on 101B tokens performs comparably to an Adam W model trained on 197B tokens (+95%). Moreover, our method significantly slowsdown model forgetting during training.
Researcher Affiliation Collaboration Matteo Pagliardini EPFL Pierre Ablin Apple David Grangier Apple. Work done while interning at Apple.
Pseudocode Yes Algorithm 1 Ad EMAMix optimizer. Differences with Adam W are in blue. 1: Input: Data distribution D. Initial model parameters θ(0). Number of iterations T. Learning rate η. ϵ a small constant. Adam W parameters: β1, β2 and λ. Ad EMAMix parameters β3, α. Warmup parameter Tα,β3, note that we usually set it to T. βstart is usually set to β1. 2: Initialize timestep: t 0 3: Initialize EMAs: m(0) 1 0 , m(0) 2 0 , ν(0) 0 4: for t [T] do 5: t t + 1 6: Optional: use schedulers η(t), β(t) 3 fβ3(t, β3, βstart, Tα,β3) and α(t) fα(t, α, Tα,β3) 7: Sample batch: x D 8: Compute gradient: g(t) θLθ(t 1)(x) 9: Update the fast EMA m1: m(t) 1 β1m(t 1) 1 + (1 β1)g(t) 10: Update the slow EMA m2: m(t) 2 β(t) 3 m(t 1) 2 + (1 β(t) 3 )g(t) 11: Update the second moment estimate: ν(t) β2ν(t 1) + (1 β2)g(t)2 12: Apply bias corrections: ˆm(t) 1 m(t) 1 1 βt 1 , ˆν(t) 1 ν(t) 1 1 βt 2 13: Update parameters: θ(t) θ(t 1) η(t) ˆ m(t) 1 +α(t)m(t) 2 ˆν(t)+ϵ + λθ(t 1) 14: end for 15: Return optimized parameters θ(T )
Open Source Code Yes The full implementation of Ad EMAMix in Pytorch is provided in the following repository: https: //github.com/apple/ml-ademamix. The full implementation of Ad EMAMix in Jax is provided in the following repository: https://github.com/apple/ml-ademamix
Open Datasets Yes We use the Red Pajama v2 (Computer, 2023) dataset for all of our experiments. ... We use a Mamba architecture (Gu & Dao, 2023) with an embedding dimension of 768, an expansion factor of 2, a state size of 16, and a depth of 24. We use a batch size of 120 sequences of 1024 tokens. We use the Eleuther AI/gpt-neox-20b tokenizer from Black et al. (2022), using the Hugging Face library (Wolf et al., 2019). ... We use two subsets of the Image Net (Russakovsky et al., 2015) dataset: (i) the widely used Image Net-1k subset... (ii) a filtered and pre-processed version of the Image Net-21k (Ridnik et al., 2021) containing 11M images corresponding to 10,450 classes.
Dataset Splits No For the learning rate, we use 3k warmup steps followed by unless specified a cosine decay to 10 5. We use batch sizes of 64, 96 and 128 for respectively our 110M, 335M, and 1.3B parameter models. ... For each, we measure the test loss on a held-out test set. ... For our experiments on Image Net-1k, given that all models overfit the dataset, we select the best model according to the minimum validation loss, akin to using early stopping.
Hardware Specification Yes We train on up to 8 A100 NVIDIA GPUs using data-parallelism. ... We train on 8 A100 NVIDIA GPUs using Pytorch Fully Sharded Data-Parallelism (Zhao et al., 2023, FSDP).
Software Dependencies No Our implementation is using Jax (Bradbury et al., 2018), and we train using bfloat16, except for normalization modules and softmax which use float32. The optimizer states and operations are in float32. ... Our implementation is in Pytorch (Paszke et al., 2019), and use bfloat16. ... The following is a code skeleton for our Ad EMAMix optimizer in Optax, an optimization library based on Jax (Bradbury et al., 2018). ... We use the Eleuther AI/gpt-neox-20b tokenizer from Black et al. (2022), using the Hugging Face library (Wolf et al., 2019).
Experiment Setup Yes For the learning rate, we use 3k warmup steps followed by unless specified a cosine decay to 10 5. We extensively tuned the hyperparameters for both Adam W and Ad EMAMix models (see App. B.1). We use the Red Pajama v2 (Computer, 2023) dataset for all of our experiments. We use batch sizes of 64, 96 and 128 for respectively our 110M, 335M, and 1.3B parameter models. Depending on the model, we vary the number of iterations from 256k to 1.5M. For Ad EMAMix, we use β3 = 0.9999 and α {5, 8, 10} depending on the model. A full description of the architecture and hyperparameters used is in App. B.1.