Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization

Authors: Ramnath Kumar, Kushal Alpesh Majmundar, Dheeraj Mysore Nagaraj, Arun Suggala

TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate the effectiveness of RGD on various learning tasks, including supervised learning, meta-learning, and out-of-domain generalization. Notably, RGD achieves state-of-the-art results on diverse benchmarks, with improvements of +0.7% on Domain Bed, +1.44% on tabular classification, +1.94% on GLUE with BERT, and +1.01% on Image Net-1K with Vi T.
Researcher Affiliation Industry Ramnath Kumar EMAIL Google Inc.Kushal Majmundar EMAIL Google Inc.Dheeraj Nagaraj EMAIL Google Inc.Arun Sai Suggala EMAIL Google Inc.
Pseudocode Yes Algorithm 1 Re-weighted Gradient Descent (RGD) 1: Input: Data {zi}n i=1, learning rate sequence {ηt}T t=1, number of iterations T, loss function ℓ, re-weighting function g, mini-batch size B 2: for t = 0 . . . T 1 do 3: Sample minibatch {zi}B i=1 4: Compute losses for points in the minibatch: ℓi ℓ(zi; θt), i 1 . . . B 5: Compute per-sample weights: wi g(ℓi) i 1 . . . B 6: Compute the weighted pseudo-gradient: i=1 wi θℓ(zi; θt) 7: Update model parameters: θt+1 ΠΘ(θt ηtvt)
Open Source Code No Our proposed loss function is a single line of change. However, one would have to play around with the learning rate (generally lower than the baseline setting). Our experiments are based on public datasets and open-source code repositories. The proposed final formulation RGD requires one line of code change. Suppose the per-sample loss is given. Example code for applying RGD in Jax is shown below. This section does not provide an explicit link to a code repository for the methodology presented in this paper, nor does it state that the authors' full implementation code is available.
Open Datasets Yes We demonstrate the effectiveness of RGD on various learning tasks, including supervised learning, meta-learning, and out-of-domain generalization. Notably, RGD achieves state-of-the-art results on diverse benchmarks, with improvements of +0.7% on Domain Bed, +1.44% on tabular classification, +1.94% on GLUE with BERT, and +1.01% on Image Net-1K with Vi T.
Dataset Splits Yes To further demonstrate that our weight clipping can effectively handle benign outliers, we perform the following experiment. We randomly flip the labels in CIFAR-10, CIFAR-100 datasets (we vary the proportion of flips from 0% to 40%) and compare the performance of RGD with state-of-the-art KL-DRO optimization technique TERM (Li et al., 2021). We use the Long-Tailed CIFAR dataset, where we reduce the number of training samples per class according to an exponential function as proposed by Cui et al. (2019). In meta-learning, the goal is to learn representations that generalize effectively to new tasks, even when provided with limited examples. However, task heterogeneity poses a significant challenge. Some tasks may be inherently simpler to learn, leading models to prioritize these and neglect the more difficult, less frequent tasks. To briefly explain, consider the dataset PACS, which consists of Photos, Art, cartoons, and sketches of the same set of classes (for instance, dogs and cats, amongst others). The goal of the task is to learn from three of these domains and evaluate the performance of the left-out domain (similar to a k-fold cross-validation).
Hardware Specification No The paper does not explicitly mention the specific hardware (e.g., GPU models, CPU types, or TPUs) used for running its experiments. It only refers to implementing in JAX and basing experiments on open-source code repositories, without hardware details for their own runs.
Software Dependencies No Our proposed loss function is a single line of change. However, one would have to play around with the learning rate (generally lower than the baseline setting). Our experiments are based on public datasets and open-source code repositories. The proposed final formulation RGD requires one line of code change. Suppose the per-sample loss is given. Example code for applying RGD in Jax is shown below. import jax.numpy as jnp import jax While JAX is mentioned and code snippets are provided, no specific version number for JAX or any other software dependency is stated.
Experiment Setup Yes RGD introduces only one additional hyperparameter, the clipping factor (τ). The optimizer, weight decay, batch sizes, etc. were kept constant with the baseline across all our experiments. Details regarding hyperparameter tuning are relegated to Appendix 5. In this section, we describe the common hyperparameter tuning space used across all experiments in our paper unless otherwise mentioned. The two parameters we tune were τ and lr. We use a simple grid search for τ in the order of [1, 3, 5, 7, 9] across the experiments where the scaling factor (γ) is by default set as 1 τ+1. This allowed our loss to be bounded between 0,1 and helped fairly compare RGD-χ2and RGD. The lr was tuned by a proxy of lr_mult where we scaled the learning rate by a fraction in the range [0.5, 1.5]. We trained the Bert-base model for 450K steps, and tuned the learning rate (lr) for baseline, and lr, clipping factor for RGD. We trained for 100K steps with a batch size of 256. We used the default learning rate of 0.0016 for the baseline. For RGD, we fix the clipping threshold to 1 and tune the learning rate.