Cut Your Losses in Large-Vocabulary Language Models
Authors: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. |
| Researcher Affiliation | Industry | Erik Wijmans Brody Huval Alexander Hertzberg Vladlen Koltun Philipp Krahenbuhl Apple |
| Pseudocode | Yes | Algorithm 1 Memory-efficient indexed matrix multiplication |
| Open Source Code | Yes | https://github.com/apple/ml-cross-entropy |
| Open Datasets | Yes | We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. We pretrain Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the 5% of the Open Web Text Dataset (Gokaslan et al., 2019) using CCE-Kahan-Full C and torch.compile. |
| Dataset Splits | Yes | We pretrain Qwen 2.5 7B Instruct... on the 5% of the Open Web Text Dataset... We report validation perplexity on a held-out 0.25% of Open Web Text and find that CCE-Kahan Full C produces identical curves as torch.compile (Fig. 5). |
| Hardware Specification | Yes | Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB. |
| Software Dependencies | Yes | Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB. |
| Experiment Setup | Yes | First we examine the runtime and memory of various implementations of the cross-entropy loss log softmaxxi(C E). We consider a batch of 8,192 tokens with a vocabulary size of 256,000 and hidden dimension 2,304. This corresponds to Gemma 2 (2B) (Rivi ere et al., 2024). We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. [...] We fine-tune Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the Alpaca Dataset (Taori et al., 2023) using CCE and torch.compile as the control. |