Variational Classification: A Probabilistic Generalization of the Softmax Classifier

Authors: Shehzaad Zuzar Dhuliawala, Mrinmaya Sachan, Carl Allen

TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Empirical evaluation on image and text classification datasets demonstrates that our proposed approach, variational classification1, maintains classification accuracy while the reshaped latent space improves other desirable properties of a classifier, such as calibration, adversarial robustness, robustness to distribution shift and sample efficiency useful in low data settings. Through a series of experiments on vision and text datasets, we demonstrate that VC achieves comparable accuracy to regular softmax classification while the aligned latent distribution improves calibration, robustness to adversarial perturbations (specifically FGSM white box ), generalisation under domain shift and performance in low data regimes.
Researcher Affiliation Academia Shehzaad Dhuliawala EMAIL Department of Computer Science, ETH Zurich, Switzerland Mrinmaya Sachan EMAIL Department of Computer Science, ETH Zurich, Switzerland Carl Allen EMAIL AI Centre, ETH Zurich, Switzerland
Pseudocode Yes Algorithm 1 Variational Classification (VC) 1: Input pθ(z|y), qϕ(z|x), pπ(y), Tψ(z); learning rate schedule {ηt θ, ηt ϕ, ηt π, ηt ψ}t, β 2: Initialise θ, ϕ, π, ψ; t 0 3: while not converged do 4: {xi, yi}m i=1 D [sample batch from data distribution p(x, y)] 5: for z = {1 ... m} do 6: zi qϕ(z|xi), z i pθ(z|yi) [e.g. qϕ(z|xi) .=δz fω(xi), ϕ .=ω zi =fω(xi)] 7: pθ(yi|zi) = pθ(zi|yi)pπ(yi) P y pθ(zi|y)pπ(y) 8: end for 9: gθ 1 m Pm i=1 θ [log pθ(yi|zi) + β pθ(zi|yi)] 10: gϕ 1 m Pm i=1 ϕ [log pθ(yi|zi) β Tψ(zi)] [e.g. using reparameterisation trick ] 11: gπ 1 m Pm i=1 π log pπ(yi) 12: gψ 1 m Pm i=1 ψ [log σ(Tψ(zi)) + log(1 σ(Tψ(z i))] 13: θ θ + ηt θ gθ, ϕ ϕ + ηt ϕ gϕ, π π + ηt π gπ, ψ ψ + ηt ψ gψ, t t + 1 14: end while
Open Source Code Yes 1Code: www.github.com/shehzaadzd/variational-classification.
Open Datasets Yes Empirical evaluation on image and text classification datasets demonstrates that our proposed approach... on three standard benchmarks (CIFAR-10, CIFAR-100, and Tiny-Imagenet)... robustness benchmarks, CIFAR-10-C, CIFAR-100-C and Tiny-Imagenet-C, proposed by Hendrycks & Dietterich (2019)... trained on the SNLI dataset (Bowman et al., 2015) and tested on the MNLI dataset (Williams et al., 2018)... Cancer detection uses the Camelyon17 dataset (Bandi et al., 2018) from the WILDs datasets (Koh et al., 2021)... on 10 Med MNIST classifcation datasets (Yang et al., 2021)
Dataset Splits Yes Models are trained on 500 samples from MNIST, 1000 samples from CIFAR-10 and 50 samples from AGNews. We compute the AUROC when a model is trained on CIFAR-10 and evaluated on the CIFAR-10 validation set mixed (in turn) with SVHN, CIFAR-100, and Celeb A.
Hardware Specification No No specific hardware details (like GPU/CPU models or specific cloud instances) are mentioned in the paper.
Software Dependencies No No specific software dependencies with version numbers are mentioned in the paper.
Experiment Setup Yes MC-Dropout (Gal & Ghahramani, 2016) for CIFAR-10 and CIFAR-100 on Res Net-50 (p = 0.2, averaging over 10 samples). Temperature scaling was performed as in Guo et al. (2017) with the temperature tuned on an in-distribution validation set. class-conditional priors pθ(z|y) are multi-variate Gaussians with parameters learned from the data (we use diagonal covariance for simplicity). In our experiments we use 20 equally spaced bins.