Ask Your Distribution Shift if Pre-Training is Right for You

Authors: Benjamin Cohen-Wang, Joshua Vendrow, Aleksander Madry

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases. After providing theoretical motivation and empirical evidence for this finding, we explore two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset.1
Researcher Affiliation Academia Benjamin Cohen-Wang EMAIL Massachusetts Institute of Technology Joshua Vendrow EMAIL Massachusetts Institute of Technology Aleksander MÄ…dry EMAIL Massachusetts Institute of Technology
Pseudocode No The paper includes theoretical motivations and proofs (e.g., Theorem 4.1), but it does not contain a clearly labeled 'Pseudocode' or 'Algorithm' block with structured steps formatted like code.
Open Source Code Yes Code is available at https://github.com/MadryLab/pretraining-distribution-shift-robustness
Open Datasets Yes A common paradigm for developing machine learning models is pre-training them on a large, diverse dataset (e.g., Image Net (Deng et al., 2009), JFT-300M (Sun et al., 2017), LAION-5B (Schuhmann et al., 2022)) and then fine-tuning them on task-specific data. To illustrate this, we consider two distribution shifts of Image Net (Deng et al., 2009): Image Net-V2 (Recht et al., 2019) and Image Net Sketch (Wang et al., 2019). In particular, we investigate the effectiveness of this strategy on WILDS-FMo W (Christie et al., 2018; Koh et al., 2020), a distribution shift benchmark for classifying satellite images. As a case study, we consider the task of predicting hair color (blond vs. non-blond) in the Celeb A dataset (Liu et al., 2015).
Dataset Splits Yes To split a shifted dataset into an in-support split and an out-of-support split , we would ideally measure the reference distribution probability density pref of inputs in the shifted dataset and assign inputs with small pref to the out-of-support split. [...] To estimate this ratio for the entire shifted dataset, we split the dataset into 10 folds and train a classifier to estimate pref/pshift on each fold using the remaining 9 folds. Specifically, we consider three natural shifts of the Image Net dataset: Image Net-V2 (Recht et al., 2019), which closely resembles Image Net, Image Net Sketch (Wang et al., 2019), which consists of sketches of Image Net classes, and Image Net-R (Hendrycks et al., 2020a). As a baseline, we train 100 Res Net-50 models from scratch on random subsets ranging from 25% of the reference dataset to the entire dataset.
Hardware Specification Yes All models are trained using the FFCV data-loading library (Leclerc et al., 2022) on a cluster of A100 GPUs.
Software Dependencies No The paper mentions several software libraries and frameworks (FFCV, PyTorch Image Models, CLIP, Open CLIP) and references the papers where they were introduced (e.g., Leclerc et al., 2022; Wightman, 2019; Radford et al., 2021; Ilharco et al., 2021). However, it does not provide specific version numbers for any of these software dependencies (e.g., PyTorch 1.x, Python 3.x, CUDA x.x).
Experiment Setup Yes We run Adam W for 100 epochs, using a cosine learning rate schedule with a peak learning rate of 0.003 and 10 warmup epochs, a batch size of 512, a weight decay of 0.1 and gradient clipping at global norm 1. We fully fine-tune models by running Adam W for 8 epochs, using a cosine learning rate schedule with 1 warmup epoch. We select the best peak learning rate (in terms of reference accuracy) among 3 10 4, 1 10 4, 3 10 5, 1 10 5, 3 10 6, 1 10 6. We use a batch size of 512, a weight decay of 0.1, and gradient clipping at global norm 1.