Ask Your Distribution Shift if Pre-Training is Right for You
Authors: Benjamin Cohen-Wang, Joshua Vendrow, Aleksander Madry
TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases. After providing theoretical motivation and empirical evidence for this finding, we explore two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset.1 |
| Researcher Affiliation | Academia | Benjamin Cohen-Wang EMAIL Massachusetts Institute of Technology Joshua Vendrow EMAIL Massachusetts Institute of Technology Aleksander MÄ…dry EMAIL Massachusetts Institute of Technology |
| Pseudocode | No | The paper includes theoretical motivations and proofs (e.g., Theorem 4.1), but it does not contain a clearly labeled 'Pseudocode' or 'Algorithm' block with structured steps formatted like code. |
| Open Source Code | Yes | Code is available at https://github.com/MadryLab/pretraining-distribution-shift-robustness |
| Open Datasets | Yes | A common paradigm for developing machine learning models is pre-training them on a large, diverse dataset (e.g., Image Net (Deng et al., 2009), JFT-300M (Sun et al., 2017), LAION-5B (Schuhmann et al., 2022)) and then fine-tuning them on task-specific data. To illustrate this, we consider two distribution shifts of Image Net (Deng et al., 2009): Image Net-V2 (Recht et al., 2019) and Image Net Sketch (Wang et al., 2019). In particular, we investigate the effectiveness of this strategy on WILDS-FMo W (Christie et al., 2018; Koh et al., 2020), a distribution shift benchmark for classifying satellite images. As a case study, we consider the task of predicting hair color (blond vs. non-blond) in the Celeb A dataset (Liu et al., 2015). |
| Dataset Splits | Yes | To split a shifted dataset into an in-support split and an out-of-support split , we would ideally measure the reference distribution probability density pref of inputs in the shifted dataset and assign inputs with small pref to the out-of-support split. [...] To estimate this ratio for the entire shifted dataset, we split the dataset into 10 folds and train a classifier to estimate pref/pshift on each fold using the remaining 9 folds. Specifically, we consider three natural shifts of the Image Net dataset: Image Net-V2 (Recht et al., 2019), which closely resembles Image Net, Image Net Sketch (Wang et al., 2019), which consists of sketches of Image Net classes, and Image Net-R (Hendrycks et al., 2020a). As a baseline, we train 100 Res Net-50 models from scratch on random subsets ranging from 25% of the reference dataset to the entire dataset. |
| Hardware Specification | Yes | All models are trained using the FFCV data-loading library (Leclerc et al., 2022) on a cluster of A100 GPUs. |
| Software Dependencies | No | The paper mentions several software libraries and frameworks (FFCV, PyTorch Image Models, CLIP, Open CLIP) and references the papers where they were introduced (e.g., Leclerc et al., 2022; Wightman, 2019; Radford et al., 2021; Ilharco et al., 2021). However, it does not provide specific version numbers for any of these software dependencies (e.g., PyTorch 1.x, Python 3.x, CUDA x.x). |
| Experiment Setup | Yes | We run Adam W for 100 epochs, using a cosine learning rate schedule with a peak learning rate of 0.003 and 10 warmup epochs, a batch size of 512, a weight decay of 0.1 and gradient clipping at global norm 1. We fully fine-tune models by running Adam W for 8 epochs, using a cosine learning rate schedule with 1 warmup epoch. We select the best peak learning rate (in terms of reference accuracy) among 3 10 4, 1 10 4, 3 10 5, 1 10 5, 3 10 6, 1 10 6. We use a batch size of 512, a weight decay of 0.1, and gradient clipping at global norm 1. |