Accelerating Training with Neuron Interaction and Nowcasting Networks
Authors: Boris Knyazev, Abhinav Moudgil, Guillaume Lajoie, Eugene Belilovsky, Simon Lacoste-Julien
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate that Ni No accelerates training with Adam for Conv Nets and Transformers reducing the number of steps to achieve the target performance of Adam by up to 50%. We release our source code and models at https://github.com/Samsung SAILMontreal/nino. [...] We experiment with nine tasks, each defined by a dataset and a neural network architecture in the vision or language domains (Table 1). Four of the tasks, the in-distribution tasks, are of a relatively smaller scale and used to train our meta-models (Ni No, WNN and their variants). The other five tasks, the out-of-distribution tasks, differ from the in-distribution tasks in the architecture and/or dataset. |
| Researcher Affiliation | Collaboration | 1Samsung SAIT AI Lab, Montreal 2Concordia University 3Universit e de Montr eal 4Mila |
| Pseudocode | Yes | A.9 PSEUDO-CODE We show the pseudo-code for the three main steps in our pipeline: 1. collecting the dataset of checkpoints on a set of training tasks; 2. training Ni No given the dataset of checkpoints; 3. evaluating/using the trained Ni No on new tasks. A.9.1 COLLECT CHECKPOINTS A.9.2 TRAINING NINO A.9.3 USING NINO |
| Open Source Code | Yes | We release our source code and models at https://github.com/Samsung SAILMontreal/nino. |
| Open Datasets | Yes | We use the Fashion MNIST (FM), CIFAR-10 (C10) and CIFAR-100 (C100) datasets and two convolutional architectures with three layers: with 16, 32 and 32 channels per layer (e.g. task FM/16) or 32, 64 and 64 channels per layer (e.g. task FM/32). [...] We use the LM1B (Chelba et al., 2013) and Wiki Text103 (WIKI) (Merity et al., 2016) datasets and train GPT2 style Transformers (Radford et al., 2019) with 3 layers, 24 hidden units and 3 attention heads (denoted as 3-24); [...] To investigate the ability of Ni No on more challenging vision tasks, we trained a small Vi T (11M parameters) on Image Net (Russakovsky et al., 2015), with 32x32 images as in Metz et al. (2022a); Loshchilov & Hutter (2017). |
| Dataset Splits | No | The paper mentions using specific datasets (Fashion MNIST, CIFAR-10, CIFAR-100, LM1B, Wiki Text103, ImageNet) and batch sizes (128 for vision, 32 for language) but does not explicitly state the train/validation/test splits (e.g., percentages or exact counts) for these datasets in the main text. It refers to 'validation set performance' and 'target validation performance' but not the split methodology itself. |
| Hardware Specification | Yes | Training of the Ni No and WNN+ meta-models completes in under 7 and 6 hours respectively on a single NVIDIA RTX8000 with 48GB of memory. [...] This setup fully utilizes a single NVIDIA A100-80GB training the model in around 6 hours. |
| Software Dependencies | No | The paper mentions 'Py Torch' in Section A.8.2, but does not provide any specific version number for it or any other software libraries or dependencies. For example, 'using Py Torch and the code base from Knyazev et al. (2023)' lacks version information for PyTorch. |
| Experiment Setup | Yes | For example, to use a trained WNN f ϕ k on a new task, first Adam is run for c epochs (c 5 by default) after which f ϕ k is applied to predict future (10-th epoch) parameters... [...] In all cases, these tasks are optimized using Adam (Kingma & Ba, 2015) without weight decay, with a constant learning rate of 6e-3 (for Fashion MNIST) or 3e-3 (for CIFAR) with a batch size of 128 for T=10k steps. [...] These tasks are optimized for the next token prediction loss with Adam W (Loshchilov & Hutter, 2017), weight decay 1e-2, learning rate 2e-4, batch size of 32, sequence length of 1024 for either 1 epoch (for LM1B) or 4 epochs (for WIKI) corresponding to around 24k or 14k steps respectively. [...] We train meta-models for 20k training iterations using Adam W, learning rate 3e-3 with cosine decay and weight decay 0.01. We sample a batch of 4 checkpoints in each training iteration and use automatic mixed precision. |