Matryoshka Quantization
Authors: Pranav Ajit Nair, Puranjay Datta, Jeff Dean, Prateek Jain, Aditya Kusupati
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate the efficacy of Mat Quant when applied to quantizing the Feed-Forward Network (FFN) parameters of standard LLMs (Gemma-2 2B, 9B, and Mistral 7B) (Vaswani et al., 2017) typically, FFN is the main latency block hence the focus on improving the most significant component s latency. Our results show that Mat Quant produces int8 and int4 models with comparable accuracy to independently trained baselines, despite the benefit of shared model parameters. Critically, the int2 models generated by Mat Quant significantly outperform their individually trained counterparts, with 4% higher accuracy on downstream tasks (Figure 1b). We also extend Mat Quant to quantize all weights of a Transformer layer. In Figure 1c, we find that quantizing with Mat Quant shifts the quantized weight distribution toward higher values, contributing to improved int2 performance. Finally, in Section 7, we also demonstrate that using an extra bit to represent outliers significantly boosts the performance for our sliced int2 models. |
| Researcher Affiliation | Industry | 1Google DeepMind 2Google Research. Correspondence to: Pranav Nair <EMAIL>, Puranjay Datta <EMAIL>, Jeff Dean <EMAIL>, Prateek Jain <EMAIL>, Aditya Kusupati <EMAIL>. |
| Pseudocode | No | The paper describes mathematical formulations and objectives (Equations 1-7) but does not contain a clearly labeled 'Pseudocode' or 'Algorithm' block with structured steps. |
| Open Source Code | No | The paper does not contain an unambiguous sentence stating the release of code for the described methodology, nor does it provide a direct link to a source-code repository. |
| Open Datasets | Yes | We use C4’s test set to calculate perplexity, and for downstream evaluations, we test on ARC-c, ARC-e (Clark et al., 2018), Bool Q (Clark et al., 2019), Hella Swag (Zellers et al., 2019), PIQA (Bisk et al., 2020), and Winogrande (Sakaguchi et al., 2020). |
| Dataset Splits | No | For Omni Quant experiments, we sample 128 examples with a sequence length of 2048 from the C4 dataset (Raffel et al., 2020) and train using a batch size of 4. We train for a total of 10M tokens for all models except the int2 baseline, where we train the model for 20M tokens (Shao et al., 2023). For QAT experiments, we sample a fixed set of 100M tokens from the C4 dataset and train all our models using a batch size of 16 and a sequence length of 8192 for a single epoch. The paper mentions sampling data for training and using C4's test set, but it does not specify explicit training/validation/test splits (e.g., percentages or sample counts) for the overall dataset used to produce these samples, nor does it describe a detailed splitting methodology for reproducibility. |
| Hardware Specification | Yes | We run all our experiments on TPUv5e chips. |
| Software Dependencies | No | The paper mentions techniques like Omni Quant and QAT, and models like Gemma2 and Mistral 7B, but does not specify version numbers for any software libraries or dependencies (e.g., Python, PyTorch, TensorFlow versions). |
| Experiment Setup | Yes | For Omni Quant experiments, we use a constant learning rate of 1e-3 and for QAT experiments, we linearly warmup the learning rate to 1e-5 for 150 and use a consine decay schedule thereafter. For Omni Quant experiments, we sample 128 examples with a sequence length of 2048 from the C4 dataset (Raffel et al., 2020) and train using a batch size of 4. We train for a total of 10M tokens for all models except the int2 baseline, where we train the model for 20M tokens (Shao et al., 2023). For Co-distillation experiments where Omni Quant is the base algorithm, we train for a total of 8.3M tokens. For QAT experiments, we sample a fixed set of 100M tokens from the C4 dataset and train all our models using a batch size of 16 and a sequence length of 8192 for a single epoch. For Attn + FFN experiments with QAT, we sample a fixed set of 300M tokens from C4 and train with a batch size of 16 for a single epoch. We use (λ8, λ4, λ2) = (0.1, 0.1, 1.0) for all our Gemma experiments unless otherwise stated. In the case of Mistral 7B, for Omni Quant experiments, we use (λ8, λ4, λ2) = (0.4, 0.4, 1.0), and for QAT experiments we use (λ8, λ4, λ2) = (0.2, 0.2, 1.0). For all our Extra Precision Mat Quant experiments, we use (λ8, λ4, λ2) = (1.0, 1.0, 1.0). |