Leveraging Flatness to Improve Information-Theoretic Generalization Bounds for SGD
Authors: Ze Peng, Jian Zhang, Yisen Wang, Lei Qi, Yinghuan Shi, Yang Gao
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this section, we experimentally show how our bound captures the true generalization error/gap measured by Cross Entropy (0-1 loss does not have Hessians and is motivationally incompatible) compared to the existing bounds. We vary the hyperparameter and train 6 independent Res Net-18 models on CIFAR-10 at each hyperparameter. |
| Researcher Affiliation | Academia | State Key Laboratory for Novel Software Technology, Nanjing University State Key Lab of General Artificial Intelligence, School of Intelligence Science and Technology, Peking University School of Computer Science and Engineering, Southeast University EMAIL, EMAIL, EMAIL, EMAIL |
| Pseudocode | No | The paper primarily focuses on theoretical derivations and experimental results, but does not include any explicitly labeled pseudocode blocks or algorithms. |
| Open Source Code | Yes | Codes are available at https://github.com/peng-ze/omniscient-bounds. The codes for the experiments can be found in the supplementary material or at https://github.com/peng-ze/ omniscient-bounds. |
| Open Datasets | Yes | The bound is evaluated on Res Net-18 trained by CIFAR-10. Experiments on deep neural networks show our bound... Results for MLP on MNIST can be found in Appendix C.4. |
| Dataset Splits | Yes | The 6 models are trained by 6 random splits of the training set. Terms involving population statistics (e.g., the population Hessians in the flatness terms and the population gradient in the penalty term) are approximated to the second order and estimated on validation sets. The true generalization error is estimated on a separate test set. |
| Hardware Specification | Yes | Models are trained on 12 NVidia RTX4090D GPUs for 2 day with auto mixed precision (BF16) and torch.compile() to save memory and increase parallelization. |
| Software Dependencies | No | In experiments, all Hessian traces are computed using Py Hessian (Yao et al., 2020). Models are trained... with auto mixed precision (BF16) and torch.compile(). The paper mentions Py Hessian and `torch.compile()` (implying PyTorch), but specific version numbers for these software components are not provided. |
| Experiment Setup | Yes | For both MNIST and CIFAR-10, we train k = 6 independent models from the same randomly chosen initial weight... We use 2-layer MLPs for MNIST... while Res Net-18 for CIFAR-10. Res Net-18 models are trained for 200 epochs while 2-layer MLP models are trained for 500 epochs... We start from a base hyperparameter, where the learning rate is 0.01, the batch size is 60, and no dropout or weight decay is used. For the 2-layer MLP, the base hidden width is 512. We use SGD with momentum of 0.9. We use random horizontal flip and random resized crop. |