Relaxed Recursive Transformers: Effective Parameter Sharing with Layer-wise LoRA
Authors: Sangmin Bae, Adam Fisch, Hrayr Harutyunyan, Ziwei Ji, Seungyeon Kim, Tal Schuster
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this work, we revisit parameter sharing for LLMs, and propose novel methodologies to convert existing, unshared models into smaller, and more efficient, Recursive Transformers. These models use a single block of unique layers that are recursively reused across multiple loops, yet still achieve impressive performance relative to their reduced size. To mitigate the potential performance degradation associated with parameter sharing, we first initialize the shared block of layers based on the original model s pre-trained parameters, and then finetune the resulting recursive model for a limited number of uptraining steps. Importantly, we show that our initialization strategies allow us to achieve strong performance with minimal training time. [...] We introduce a framework for initializing and training Relaxed Recursive Transformers and demonstrate strong performance compared to non-recursive models of comparable size. For example, when we uptrained a recursive Gemma 1B model converted from a pretrained Gemma 2B (Team et al., 2024), we observed up to 13.5 absolute accuracy improvement (22% error reduction) on few-shot tasks compared to a non-recursive Gemma 1B model (pretrained from scratch). Furthermore, we show that by incorporating knowledge distillation (Hinton et al., 2015; Kim & Rush, 2016), our recursive Gemma model, uptrained on 60 billion tokens, achieves performance on par with the full-size Gemma model trained on a massive 3 trillion token corpus (see 3.3 for details). [...] Based on our Relaxed Recursive Transformer, we also evaluate a key use case for continuous depth-wise batching with early-exiting (Bae et al., 2023; Schuster et al., 2022; Elbayad et al., 2020; Graves, 2016a), which opportunistically makes predictions for samples with high confidence at earlier stages. From our simulation, Early Exits reveal a substantial throughput improvement of up to 2-3 compared to a vanilla Transformer with the same architecture. Notably, the recursive Gemma model, which outperforms the vanilla Pythia model, can theoretically achieve a nearly 4 increase in throughput (see 3.8 for details). |
| Researcher Affiliation | Collaboration | Sangmin Bae1 , Adam Fisch2, Hrayr Harutyunyan3, Ziwei Ji3, Seungyeon Kim2, Tal Schuster2 1KAIST AI 2Google Deep Mind 3Google Research EMAIL, EMAIL |
| Pseudocode | No | The paper describes methods and architectural components using mathematical formulas and descriptive text, e.g., 'hℓ t = f(hℓ 1 t ; Φ ((ℓ 1) mod L/B)+1), ℓ [1, L]' and includes figures illustrating concepts. However, there are no explicitly labeled pseudocode blocks or algorithm sections presented in a structured, step-by-step format commonly associated with pseudocode. |
| Open Source Code | No | To ensure the reproducibility of our work, we provide a comprehensive description of our model architectures in Appendix F, and details of experimental settings can be found in Appendix G. We utilized the open-source Hugging Face framework and followed established open-source frameworks for evaluation, further enhancing reproducibility. We plan to release the source codes upon publication to facilitate future research. |
| Open Datasets | Yes | We evaluate our method on three popular pretrained LLMs: Gemma 2B (Team et al., 2024), Tiny Llama 1.1B (Zhang et al., 2024b), and Pythia 1B (Biderman et al., 2023). [...] After converting to Recursive Transformers, we uptrained models on the Slim Pajama dataset (Soboleva et al., 2023). [...] The Slim Pajama dataset (Soboleva et al., 2023). Slim Pajama is an open-source dataset designed for training large language models, which is created by cleaning and deduplicating the Red Pajama dataset (Computer, 2023). |
| Dataset Splits | No | Uptraining setting To convert vanilla Transformers into Recursive Transformers, we conducted further uptraining on either 15 billion or 60 billion tokens from the Slim Pajama dataset (Soboleva et al., 2023). [...] We used the Language Model Evaluation Harness framework (Gao et al., 2023) to evaluate accuracy on seven few-shot tasks, and averaged them for performance comparison. [...] Evaluation setting We evaluated perplexity on test sets from three language modeling datasets: Slim Pajama, Red Pajama, and PG19 (Rae et al., 2019). Additionally, we used the Language Model Evaluation Harness framework (Gao et al., 2023) to evaluate accuracy on seven few-shot tasks: LAMBADA (Paperno et al., 2016), Hella Swag (Zellers et al., 2019), PIQA (Bisk et al., 2020), Wino Grande (Sakaguchi et al., 2020), ARC-easy and ARC-challenge (Clark et al., 2018), and Open Book QA (Mihaylov et al., 2018). We adhered to the standard number of shots specified by the evaluation framework for each dataset. The paper mentions using |
| Hardware Specification | Yes | Eight H100 GPUs were used for the training. [...] Using a single A100 40Gi B GPU, we measured these decoding times across different batch sizes, until an out-of-memory error occurred or under a specific memory constraint was reached. |
| Software Dependencies | No | We utilized the open-source Hugging Face framework and followed established open-source frameworks for evaluation, further enhancing reproducibility. We employed the Hugging Face training framework (Wolf et al., 2020) and enhanced memory efficiency through the Zero Redundancy Optimizer (Ze RO) (Rajbhandari et al., 2020) from the Deep Speed library (Rasley et al., 2020), along with mixed precision training. While specific frameworks and libraries are mentioned, no version numbers for these software components are provided, which is necessary for reproducible software dependencies. |
| Experiment Setup | Yes | Uptraining setting To convert vanilla Transformers into Recursive Transformers, we conducted further uptraining on either 15 billion or 60 billion tokens from the Slim Pajama dataset (Soboleva et al., 2023). We used the Language Model Evaluation Harness framework (Gao et al., 2023) to evaluate accuracy on seven few-shot tasks, and averaged them for performance comparison. Detailed experimental setup for uptraining or evaluation can be found in Appendix G. [...] The context length was set to 2048, and the batch size was approximately 2 million tokens. We used the Adam W optimizer (Loshchilov & Hutter, 2019) with a learning rate of 2e-4, utilizing a cosine annealing learning rate scheduler (Loshchilov & Hutter, 2017). Additionally, we set warmup steps to 200 for 15 billion token training and 800 for 60 billion token training. |