Meta-learning Optimizers for Communication-Efficient Learning

Authors: Charles-Étienne Joseph, Benjamin Thérien, Abhinav Moudgil, Boris Knyazev, Eugene Belilovsky

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we propose learned optimization as an approach to alleviate the challenges of communicationefficient distributed learning. Specifically, we follow the setup of local SGD Stich (2019) with homogeneous devices and homogeneous data split among them and demonstrate that our global learned optimizers (fig. 1) meta-trained for this setting can outperform Local SGD and Slow Mo (Wang et al., 2019) as well as dataparallel Adam and SGD. Our results demonstrate that learned optimizers can substantially outperform local SGD and its sophisticated variants while maintaining their communication efficiency. Our learned optimizers can even generalize to unseen and much larger datasets and architectures, including Image Net and Vi Ts, and to unseen modalities such as language modeling. We therefore show the potential of learned optimizers for improving communication-efficient distributed learning. Our experiments in the main manuscript assume a distributed setting with homogeneous devices and homogeneous data split among them. Following the convention in learned optimization Metz et al. (2022a; 2019); Harrison et al. (2022), we mainly report training loss for simplicity comparing among different optimizers. However, we do demonstrate that models trained by our global learned optimizers compare favorably on held-out data to models trained with other optimizers (see Figures 4 and 6).
Researcher Affiliation Collaboration Charles-Étienne Joseph EMAIL Benjamin Thérien EMAIL Abhinav Moudgil EMAIL Boris Knyazev EMAIL Eugene Belilovsky , EMAIL Department of Computer Science and Operation Research, Université de Montréal, Montréal, Canada Department of Computer Science and Software Engineering, Concordia University, Montréal, Canada Mila, Montréal, Canada Samsung SAIT AI Lab, Montréal, Canada
Pseudocode Yes Algorithm 1: Learned optimizers vs Local SGD. Steps used in both algorithms are not colored. Input: T Number of communication steps K Number of workers H Number of local steps γ Local learning rate W0,0 Initial weights D Dataset L Loss function Fϕ Learned optimizer U0 Initial accumulators state Algorithm 2: Meta-training our Global Learned Optimizers with PES. Note that this algorithm has been adapted from algorithm 1 of (Vicol et al., 2021) with minimal changes to the notation. Input: Optimizer parameters θt, gradients
Open Source Code No No explicit statement or link to open-source code is provided in the paper. The paper states: "Our method is currently implemented in simulation." which does not indicate code release.
Open Datasets Yes We use the Fashion MNIST (FMNIST) dataset (Xiao et al., 2017) (10 classes) with 28 28 images. We also use the CIFAR-10 dataset (Krizhevsky et al., 2009) (10 classes) with 32 32 images. Finally, we scale our setup to the Image Net dataset (Russakovsky et al., 2015) (1000 classes) with downsampled 32 32 and 64 64 images. For the language modeling task, we use LM1B (Chelba et al., 2013).
Dataset Splits No The paper does not provide specific details on training/validation/test splits (e.g., percentages or sample counts) for the main datasets (FMNIST, CIFAR-10, ImageNet). It mentions using a "test split" for accuracy in the Federated Learning section, but no general splitting methodology for all experiments.
Hardware Specification Yes All timings are measured when training across K=8 A6000 GPUs. We meta-train and evaluate using 1 NVIDIA A100.
Software Dependencies No The paper does not explicitly list specific software dependencies with version numbers. It mentions "jax.jit" but without a version.
Experiment Setup Yes In our experiments, Fϕ is a two hidden layer 32 hidden dimension MLP with Re LU activations mapping the input Ada features for each parameter, p, in the optimizee to a two-dimensional vector, [dϕ, mϕ]. At step t, the learned optimizer update for all p is given as follows: Fϕ(Ap [ t,p]) = [dϕ,p, mϕ,p]; pt = pt 1 λ1dϕ,pe(λ2mϕ,p). Where Ap are the ada features computed from statistics of p and λ1 and λ2 are constants set to 0.001. We meta-train our learned optimizers we estimate gradients using Persistent Evolution Strategies (PES) (Vicol et al., 2021) and take gradient descent steps using Adam W and a linear warmup plus cosine decay schedule. Each gradient is estimated from a batch of 8 tasks. T varies from 100 to 1000 during training according to a log-uniform truncation schedule. In our experiments, gradients are estimated with respect to the optimizee s training loss, except for the curves in fig. 4 whose gradients were estimated with respect to the optimizee s validation loss. During meta-training, the learning rate is warmed up for 100 steps to a maximum learning rate before being decayed (following a cosine decay schedule) to 1/3 of the maximum value. All the meta-training details are provided in appendix B. For each task, we use a local batch size Bloc of 128. For both optimizers, we use gradient clipping to a norm of 5 and weight decay of 1e 4 for the local steps.