Cut Your Losses in Large-Vocabulary Language Models

Authors: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB.
Researcher Affiliation Industry Erik Wijmans Brody Huval Alexander Hertzberg Vladlen Koltun Philipp Krahenbuhl Apple
Pseudocode Yes Algorithm 1 Memory-efficient indexed matrix multiplication
Open Source Code Yes https://github.com/apple/ml-cross-entropy
Open Datasets Yes We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. We pretrain Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the 5% of the Open Web Text Dataset (Gokaslan et al., 2019) using CCE-Kahan-Full C and torch.compile.
Dataset Splits Yes We pretrain Qwen 2.5 7B Instruct... on the 5% of the Open Web Text Dataset... We report validation perplexity on a held-out 0.25% of Open Web Text and find that CCE-Kahan Full C produces identical curves as torch.compile (Fig. 5).
Hardware Specification Yes Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB.
Software Dependencies Yes Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB.
Experiment Setup Yes First we examine the runtime and memory of various implementations of the cross-entropy loss log softmaxxi(C E). We consider a batch of 8,192 tokens with a vocabulary size of 256,000 and hidden dimension 2,304. This corresponds to Gemma 2 (2B) (Rivi ere et al., 2024). We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. [...] We fine-tune Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the Alpaca Dataset (Taori et al., 2023) using CCE and torch.compile as the control.