Noise Stability Optimization for Finding Flat Minima: A Hessian-based Regularization Approach
Authors: Hongyang R. Zhang, Dongyue Li, Haotian Ju
TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We conduct a detailed experimental study to validate our approach and show that it can effectively regularize the Hessian and improve generalization. First, our algorithm can outperform prior approaches on sharpness-reduced training, delivering up to a 2.4% test accuracy increase for fine-tuning Res Nets on six image classification datasets. Moreover, the trace of the Hessian reduces by 15.8%, and the largest eigenvalue is reduced by 9.7% with our approach. We also find that the regularization of the Hessian can be combined with alternative regularization methods, such as weight decay and data augmentation, leading to stronger regularization. Second, our approach remains highly effective for improving generalization in pretraining multimodal CLIP models and chain-of-thought fine-tuning. |
| Researcher Affiliation | Academia | Hongyang R. Zhang EMAIL Northeastern University, Boston; Dongyue Li EMAIL Northeastern University, Boston; Haotian Ju EMAIL Northeastern University, Boston |
| Pseudocode | Yes | Algorithm 1 Noise stability optimization (NSO) for regularizing the Hessian of neural networks |
| Open Source Code | Yes | The experiment code for reproducing our empirical findings can be found online at: https://github.com/Virtuoso Research/Noise-stability-optimization. |
| Open Datasets | Yes | We will fine-tune a pretrained Res Net-34 on several image classification datasets, including aircraft recognition (Aircraft) (Maji et al., 2013), indoor scene recognition (Caltech-256) (Griffin et al., 2007), and medical image classification (retina images for diabetic retinopathy classification) (Pachade et al., 2021). ... In Table 3, we report the comparison between NSO, SGD, SAM, unnormalized SAM (USAM), and adaptive SAM (ASAM). We fine-tune the Res Net-34 network on six image classification datasets [CIFAR-10, CIFAR-100, Aircrafts, Caltech-256, Indoor, Retina] ... We apply our approach to pretraining randomly initialized models by replacing SGD to train contrastive language-image (CLIP) models on a dataset of image-caption pairs. In particular, we use the Conceptual Caption dataset ... We fine-tune pretrained GPT-2 models on two question-answering datasets: Commonsense QA and Strategy QA. |
| Dataset Splits | Yes | Table 3: Comparison between our approach (NSO) with SGD, sharpness-aware minimization (SAM), unnormalized SAM (USAM), and adaptive SAM (ASAM). We fine-tune the Res Net-34 network on six image classification datasets and report the test accuracy and the trace of Hessian using the model in the last epoch of training. The results are averaged over five random seeds. Basic Statistics # Training # Validation # Test # Classes [For CIFAR-10: 45,000, 5,000, 10,000, 10; For CIFAR-100: 45,000, 5,000, 10,000, 100; For Aircrafts: 3,334, 3,333, 3,333, 100; For Caltech-256: 7,680, 5,120, 5,120, 256; For Indoor: 4,824, 536, 1,340, 67; For Retina: 1,396, 248, 250, 5] |
| Hardware Specification | No | The paper does not provide specific hardware details such as GPU models, CPU types, or memory amounts used for running experiments. |
| Software Dependencies | No | The paper does not explicitly list any software names with specific version numbers, such as Python, PyTorch, or CUDA versions. |
| Experiment Setup | Yes | Appendix C: Experiment Details. We describe the setup for Figure 2... We use a 12-layer Vision Transformer as the image encoder and a 12-layer GPT-2 transformer as the text encoder. ... We report the hyper-parameters for the experiments in Section 3. These include a learning rate of 0.0002, momentum of 0.99, weight decay of 0.0001, batch size of 32, and training epochs of 60. We reduce the learning rate by 0.1 every 20 epochs. We choose these hyper-parameters based on a grid search on the validation split. The range in which we conduct a grid search is as follows: Learning rate: 0.005, 0.002, 0.001, 0.0005, 0.0002, and 0.0001; Momentum: 0.9, 0.95, 0.99; Weight decay: 0.01, 0.001, 0.0001; Epochs: 20, 40, and 60; Batch size: 16, 32, and 64. |