Scalable Meta-Learning via Mixed-Mode Differentiation
Authors: Iurii Kemaev, Dan A. Calian, Luisa M Zintgraf, Gregory Farquhar, Hado Van Hasselt
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this paper, we... use modern hardware and libraries for tensor programming to demonstrate that the proposed algorithmic technique, whilst requiring only minor code modifications, yields significant performance improvements in common meta-learning scenarios. In a representative setting, Mix Flow-MG demonstrates reductions up to 95% in the active memory consumption and 25% reduction in wall-clock time, thus allowing to scale bilevel gradient setups by more than an order of magnitude in a compute-efficient way. Section 5. Benchmarking Language Modelling Tasks |
| Researcher Affiliation | Industry | 1Google Deep Mind. Correspondence to: Iurii Kemaev <EMAIL>. |
| Pseudocode | Yes | Algorithm 1 Standard Truncated-BPTT (Equation (3)) Algorithm 2 Mixed-mode Truncated-BPTT (Equation (4)) |
| Open Source Code | Yes | We included a minimalistic implementation in JAX and PyTorch for Mix Flow-MG in Appendix A.4 for reference and easy adoption. |
| Open Datasets | No | We chose the language modelling domain for the inner-level optimisation, where the standard loss is the next-token-prediction loss NTP(θ,x). We use the Chinchilla family of language models (Hoffmann et al., 2022a) with Ro PE (Su et al., 2024) and the Adam optimiser (Kingma, 2014). |
| Dataset Splits | No | The paper mentions 'Batch size {2, 4, 8}' and 'Sequence length {2048, 4096, 8192}' as hyperparameters for their sweep, but does not provide specific information regarding dataset splits (e.g., training, validation, test percentages or counts) or how the data for their experiments was partitioned. |
| Hardware Specification | Yes | Benchmarking was performed in JAX (Bradbury et al., 2018) on TPUv5p and H100 using the Open XLA backend and libraries from Deep Mind et al. (2020). |
| Software Dependencies | No | The paper mentions 'JAX (Bradbury et al., 2018)' and 'Py Torch (Paszke et al., 2017)' and 'Open XLA backend and libraries from Deep Mind et al. (2020)', but does not provide specific version numbers for any of these software components. |
| Experiment Setup | Yes | Table 1. Sweep over tasks: hyperparameters and values. Parameter Values Task {learning lr, maml, loss weighting} Model size ( 10^6) {57, 106, 163, 217, 306} # of inner updates (T) {2, 4, 8} Batch size {2, 4, 8} Sequence length {2048, 4096, 8192} |