Avoiding spurious sharpness minimization broadens applicability of SAM

Authors: Sidak Pal Singh, Hossein Mobahi, Atish Agarwala, Yann Dauphin

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We find that in both the training regimes, our proposed algorithms, namely PRECOND FUNCTIONAL-SAM and PRECOND SAM, significantly outperform SAM, at all the model scales in Tables 2 and 3. The gains typically range from 5 10-3 to 10-2, are largest for 117.9M parameters. We find that SAM performs the worst of the lot, even worse than ADAMW while being 2x computationally expensive.
Researcher Affiliation Industry 1Google Research 2Google Deep Mind. Correspondence to: Sidak Pal Singh <EMAIL>.
Pseudocode Yes In Appendix C, we provide the JAX code snippets for SAM and FUNCTIONAL-SAM, demonstrating that the difference in their implementation is a matter of a few lines. Further, like SAM, FUNCTIONAL-SAM remains compatible with methods like ADAMW which take a gradient and then further process it. Listing 1: Illustration of how to get the gradients in the two methods. Functional SAM differs from the SAM implementation only in the last couple lines
Open Source Code Yes We used the Nanodo (Liu et al., 2024) framework to implement these minimal decoder-only Transformer models, in Flax (Heek et al., 2024) together with JAX (Bradbury et al., 2018). URL http://github.com/google-deepmind/nanodo. In Appendix C, we provide the JAX code snippets for SAM and FUNCTIONAL-SAM
Open Datasets Yes We consider next-token prediction task in the case of language using the C4 dataset and image classification in the case of vision. Since the typical vocabulary sizes in language is in the order of tens of thousands, so besides Image Net-1K, we adopt other datasets like JFT (Sun et al., 2017) and Image Net-21K (Ridnik et al., 2021) to make the settings further comparable in terms of number of outputs. For both settings, we employ Transformer-based networks, Nanodo (Liu et al., 2024), which is a simplified version of GPT-2 (Radford et al., 2019), in language modeling and Vision Transformer (Vi T, Dosovitskiy et al., 2021) for vision tasks. Furthermore, in both cases, we train with ADAMW as the optimizer and measure the normalized sharpness gradient contributions (Eqn. 9) throughout training using exact Hessian-vector products. We present the results for Nanodo, C4 as well as Vi T with Image Net-1K and JFT in Figure 2, and that on Image Net21K in Appendix A.1.
Dataset Splits No The paper uses the C4 dataset and mentions 'Validation Loss' in figures and tables, implying the use of a validation split. It also mentions 'training in an online fashion where no batch is seen more than once due to massive size of the C4 corpus'. However, it does not explicitly state the percentages, absolute sample counts, or refer to a specific predefined splitting methodology with a citation for the C4 dataset as used in their experiments.
Hardware Specification No The paper mentions 'compute budget' and 'model scales (including billion-parameter scale)' but does not provide specific details on the hardware used, such as GPU/CPU models, processor types, or memory amounts.
Software Dependencies No The paper mentions using 'ADAMW', 'Nanodo (Liu et al., 2024) framework', 'Flax (Heek et al., 2024)', and 'JAX (Bradbury et al., 2018)'. While these software components are listed and cited, specific version numbers for Flax or JAX are not provided.
Experiment Setup Yes In both scenarios, we consider a batch size of 256 sequences, of maximum length 512, and evaluate model at 5 different sizes: 2M (for prototyping), 23.9M, 42.5M, 117.9M, and 1208M in terms of non-embedding parameters (see details in Appendix A.6), and trained with ADAMW (Kingma & Ba, 2017; Loshchilov & Hutter, 2019) as the underlying optimizer on the C4 dataset (Raffel et al., 2020). We used a decoupled weight decay parameter set to 10-4 for all experimental settings. We use the Nanodo (Liu et al., 2024) framework to implement these minimal decoder-only Transformer models, in Flax (Heek et al., 2024) together with JAX (Bradbury et al., 2018). For thoroughness, at each model size separately, we simultaneously tune learning rate and perturbation strength ρ in our experiments, even though this may be difficult at even bigger model sizes or when faced with computational constraints. The optimal values of the perturbation strength for PRECOND FUNCTIONAL-SAM at the various scales are respectively 0.5, 0.5, 0.4, 0.4. (Table 2).