Improving Language Model Distillation through Hidden State Matching

Authors: Sayantan Dasgupta, Trevor Cohn

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We show the efficacy of our method using encoder decoder (BART, m BART & T5) and encoder-only (BERT) architectures across a range of tasks from classification to summarization and translation. Our technique is competitive with the current state-of-the-art distillation methods at comparable compression rates and does not require already pretrained student models.
Researcher Affiliation Academia Sayantan Dasgupta Computing & Information Science University of Melbourne Melbourne, VIC, Australia EMAIL Trevor Cohn Computing & Information Science University of Melbourne Melbourne, VIC, Australia EMAIL
Pseudocode No The paper describes mathematical formulas and derivations for CKA and related losses but does not present any structured pseudocode or algorithm blocks.
Open Source Code Yes The Code is available on github1 1https://github.com/Sayan21/ICLR25-CKA
Open Datasets Yes We follow the experimental setup of Shleifer & Rush (2020), who perform distillation for summarization on the CNN Daily Mail (Hermann et al., 2015) and XSum (Narayan et al., 2018) datasets. We used multilingual data from m C4 (Xue et al., 2020) for all the languages the teacher m BART model covers. We distill the pre-trained m BART students for the downstream task of translation from English to Romanian using the WMT16 dataset (Bojar et al., 2016). We further run a supervised distillation on the pre-trained students for English-Spanish translation using the WMT13 corpus (Allauzen et al., 2013). We finally apply CKA loss to the task-agnostic distillation of BERT... fine-tune the distilled student with CKA loss on downstream GLUE tasks, specifically: SST-2 (Socher et al., 2013); MRPC (Dolan & Brockett, 2005), QQP and STS-B; MNLI (Williams et al., 2017), QNLI (Rajpurkar et al., 2016) and RTE (Wang et al., 2018); and COLA (Warstadt et al., 2019).
Dataset Splits Yes We used a context size of 512 and trained the students for 25 epochs, each containing 40,000 text samples of m C4, and computed the sum of CLM loss and KL divergence on the validation set of m C4 at the end of every epoch. We sample 3 million sentence pairs from the WMT13 corpus, which is 14.5 million sentences in size, without replacement for training, and then measure the BLEU score on the test set. The evaluation is performed on the test set of IWSLT2017 (8.6K). We train the model for 30 epochs, with each step involving 320,000 sample texts from the C4 training set, and compute the KL Divergence for the C4 validation set at the end of every epoch.
Hardware Specification Yes All the experiments are performed on an A100 GPU with 80GB memory. All the experiments are performed on an A40 GPU with 40GB memory.
Software Dependencies No The paper does not explicitly mention specific software dependencies with version numbers.
Experiment Setup Yes We keep the temperature at 1 unless mentioned otherwise and do not use hyperparameters to weigh the loss contributions. We use a batch size of 16 and sum over 8 batches for the computation of CKA and the other losses through gradient accumulation, making the effective batch size 256. We use the Adam optimizer with η = 1e 4 and weight decay 5e 4. The context size used for the input document is 1024, while the context size for the summary is 128. We use Adam Optimizer with η = 3e 5 and weight decay 5e 4 for all the pretraining distillation on m C4. The context size for pretraining of m BART is 512. We use a sequence length of 512 tokens during pretraining using C4 and use the Adam optimizer with learning rate η = 2e 4 and weight decay 5e 4. We use a batch size of 32 for gradient computation and then accumulate the gradient for 40 batches, resulting in a large batch size of 1280. The fine-tuning on GLUE tasks is done with the Adam optimizer with learning rate η = 3e 5 to 1e 4 and weight decay 5e 4 for a batch size of 64.