Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift

Authors: Nguyen Nhat Minh To, Paul F R Wilson, Viet Nguyen, Mohamed Harmanani, Michael Cooper, Fahimeh Fooladgar, Purang Abolmaesumi, Parvin Mousavi, Rahul Krishnan

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In empirical evaluation on nine real-world datasets, covering diverse domains and kinds of subpopulation shift, our method of Diverse Prototypical Ensembles (DPEs) often outperforms the prior state-of-the-art in worst-group accuracy. We empirically validate DPE using nine real-world datasets proposed in (Yang et al., 2023) to assess robustness against different types of subpopulation shifts. Our results show DPE’s superior performance over prior state-of-the-art methods, including in challenging cases like attribute generalization and imbalance.
Researcher Affiliation Academia 1Department of Electrical and Computer Engineering, University of British Columbia, Vancouver, Canada 2Vector Institute, Toronto, Canada 3School of Computing, Queen s University, Kingston, Canada 4Department of Computer Science, University of Toronto, Toronto, Canada. Correspondence to: Minh Nguyen Nhat To <EMAIL>.
Pseudocode Yes Our training pipeline (Algorithm 1) consists of two stages: feature extractor training and prototypical ensemble training. In Stage 1, the feature extractor and classification head are trained using ERM on the training data to optimize feature representations. In Stage 2, an ensemble of class-specific prototypes is initialized and trained on class-balanced subsets or group-balanced subsets of the validation data, depending on the availability of the subgroup annotations. A distance-based loss and an inter-prototype similarity loss are used to update each ensemble member. During inference, class probabilities are computed using the joint decision of the members in the prototypical ensemble. Algorithm 1 Subpopulation Prototypical Ensemble Diversification
Open Source Code Yes The code is available at https://github.com/minhto2802/dpe4subpop.
Open Datasets Yes Specifically, the evaluation includes WATERBIRDS (Wah et al., 2011), CELEBA (Liu et al., 2015), METASHIFT (Liang & Zou, 2022), IMAGENETBG (Xiao et al., 2021), NICO++ (Zhang et al., 2023), LIVING17 (Santurkar et al., 2021), CHEXPERT (Irvin et al., 2019), CIVILCOMMENTS (Borkan et al., 2019), and MULTINLI (Schuhmann et al., 2022).
Dataset Splits Yes We use the same training/validation/test splits given by (Yang et al., 2023). More details of all datasets are provided in Appendix A.1. Table 1. Summary of the WATERBIRDS dataset. Dataset # Attr. # Cls. # Tr. # Val. # Test WATERBIRDS 2 2 4795 1199 5794
Hardware Specification Yes To quantify computational efficiency, we benchmarked runtime and GPU memory usage on an RTX6000 with a Res Net-50 backbone and batch size 1.
Software Dependencies No Image datasets such as WATERBIRDS, METASHIFT, and IMAGENETBG use the SGD optimizer with a learning rate of 1e-2 and a batch size of 128, while text-based datasets such as CIVILCOMMENTS and MULTINLI use Bert Adam with a learning rate of 1e-4 and a batch size of 16. The number of training epochs varies by dataset, ranging from 4 for CIVILCOMMENTS to 300 for WATERBIRDS. We adopted pretrained Res Net-50 for image data and BERT for textual data to facilitate a direct comparison with state-of-the-art benchmarks on subpopulation shift robustness. The text mentions software components like 'SGD optimizer', 'Bert Adam', 'Res Net-50', and 'BERT' but does not provide specific version numbers for any of them, nor does it list core programming languages or frameworks with versions (e.g., Python, PyTorch).
Experiment Setup Yes The training procedure for DPE framework consists of two stages and is implemented with dataset-specific hyperparameters detailed in Tables 10 and 11. Hyperparameters were selected to optimize the WGA on the validation set, ensuring fair evaluation across different datasets. Table 10. Hyperparameters of the representation learning. Dataset # Epochs Optimizer LR Batch size. Table 11. Hyperparameters of prototypical ensemble learning. Dataset LR Optimizer Batch size λ