Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape View
Authors: Kaiyue Wen, Zhiyuan Li, Jason Wang, David Hall, Percy Liang, Tengyu Ma
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We evaluate the effectiveness of WSD-S with extensive experiments on LLMs from 0.1B to 1.2B parameters in a continual learning setting with 50B, 100B, and 200B tokens as the three target compute budgets. We empirically show that WSD-S has performance comparable with independent oracle runs with cosine learning rate schedules optimally tuned for each of the three budgets. |
| Researcher Affiliation | Academia | Kaiyue Wen1 Zhiyuan Li2 Jason Wang1 David Hall1 Percy Liang1 Tengyu Ma1 1 Stanford University 2 Toyota Technological Institute at Chicago EMAIL EMAIL EMAIL EMAIL EMAIL EMAIL |
| Pseudocode | No | The paper describes methods using mathematical equations and textual descriptions, but no explicit pseudocode or algorithm blocks are provided. |
| Open Source Code | No | We use a TPU v3-256 model to train the model with the Levanter framework in Jax (Bradbury et al., 2018; CRFM, 2024). (CRFM, 2024) refers to 'Stanford CRFM. Levanter. https://github.com/stanford-crfm/levanter, 2024.'. This indicates the use of a third-party framework, not the release of the authors' own implementation code. |
| Open Datasets | Yes | These models are trained on the Pile dataset (Gao et al., 2020) with a context length of 4096 and a batch size of 4M tokens. ...We pretrain a 124M GPT-2 model on Open Web Text. ...We reran our experiments on another dataset called DCLM (Li et al. (2024)). |
| Dataset Splits | No | The paper mentions training on various token budgets (e.g., 50B, 100B, 200B tokens) and refers to 'validation loss', implying a validation set is used. However, it does not specify the explicit percentages, sample counts, or methodology for how the datasets (The Pile, Open Web Text, DCLM) were split into training, validation, and test sets for reproducibility. |
| Hardware Specification | Yes | We use a TPU v3-256 model to train the model with the Levanter framework in Jax (Bradbury et al., 2018; CRFM, 2024). |
| Software Dependencies | No | We use a standard Adam optimizer. We use a TPU v3-256 model to train the model with the Levanter framework in Jax (Bradbury et al., 2018; CRFM, 2024). The paper mentions software like Adam, Levanter, and Jax, but does not provide specific version numbers for these components. |
| Experiment Setup | Yes | For the 0.1B and 0.3B models, we use a peak learning rate of 6e-4, and for the 0.6B and 1.2B models, we use a peak learning rate of 4e-4. These values are chosen following current empirical practice (e.g. see Groeneveld et al. (2024)). We set the minimal learning rate to 0.1 of the peak learning rate. ...with a context length of 4096 and a batch size of 4M tokens. ...We set the batch size to 1024... The fraction of time spent decaying is chosen to be 10%. The exact hyperparameters are deferred to Appendix D. These models are trained on the Pile dataset (Gao et al., 2020) with a context length of 4096 and a batch size of 4M tokens. |