Scalable Meta-Learning via Mixed-Mode Differentiation

Authors: Iurii Kemaev, Dan A. Calian, Luisa M Zintgraf, Gregory Farquhar, Hado Van Hasselt

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this paper, we... use modern hardware and libraries for tensor programming to demonstrate that the proposed algorithmic technique, whilst requiring only minor code modifications, yields significant performance improvements in common meta-learning scenarios. In a representative setting, Mix Flow-MG demonstrates reductions up to 95% in the active memory consumption and 25% reduction in wall-clock time, thus allowing to scale bilevel gradient setups by more than an order of magnitude in a compute-efficient way. Section 5. Benchmarking Language Modelling Tasks
Researcher Affiliation Industry 1Google Deep Mind. Correspondence to: Iurii Kemaev <EMAIL>.
Pseudocode Yes Algorithm 1 Standard Truncated-BPTT (Equation (3)) Algorithm 2 Mixed-mode Truncated-BPTT (Equation (4))
Open Source Code Yes We included a minimalistic implementation in JAX and PyTorch for Mix Flow-MG in Appendix A.4 for reference and easy adoption.
Open Datasets No We chose the language modelling domain for the inner-level optimisation, where the standard loss is the next-token-prediction loss NTP(θ,x). We use the Chinchilla family of language models (Hoffmann et al., 2022a) with Ro PE (Su et al., 2024) and the Adam optimiser (Kingma, 2014).
Dataset Splits No The paper mentions 'Batch size {2, 4, 8}' and 'Sequence length {2048, 4096, 8192}' as hyperparameters for their sweep, but does not provide specific information regarding dataset splits (e.g., training, validation, test percentages or counts) or how the data for their experiments was partitioned.
Hardware Specification Yes Benchmarking was performed in JAX (Bradbury et al., 2018) on TPUv5p and H100 using the Open XLA backend and libraries from Deep Mind et al. (2020).
Software Dependencies No The paper mentions 'JAX (Bradbury et al., 2018)' and 'Py Torch (Paszke et al., 2017)' and 'Open XLA backend and libraries from Deep Mind et al. (2020)', but does not provide specific version numbers for any of these software components.
Experiment Setup Yes Table 1. Sweep over tasks: hyperparameters and values. Parameter Values Task {learning lr, maml, loss weighting} Model size ( 10^6) {57, 106, 163, 217, 306} # of inner updates (T) {2, 4, 8} Batch size {2, 4, 8} Sequence length {2048, 4096, 8192}