SAFE: Finding Sparse and Flat Minima to Improve Pruning

Authors: Dongyeop Lee, Kwanhee Lee, Jinseok Chung, Namhoon Lee

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Extensive evaluations on standard image classification and language modeling tasks reveal that SAFE consistently yields sparse networks with improved generalization performance, which compares competitively to well-established baselines. In addition, SAFE demonstrates resilience to noisy data, making it well-suited for real-world conditions.
Researcher Affiliation Academia 1POSTECH, South Korea. Correspondence to: Dongyeop Lee <EMAIL>.
Pseudocode Yes Algorithm 1 SAFE and SAFE+ algorithms Require: Target parameter count d, total train iteration T, dual-update interval K, learning rate η(t), perturbation radius ρ, penalty parameter λ, importance matrix P. 1: Initialize x(0) 2: u = 0 3: for t in T do 4: if t mod K = 0 then 5: if SAFE then 6: z = proj 0 d(x(t+1) + u) 7: else if SAFE+ then 8: z = proj P 0 d(x(t+1) + u) 9: end if 10: u = u + x(t+1) z 11: end if 12: x(t+1/2) = x(t) η(t) f x(t) + ρ f(x(t)) f(x(t)) 2 13: x(t+1) = x(t+1/2) η(t)λ(x(t) z + u) 14: end for 15: return proj 0 d(x(T ))
Open Source Code Yes The code to reproduce the results of the paper is provided in JAX2 (Bradbury et al., 2018; Heek et al., 2023) and Py Torch3 (Paszke et al., 2019). Specifically, all image classification experiments using SAFE were conducted in JAX, while Py Torch was used for LLM experiments due to better support for official implementations and pretrained checkpoints of widely adopted models such as LLa MA. To obtain baseline performances, we either run our own implementations (ADMM, GMP, Magnitude) or official ones (Sparse GPT, Wanda, ALPS), or refer to reported results from prior work (LTH from Wang et al. (2020); PBW and MLPrune from Zhou et al. (2021)). For sound comparison, when running our own or official implementations, we align key settings such as model architecture and training epochs with those used in prior works to produce reported performance. 2https://github.com/LOG-postech/safe-jax 3https://github.com/LOG-postech/safe-torch
Open Datasets Yes Extensive evaluations on standard image classification and language modeling tasks reveal that SAFE consistently yields sparse networks with improved generalization performance... We evaluate pruning performance on VGG-19 (Simonyan, 2014) and Res Net-20/321 (He et al., 2016) using a range of representative pruning methods... For this purpose, we adapt SAFE and SAFE+ to sequentially optimize the reconstruction error minimization (REM) objective for each transformer block (Shin et al., 2024), similarly to other LLM pruning techniques... We compare SAFE with state-of-the-art LLM post-training pruning methods such as Sparse GPT (Frantar & Alistarh, 2023), Wanda (Sun et al., 2024), ALPS (ADMM-based) (Meng et al., 2024), as well as magnitude pruning (Han et al., 2015) and evaluate the perplexity on Wikitext2 (Merity et al., 2022) and C4 validation sets. ... We evaluate SAFE on three representative challenges: incorrect training labels (Song et al., 2022), inference-time input corruption that arises naturally (Hendrycks & Dietterich, 2019), and corruptions that are deliberately introduced by adversaries (Szegedy et al., 2014).
Dataset Splits Yes Extensive evaluations on standard image classification and language modeling tasks reveal that SAFE consistently yields sparse networks with improved generalization performance... The final validation accuracies are provided in Figure 2 and Table 7 of Appendix C... We evaluate the sparse models trained with ADMM and SAFE, as obtained in Section 4.2, on the CIFAR-10 test set with common image corruptions and adversarial perturbations. Specifically, for common corruptions, we use CIFAR-10C (Hendrycks & Dietterich, 2019)... We compare SAFE with state-of-the-art LLM post-training pruning methods such as Sparse GPT (Frantar & Alistarh, 2023), Wanda (Sun et al., 2024), ALPS (ADMM-based) (Meng et al., 2024), as well as magnitude pruning (Han et al., 2015) and evaluate the perplexity on Wikitext2 (Merity et al., 2022) and C4 validation sets. We follow the common practice of randomly selecting 128 samples from the C4 training dataset (Raffel et al., 2020).
Hardware Specification Yes All CIFAR experiments are conducted using a single or three NVIDIA RTX 3090, where a batch size of 126 was used in this case. For CIFAR-10/100, we trained Res Net-20 for 200 epochs, and Res Net-32 and VGG-19 for 300 epochs. All experiments were conducted on a single GPU (NVIDIA A6000 or L40S) or HPU (Intel Gaudi2).
Software Dependencies No The code to reproduce the results of the paper is provided in JAX2 (Bradbury et al., 2018; Heek et al., 2023) and Py Torch3 (Paszke et al., 2019). Specifically, all image classification experiments using SAFE were conducted in JAX, while Py Torch was used for LLM experiments due to better support for official implementations and pretrained checkpoints of widely adopted models such as LLa MA.
Experiment Setup Yes We present various details of our experimental setup. All experiments are run across three different seeds, and the results are provided as the mean and the standard error. Table 5: Hyperparameter details used/searched for SAFE and SAFE+. Here, perturbation radius and dual interval were searched only in Res Net-20/CIFAR-10 and LLa Ma-2-7b, then applied across all settings and target sparsity. ... Across all experimental settings, the values for hyperparameters remain consistent, with the exception of the penalty parameter λ. We report these in Table 5. Basic hyperparameters such as learning rate, batch size, weight decay, and momentum are set to standard values commonly used in the literature (Kusupati et al., 2020; Ramanujan et al., 2020; Liu et al., 2019). For SAFE-specific hyperparameters, including perturbation radius and dual-update interval, values were optimized on Res Net-20/CIFAR-10, LLa Ma-2-7b and applied universally across all settings, with only the penalty parameter λ for image classification tasks being adjusted for each setting, which we report in Table 6. This demonstrates the general applicability of its hyperparameter values across different tasks.