Fast and Effective Weight Update for Pruned Large Language Models
Authors: Vladimír Boža
TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | 4 Experiments General setup. We implement our algorithms by extending the Wanda (Sun et al., 2023) codebase, which relies on Pytorch and the Huggingface library. Similarly to Wanda and Sparse GPT, we use 128 calibration samples from the C4 training dataset (Raffel et al., 2020). We run pruning on a machine with two Quadro RTX 5000 GPUs (each with 16GB of GPU memory). Since we prune layers sequentially in order, we need only to load one layer to GPU memory at one time. This allows us to prune 70B parameter LLa MA models using a relatively small GPU. Unless stated otherwise, we prune for k = 20 iterations, using ks = 15 sparsification steps, and set the dampening factor to λ = 0.1 and ADMM penalty factor ρ = 1. We compare our methods to Wanda (Sun et al., 2023), which does not do weight update and just prunes weights with the lowest product of magnitude and activation norm, and Sparse GPT (Frantar & Alistarh, 2023), which uses multiple approximations to select pruned weight and calculating weight updates. For both methods, we use their public implementation and default hyperparameter settings. Models and evaluation. We test our methods on LLa MA (Touvron et al., 2023a) and LLa MA2 (Touvron et al., 2023b) models. Similarly to previous works (Frantar & Alistarh, 2023; Sun et al., 2023), we measure the performance of pruned models on language modeling and zero-shot tasks. Our main focus is perplexity on held-out Wiki Text (Merity et al., 2016), considered a goto metric for evaluating language model compression (Frantar & Alistarh, 2023). As an additional verification and testing, we use the same seven tasks as Wanda uses from Eleuther AI LM Harness (Gao et al., 2021). 4.1 Reconstruction error convergence As a first experiment, we study the quality of our update algorithm. We use a fixed sparsity mask derived using Wanda with 50% sparsity and observe reconstruction error convergence in one layer. We compare our algorithm to gradient-based approaches using Adam and SGD optimizers with varying learning rates. We also compare it to the Sparse GPT update (without mask selection) used in the Wanda paper. The results for selected layers of LLa MA-7b are presented in Figure 1. Our ADMM-based algorithm is superior to both gradient-based algorithms and Sparse GPT as it converges almost instantly after computing the initial XT X matrix inverse. We also note that ADMM works well with the default setting of ρ = 1 and does not require learning rate tuning, which starkly contrasts with SGD and Adam, which have different optimal learning rates in different layers. 4.2 Weight update quality comparison In this experiment, we first prune each layer of LLa MA-7B to 60% or 80% sparsity using Wanda mask selection and then update weights either using gradient-based (via Adam) or ADMM update. We select the pruning mask in a single step, i.e., we do not do any gradual mask selection. We test using 1, 10, 20, 50, Published in Transactions on Machine Learning Research (07/2024) and 100 update steps. We also test the performance of Sparse GPT weight update and, for reference, include results of running Sparse GPT with its own gradual mask selection. We measure perplexity on Wikitext and time overhead (over forward pass) for each update option. Using just one update step, we can almost beat Sparse GPT and all gradient-based algorithms (Figure 2). The ADMM update almost converges with ten update steps, while the gradient-based algorithms need more than 100 steps. ADMM is thus clearly a faster and superior weight update algorithm compared to the gradient-based update. Our algorithm also provides a better weight update than Sparse GPT weight update, and at 60% sparsity, it is even better than Sparse GPT with its own iterative mask selection. Furthermore, we explicitly compare Sparse GPT and ADMM weight updates over different weight masks. We select either Wanda or Sparse GPT mask and apply Sparse GPT or ADMM weight update (in the case of Sparse GPT mask, Sparse GPT update is no-op, and for ADMM update, we rewind weights and keep the selected mask). Results are summarized in Table 2. Our ADMM weight update is always better than Sparse GPT update. Note that, our mask selection is also better than Sparse GPT one (9.22 vs 9.92 perplexity). 4.3 Pruning LLa MA-7B Based on previous observations, we set the number of update iterations to 20, which should provide a pruning overhead similar to Sparse GPT (Table 3) and also guarantee reasonable convergence of weight updates. We compare our weight update after mask selection without gradual pruning (ADMM1), our gradual pruning Published in Transactions on Machine Learning Research (07/2024) algorithm, which computes the mask over 15 iterations (ADMM-Grad) with Wanda and Sparse GPT pruning. We prune LLa MA-7b to various sparsities and also with 2:4 structured sparsity. First, we measure Wikitext perplexity (Table 1). We see that our weight update over a fixed Wanda mask (ADMM1) produces better results than any other algorithm at 50%, 60%, and 2:4 sparsities. Note that Sparse GPT generates the pruning mask iteratively, which gives it a slight edge in higher sparsities. When selecting the mask gradually, we are superior to all previously developed algorithms, especially at higher sparsities. Finally, we measure performance on seven zero-shot tasks (we use the same selection as the authors of Wanda): Bool Q (Clark et al., 2019), RTE (Wang et al., 2018), Hella SWAG (Zellers et al., 2019), Wino Grande (Sakaguchi et al., 2021), ARC easy and challenge (Clark et al., 2018), and Openbook QA (Mihaylov et al., 2018). Our results (Table 4) show that our algorithm is superior to the previous ones except for the RTE task. We note that results for the RTE task are slightly erratic (e.g. there is better performance at 60% sparsity than at 50%). We attribute this to the small RTE dataset size (277 samples). Notably, we recover 30-40% of the performance drop of Sparse GPT on the Bool Q task at 50-70% sparsities and also on Wino Grande task using 50-60% sparsities. When using 2:4 sparsity, we recover 20-25% of the performance drop on Wino Grande and ARC-e tasks. Published in Transactions on Machine Learning Research (07/2024) Table 4: Zero shot accuracies on various tasks during pruning of LLa MA-7B Sparsity Method Bool Q RTE Hella Swag Wino Grande ARC-e ARC-c OBQA Mean 0 % Dense 75.05 66.43 56.92 69.93 75.34 41.89 34.40 59.99 50% Wanda 71.22 55.60 51.85 66.06 69.11 36.86 28.80 54.21 Sparse GPT 73.05 52.34 51.21 68.42 70.70 36.43 28.60 54.39 ADMM-Grad 73.63 52.34 52.33 69.13 70.74 37.88 30.20 55.18 60% Wanda 69.26 59.56 43.76 62.35 62.58 30.29 25.20 50.43 Sparse GPT 70.7 62.09 44.84 65.58 64.14 30.97 25.20 51.93 ADMM-Grad 72.41 58.84 46.61 66.77 64.52 31.65 26.20 52.43 70% Wanda 59.78 58.12 28.81 50.82 32.40 18.85 14.20 37.57 Sparse GPT 62.35 55.95 33.77 59.35 45.70 23.97 17.20 42.61 ADMM-Grad 66.05 53.79 36.29 59.74 50.84 25.50 18.60 44.40 80% Wanda 37.82 48.37 26.29 48.77 27.23 20.56 13.00 31.72 Sparse GPT 41.89 52.70 27.83 48.38 30.30 18.77 13.40 33.32 ADMM-Grad 56.14 52.70 28.75 50.74 31.56 18.94 12.40 35.89 2:4 Wanda 69.3 51.99 42.06 62.75 60.94 28.07 24.60 48.53 Sparse GPT 70.46 60.65 42.99 64.88 61.49 30.12 23.60 50.60 ADMM-Grad 70.27 55.59 44.88 66.14 64.18 30.97 25.20 51.03 Table 5: Perplexity of pruned LLa MA-2 variants on Wiki Text Method Sparsity 7B 13 B 70B Dense 0 % 5.12 4.57 3.12 Wanda 50 % 6.42 5.56 3.98 Sparse GPT 50 % 6.51 5.63 3.98 ADMM-Grad 50 % 6.33 5.52 3.95 Wanda 60 % 9.71 7.75 4.98 Sparse GPT 60 % 9.58 7.80 4.98 ADMM-Grad 60 % 8.70 7.09 4.81 Wanda 80 % 5e3 2e3 1e2 Sparse GPT 80 % 108.87 94.23 25.86 ADMM-Grad 80 % 55.93 43.58 18.84 Wanda 2:4 11.02 8.27 5.16 Sparse GPT 2:4 10.17 8.32 5.40 ADMM-Grad 2:4 9.74 7.78 5.19 |
| Researcher Affiliation | Academia | Vladimír Boža EMAIL Faculty of Mathematics, Physics and Informatics, Comenius University, Bratislava, Slovakia |
| Pseudocode | Yes | Algorithm 1 Layerwise gradual pruning with ADMM. Given weight matrix W, calibration input X, desired sparsity sf, number of iterations k, number of sparsification steps ks, dampening factor λ (usually 0.1) and penalty factor ρ (usually 1), we prune matrix W to desired sparsity and accuratelly update weights for the given weight mask. |
| Open Source Code | Yes | The code is available at https://github.com/fmfi-compbio/ admm-pruning. |
| Open Datasets | Yes | Similarly to Wanda and Sparse GPT, we use 128 calibration samples from the C4 training dataset (Raffel et al., 2020). Our main focus is perplexity on held-out Wiki Text (Merity et al., 2016), considered a goto metric for evaluating language model compression (Frantar & Alistarh, 2023). As an additional verification and testing, we use the same seven tasks as Wanda uses from Eleuther AI LM Harness (Gao et al., 2021). |
| Dataset Splits | Yes | Similarly to Wanda and Sparse GPT, we use 128 calibration samples from the C4 training dataset (Raffel et al., 2020). Our main focus is perplexity on held-out Wiki Text (Merity et al., 2016), considered a goto metric for evaluating language model compression (Frantar & Alistarh, 2023). As an additional verification and testing, we use the same seven tasks as Wanda uses from Eleuther AI LM Harness (Gao et al., 2021). |
| Hardware Specification | Yes | We run pruning on a machine with two Quadro RTX 5000 GPUs (each with 16GB of GPU memory). |
| Software Dependencies | No | The paper mentions "Pytorch and the Huggingface library" but does not specify version numbers for these software components. |
| Experiment Setup | Yes | Unless stated otherwise, we prune for k = 20 iterations, using ks = 15 sparsification steps, and set the dampening factor to λ = 0.1 and ADMM penalty factor ρ = 1. |