LLaMaFlex: Many-in-one LLMs via Generalized Pruning and Weight Sharing

Authors: Ruisi Cai, Saurav Muralidharan, Hongxu Yin, Zhangyang Wang, Jan Kautz, Pavlo Molchanov

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We validate our method on Llama 3.1 8B (Dubey et al., 2024), a prominent open-source large language model trained on 15 trillion tokens. The model comprises N = 32 Transformer blocks, each containing a Multi-head Attention (MHA) layer with 8 attention groups and NA = 32 query heads, along with a Multilayer Perceptron (MLP) with an intermediate dimension of D = 14436. The hidden feature dimension is H = 4096. The model has a total of 6.98 billion non-embedding parameters, with 5.64 billion parameters allocated to the MLP layers. Since the original training data is not publicly available, we use a proprietary dataset consisting of high-quality pretraining data. With 4 choices of bj, we sample different tokens for different budget goals (see details in Equ. 8), and use 60.4 billion training tokens in total. Downstream Tasks We evaluate LLAMAFLEX on several downstream tasks including ARC-easy (Clark et al., 2018), LAMBADA (Paperno et al., 2016), PIQA (Bisk et al., 2020), Wino Grande (Sakaguchi et al., 2021), MMLU (Hendrycks et al., 2020), and Hella Swag (Zellers et al., 2019). Following the approach of (Xia et al., 2023), we report 5-shot performance for MMLU and 10-shot performance for Hella Swag, while presenting zero-shot results for the other tasks. Table 2: Downstream task evaluation of LLAMAFLEX framework. Here, #Params refers to the number of nonembedding parameters. The prefix Exp. indicates that the sub-models are explicitly trained, i.e., b B, while prefix Inter. refers to models generated through router interpolation. Table 4: Results of our ablation study on policy-aware modulation. We run LLAMAFLEX for 800 iterations and report the validation loss. All sub-networks show improved performance when modulation is enabled, reducing validation loss by 0.08 on average. We provide more experimental results in Appendix A.
Researcher Affiliation Collaboration Ruisi Cai1, 2, Saurav Muralidharan1 , Hongxu Yin1, Zhangyang Wang2, Jan Kautz1, Pavlo Molchanov1 1 NVIDIA 2 The University of Texas at Austin
Pseudocode No The paper describes methods using mathematical formulations (e.g., Equation 1, 2, 5, 6, 8, 9, 10) and text descriptions, but it does not include any clearly labeled pseudocode or algorithm blocks.
Open Source Code No The paper mentions "Tensor RT-LLM (NVIDIA, 2023) and llama.cpp" as common frameworks that LLAMAFLEX produces uniform architectures for, but it does not state that the code for LLAMAFLEX itself is open-source or provide a link to its implementation.
Open Datasets Yes We validate our method on Llama 3.1 8B (Dubey et al., 2024), a prominent open-source large language model trained on 15 trillion tokens. Downstream Tasks We evaluate LLAMAFLEX on several downstream tasks including ARC-easy (Clark et al., 2018), LAMBADA (Paperno et al., 2016), PIQA (Bisk et al., 2020), Wino Grande (Sakaguchi et al., 2021), MMLU (Hendrycks et al., 2020), and Hella Swag (Zellers et al., 2019).
Dataset Splits Yes Downstream Tasks We evaluate LLAMAFLEX on several downstream tasks including ARC-easy (Clark et al., 2018), LAMBADA (Paperno et al., 2016), PIQA (Bisk et al., 2020), Wino Grande (Sakaguchi et al., 2021), MMLU (Hendrycks et al., 2020), and Hella Swag (Zellers et al., 2019). Following the approach of (Xia et al., 2023), we report 5-shot performance for MMLU and 10-shot performance for Hella Swag, while presenting zero-shot results for the other tasks.
Hardware Specification No The paper does not provide specific hardware details such as GPU models, CPU types, or memory specifications used for running the experiments. It only mentions the use of existing LLM frameworks like Tensor RT-LLM for deployment.
Software Dependencies No The paper mentions "Tensor RT-LLM (NVIDIA, 2023) and llama.cpp" as deployment frameworks, but does not provide specific version numbers for these or any other software dependencies used in their experimental setup (e.g., Python, PyTorch, CUDA versions).
Experiment Setup Yes Unless otherwise specified, we set the sequence length to 4096 and the batch size to 128, and fine-tune the model for 28800 iterations. We set the initial learning rate to 4e 5, and use cosine learning rate decay for LLM parameters. Router Architecture and Modulation Details For each architectural variable, we use a two-layer MLP followed by a shared embedding layer. The embedding layer converts the scalar bj into an embedding vector of dimension 128, and is shared across all architectural choices. Each MLP processes an intermediate dimension of 128 and outputs a logits vector, with a dimension corresponding to the number of elastic choices. The routers contains 0.51 million parameters in total. We use Gumbel-Softmax to optimize the routers, following the practice outlined in Mask LLM Fang et al. (2024). Specifically, we exponentially decay the temperature with rate 0.9999, and linearly scale the scaling factor κ from 1 to 10. We set the initial learning rate of 4e 2 for router tuning. During the tuning process, we employ a combination of both soft and hard Gumbel-Softmax techniques (Jang et al., 2016). For each elastic choice, we modulate the elastic output by a sinusoidal embedding, of size 16, followed by a learnable lightweight MLP with an intermediate dimension of 128. Every modulation network only contains 0.02 million parameters. We also set the initial learning rate of 4e 2 for modulation networks.