Optimizing importance weighting in the presence of sub-population shifts
Authors: Floris Holstege, Bram Wouters, Noud Giersbergen, Cees Diks
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In the context of last-layer retraining of DNNs in the presence of sub-population shifts, we show empirically, for benchmark vision and natural language processing datasets, that existing state-of-the-art importance weighting methods (see Section 2) can be improved significantly by optimizing for the weights. We find that both the weighted average and worst-group accuracy generally improve, and that optimized weights increase robustness against the choice of hyperparameters for training. We show that the effect of optimizing weights is larger when the limited size of the training sample becomes pressing, namely in the case of a small total sample size or when the minority groups are small. |
| Researcher Affiliation | Academia | 1University of Amsterdam, Department of Quantitative Economics 2Tinbergen Institute EMAIL |
| Pseudocode | Yes | Details of the optimization procedure are provided in Algorithm 1. ... See Algorithm 2 in Appendix D for details. |
| Open Source Code | No | We provide an implementation of this procedure as an open-source package (link). |
| Open Datasets | Yes | The goal of this section is to verify empirically that existing state-of-the-art methods for addressing sub-population shifts benefit from using optimized group weights. We use benchmark classification datasets for studying sub-population shifts: two from computer vision (Waterbirds, Celeb A) and one from natural language processing (Multi NLI). ... Waterbirds: this dataset from Sagawa et al. (2020a) is a combination of the Places dataset (Zhou et al., 2016) and the CUB dataset (Welinder et al., 2010). |
| Dataset Splits | Yes | The training and validation set are randomly split. For Waterbirds, this is different from previous implementations, where the validation set is balanced in terms of the groups (Sagawa et al., 2020a). We deviate from this, since it is more realistic that the training and validation sets are drawn from the same distribution. ... Multi NLI: ... create a smaller binary version of the dataset (50.000 samples in training, 20.000 in validation, 30.000 in test)... |
| Hardware Specification | No | The paper does not provide specific hardware details such as GPU or CPU models. It mentions using pre-trained DNN models (ResNet50, BERT) but not the computational resources they were run on. |
| Software Dependencies | No | For the Waterbirds and Celeb A dataset, we use the Resnet50 architecture implemented in the torchvision package... For the Multi NLI dataset, we use the base BERT model implemented in the transformers package (Wolf et al., 2019)... We use the Adam W optimizer (Loshchilov & Hutter, 2017) with the standard settings in Pytorch... For all methods, unless otherwise mentioned, we use a logistic regression with the sklearn.LogisticRegression class with the Liblinear solver... While software packages are mentioned (torchvision, transformers, Pytorch, sklearn), specific version numbers for these dependencies are not provided. |
| Experiment Setup | Yes | For Waterbirds, this means using a learning rate of 10-3, a weight decay of 10-3, a batch size of 32, and for 100 epochs without early stopping. For Celeb A, this means using a learning rate of 10-3, a weight decay of 10-4, a batch size of 128, and for 50 epochs without early stopping. We use stochastic gradient descent (SGD) with a momentum parameter of 0.9. ... For finetuning the BERT model on Multi NLI, we use the Adam W optimizer (Loshchilov & Hutter, 2017) with the standard settings in Pytorch. When finetuning, we use the hyperparameters of Izmailov et al. (2022), training for 10 epochs with a batch size of 16, a learning rate of 10-5, and a weight decay of 10-4, and linear learning rate decay. ... For the L1 penalty, we select it from the following values: 0.1, 1.0, 3.3, 10.0, 33.33, 100, 300, 500. |