Adaptive Group Robust Ensemble Knowledge Distillation

Authors: Patrik Kenfack, Ulrich Aïvodji, Samira Ebrahimi Kahou

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our experiments on several datasets demonstrate the superiority of the proposed ensemble distillation technique and show that it can even outperform classic model ensembles based on majority voting. Our source code is available at https://github.com/patrikken/AGRE-KD
Researcher Affiliation Academia Patrik Joslin Kenfack EMAIL ÉTS Montréal, Mila Quebec AI Institute Ulrich Aïvodji EMAIL ÉTS Montréal, Mila Quebec AI Institute Samira Ebrahimi Kahou EMAIL University of Calgary, Mila Quebec AI Institute Canada CIFAR AI Chair
Pseudocode Yes A AGRE-KD: Algorithm and method overview In Algorithm 1, we provide a high-level description of the AGRE-KD methodology for training and validation of group-robust distilled models. We also provide an overview of our method in Figure 2. Algorithm 1 Adaptive Group Robust Ensemble Knowledge Distillation (AGRE-KD)
Open Source Code Yes Our source code is available at https://github.com/patrikken/AGRE-KD
Open Datasets Yes For example, in the Waterbirds dataset (Sagawa et al., 2019), which contains images of landbirds and waterbirds... We considered different settings of the Colored MNIST (CMNIST) dataset by varying the ratios of biasaligned samples in the training dataset... Waterbirds (Sagawa et al., 2019; Liu et al., 2021) is a dataset of birds derived from Caltech-UCSD Birds (CUB) (Wah et al., 2011)... Celeb A (Liu et al., 2015) dataset contains images of celebrities... Civil Comments (Koh et al., 2021) is a textual dataset... Both networks are pretrained on the Image Net-1K (Russakovsky et al., 2015) dataset. For the language task, we use the BERT (Devlin et al., 2019) model for teachers and the Distil BERT (Sanh et al., 2019) for the student model; and language models are pretrained on Book Corpus and English Wikipedia.
Dataset Splits Yes Following Kirichenko et al. (2022); La Bonte et al. (2024), we use half of the validation set of each benchmark to perform last-layer retraining with DFR; as we do not use the group and class labels during the KD training, we do not perform further hyperparameter tuning or model selection. We considered different settings of the Colored MNIST (CMNIST) dataset by varying the ratios of biasaligned samples in the training dataset, i.e., the proportion of samples where the color and digit correspond {99.5%, 99%, 98%, and 95%}. This means for configuration with a ratio of 99.5%, only 0.005% of samples will have a digit-color mismatch, and decreasing the ratio reduces the strength of the spurious correlation in the training dataset.
Hardware Specification No The paper does not provide specific hardware details (e.g., exact GPU/CPU models, processor types with speeds, memory amounts, or detailed computer specifications) used for running its experiments.
Software Dependencies No Our implementation uses Py Torch (Paszke et al., 2017; 2019), Torch Lightning (Falcon & team, 2019), and Milkshake (La Bonte, 2023). While software components are mentioned, specific version numbers for PyTorch and Torch Lightning are not provided.
Experiment Setup Yes Following related works on KD (Du et al., 2020; Fukuda et al., 2017; Chen et al., 2022), we set the temperature hyperparameter τ = 4 (Eq. 1) and show in an ablation study in Supplementary D (Figure 5) that increasing τ can exert positive effect on WGA up to certain values. We provide further details about hyperparameters in the Supplementary B. For the vision tasks, we used an initial learning rate of 1 x 10^-3 with a cosine learning rate scheduler; we used a batch size of 32 and 100 for the Waterbirds and the Celeb A datasets, respectively. For the Civil Comments dataset, we use an initial learning rate of 1 x 10^-5 with a linear learning rate scheduler, a batch size 16, and train for ten epochs. We keep all hyperparameters fixed to train the teacher and student models. For the optimizer, we used Adam W (Loshchilov et al., 2019) and SGD for the language and vision datasets, respectively, with a weight decay of 1 x 10^-4.