Probe Pruning: Accelerating LLMs through Dynamic Pruning via Model-Probing
Authors: Qi Le, Enmao Diao, Ziyan Wang, Xinran Wang, Jie Ding, Li Yang, Ali Anwar
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Comprehensive evaluations of PP on LLa MA-2/3 and OPT models reveal that even minimal probing-using just 1.5% of FLOPs-can substantially enhance the efficiency of structured pruning of LLMs. For instance, when evaluated on LLa MA-2-7B with Wiki Text2, PP achieves a 2.56 lower ratio of performance degradation per unit of runtime reduction compared to the state-of-the-art method at a 40% pruning ratio. Our code is available at https://github.com/Qi-Le1/Probe_Pruning. Table 1: Comparison of LLM structured pruning methods. Table 2: Zero-shot performance of LLa MA-2-7B/13B and OPT-13B after pruning attention and MLP blocks without fine-tuning: PP demonstrates superior performance in nearly all scenarios. Figure 3 shows ablation study results for various probe combinations. |
| Researcher Affiliation | Academia | Qi Le1, Enmao Diao, Ziyan Wang2, Xinran Wang1, Jie Ding1, Li Yang2, Ali Anwar1 1University of Minnesota 2University of North Carolina at Charlotte EMAIL, EMAIL, EMAIL. The authors acknowledge the Minnesota Supercomputing Institute (MSI) at the University of Minnesota for providing resources that contributed to the research results reported in this paper. The work of Qi Le was supported by the Amazon Machine Learning System Fellowship. The work of Xinran Wang and Ali Anwar was supported by the 3M Science and Technology Graduate Fellowship and the Samsung Global Research Outreach Award. The work of Jie Ding was supported in part by the National Science Foundation under CAREER Grant No. 2338506 and Grant No. 2220286. The work of Ziyan Wang and Li Yang was supported by National Science Foundation under Grant No. 2348376. |
| Pseudocode | Yes | Algorithm 1: Probe Pruning Input: An LLM M with L blocks, each containing the Transformation F l, the Intermediate transformation T l, and Layer Normalization LNl; calibration dataset D; Inference batches B. System executes: Run the calibration dataset D using model M to obtain historical states V. for t-th batch Bt do Initialize the hidden state X0 for batch Bt. for each block l = 0, . . . , L 1 do Generate a probe Pl from LNl(Xl), utilizing the residual importance (Section 4.2). Use Pl to execute the intermediate transformation of block l and gather the resulting intermediate hidden states, denoted as Xl,int,probe = T l(Pl). Use importance-scaled fusion to integrate the probing states X l,int,probe with historical states (Section 4.3). Compute the PPsp pruning metric from the integrated states (Section 4.4), and subsequently prune the weight channels accordingly. Execute full inference on Xl using the pruned weights W l, denoted by F l(Xl). end end |
| Open Source Code | Yes | Our code is available at https://github.com/Qi-Le1/Probe_Pruning. |
| Open Datasets | Yes | We evaluate PP on three popular model families: LLa MA-2 7B/13B (Touvron et al., 2023), LLa MA-3 8B (Meta AI, 2024), and OPT-13B (Zhang et al., 2022). We evaluate accuracy on commonsense reasoning tasks, including Bool Q (Clark et al., 2019), PIQA (Bisk et al., 2020), Hella Swag (Zellers et al., 2019), Wino Grande (Sakaguchi et al., 2019), ARC-Easy (Clark et al., 2018), ARC-Challenge (Clark et al., 2018), and Openbook QA (Mihaylov et al., 2018). For evaluating perplexity on the text generation task, we use Wiki Text2 (Merity et al., 2016). We use the C4 (Raffel et al., 2020) dataset as the calibration dataset for all methods. |
| Dataset Splits | Yes | For the commonsense reasoning tasks, our implementation follows (Gao et al., 2021), setting the sequence length of each batch to match its longest sample. For the text generation task, we set the sequence length to 1024. For PP, FLAP An et al. (2024), and Wanda-sp An et al. (2024), we use 2,000 samples with sequence lengths of 1,024 tokens as the calibration dataset for the text generation task, and 2,000 samples with sequence lengths of 512 tokens for the commonsense reasoning task. For LLM-Pruner Ma et al. (2023), we follow the original implementation details in Ma et al. (2023). We use 10 randomly selected samples, each truncated to a length of 128 tokens, to build importance metrics, and 20,000 samples with sequence lengths of 256 tokens for recovery retraining. For Lo RAPrune Zhang et al. (2023), we follow the original implementation details in Zhang et al. (2023). We randomly sample 20,000 sentences from the C4 dataset, each having a length of 512 tokens, according to the original calibration dataset preparation process. |
| Hardware Specification | Yes | We conduct all experiments on NVIDIA A100 GPUs. Additionally, we evaluate each block’s end-to-end runtime across all batches of Wiki Text2 and the inference speedup at a 40% pruning ratio on NVIDIA A100 GPUs, similar to previous studies (Sun et al., 2023; Ma et al., 2023). |
| Software Dependencies | No | We use the Deep Speed package (Rasley et al., 2020) to measure the FLOPs. For LLM-Pruner Ma et al. (2023),... we employ the Adam W He et al. (2020) optimizer with 100 warmup steps, set the Lo RA Hu et al. (2021) rank r to 8... All training processes are optimized using the Adam W optimizer with a linear learning rate decay. |
| Experiment Setup | Yes | We set the batch size to 20 for all tasks. For the commonsense reasoning tasks, our implementation follows (Gao et al., 2021), setting the sequence length of each batch to match its longest sample. For the text generation task, we set the sequence length to 1024. For PP, we set the default probe size to 5% of the batch size and 50% of the sequence length, approximating 1.5% of the FLOPs cost relative to dense model inference. We use 2,000 calibration samples for PP, Wanda-sp, and FLAP, and 20,000 calibration samples for tuning Lo RAPrune and LLM-Pruner. For LLM-Pruner... We use 10 randomly selected samples, each truncated to a length of 128 tokens, to build importance metrics, and 20,000 samples with sequence lengths of 256 tokens for recovery retraining. Specifically, in the recovery stage, we employ the Adam W He et al. (2020) optimizer with 100 warmup steps, set the Lo RA Hu et al. (2021) rank r to 8, use a learning rate of 1e-4, a batch size of 64, and perform recovery retraining for 2 epochs. |