Stabilizing Sharpness-Aware Minimization Through A Simple Renormalization Strategy
Authors: Chengli Tan, Jiangshe Zhang, Junmin Liu, Yicheng Wang, Yunda Hao
JMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this section, we present the empirical results on a range of tasks. From the perspective of algorithmic stability, we first investigate how SSAM ameliorates the issue of training instability with realistic data sets. We then provide the convergence results on a quadratic loss function. To demonstrate that the increased stability does not come at the cost of performance degradation, we also evaluate it on tasks such as training deep classifiers from scratch. The results suggest that SSAM can achieve comparable or even superior performance compared to SAM. For completeness, sometimes we also include the results of the standard formulation of SAM proposed by Foret et al. (2021) and denote it by SAM . CIFAR-10 and CIFAR-100. Here we adopt several popular backbones, ranging from basic Res Nets (He et al., 2016) to more advanced architectures such as Wide Res Net (Zagoruyko and Komodakis, 2016), Res Ne Xt (Xie et al., 2017), and Pyramid Net (Han et al., 2017). |
| Researcher Affiliation | Collaboration | Chengli Tan EMAIL School of Mathematics and Statistics, Northwestern Polytechnical University, Shaanxi, Xi an, 710129, China SGIT AI Lab, State Grid Corporation of China, Shaanxi, Xi an, 710054, China |
| Pseudocode | Yes | Algorithm 1 SSAM Optimizer |
| Open Source Code | Yes | A Py Torch implementation is available at https://github.com/cltan023/stablesam2024. |
| Open Datasets | Yes | Finally, we demonstrate the improved performance of SSAM on several representative data sets and tasks. Keywords: deep neural networks, sharpness-aware minimization, expected risk analysis, uniform stability, stochastic optimization... Similar trends are also observed when we replace the synthetic inputs with real data sets like MNIST and CIFAR-10 (see Appendix A). |
| Dataset Splits | Yes | Beyond the training and test set, we also construct a validation set containing 5000 images out of the training set. Moreover, we only employ basic data augmentations such as horizontal flip, random crop, and normalization. We set the mini-batch size to be 128 and each model is trained up to 200 epochs with a cosine learning rate decay (Loshchilov and Hutter, 2016). The default base optimizer is SGD with a momentum of 0.9. To determine the best choice of hyper-parameters for each backbone, slightly different from Kwon et al. (2021); Kim et al. (2022), we first use SGD to grid search the learning rate and the weight decay coefficient over {0.01, 0.05, 0.1} and {1.0e-4, 5.0e-4, 1.0e-3}, respectively. For SAM and the variants, these two hyper-parameters are then fixed. |
| Hardware Specification | Yes | We basically follow the official instructions, except that we only use four NVIDIA Ge Force RTX 4090s. |
| Software Dependencies | No | The paper mentions 'Pytorch package' and 'timm library' but does not specify their version numbers in the text. |
| Experiment Setup | Yes | We set the mini-batch size to be 128 and each model is trained up to 200 epochs with a cosine learning rate decay (Loshchilov and Hutter, 2016). The default base optimizer is SGD with a momentum of 0.9. To determine the best choice of hyper-parameters for each backbone, slightly different from Kwon et al. (2021); Kim et al. (2022), we first use SGD to grid search the learning rate and the weight decay coefficient over {0.01, 0.05, 0.1} and {1.0e-4, 5.0e-4, 1.0e-3}, respectively. For SAM and the variants, these two hyper-parameters are then fixed. As suggested by Kwon et al. (2021), the perturbation radius ρ of ASAM needs to be much larger, and we thus range it from {0.5, 1.0, 2.0}. In contrast, we sweep the perturbation radius ρ of other optimizers over {0.05, 0.1, 0.2}. We run each model with three different random seeds and report the mean and the standard deviation of the accuracy on the test set. |