Attention as a Hypernetwork
Authors: Simon Schug, Seijin Kobayashi, Yassir Akram, Joao Sacramento, Razvan Pascanu
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We find empirically that this latent code is predictive of the subtasks the network performs on unseen task compositions, revealing that latent codes acquired during training are reused to solve unseen problem instances. To further examine the hypothesis that the intrinsic hypernetwork of multi-head attention supports compositional generalization, we ablate whether making the hypernetwork-generated linear value network nonlinear strengthens compositionality. We find that this modification improves compositional generalization on abstract reasoning tasks. |
| Researcher Affiliation | Collaboration | Simon Schug ETH Zürich EMAIL, Seijin Kobayashi ETH Zürich EMAIL, Yassir Akram ETH Zürich EMAIL, João Sacramento Google, Paradigms of Intelligence Team EMAIL, Razvan Pascanu Google Deep Mind EMAIL |
| Pseudocode | Yes | Algorithm 1 Multi-head softmax attention Algorithm 2 Hypernetwork linear attention Figure A1: Pseudocode comparing multi-head softmax attention to hypernetwork linear attention. Differences between the two are highlighted in yellow. |
| Open Source Code | Yes | Code available at https://github.com/smonsays/hypernetwork-attention |
| Open Datasets | Yes | We train decoder-only transformer models with 50M parameters autoregressively for 130 Billion tokens on the C4 dataset (Raffel et al., 2020). |
| Dataset Splits | Yes | Unless noted otherwise, we hold out 25% of all possible rule combinations for evaluation in our experiments. |
| Hardware Specification | Yes | We used a Linux workstation with two Nvidia RTX 3090 GPUs with 24GB of memory each for development and conducted hyperparameter searches and experiments using 1 Linux server with 4 Nvidia RTX 3090 GPUs as well as a Slurm cluster equipped with Nvidia RTX 4090 GPUs. For the language modeling experiments, we used 16 Cloud TPU v5e with a complete run taking 72-100 hours. |
| Software Dependencies | Yes | We implemented our experiments in Python using JAX (Bradbury et al., 2018, Apache License 2.0), Flax (Heek et al., 2023, Apache License 2.0), Nano Do (Liu et al., 2024, Apache License 2.0) and the Deepmind Jax Ecosystem (Babuschkin et al., 2020, Apache License 2.0). |
| Experiment Setup | Yes | Hyperparameters. For all tasks and models we perform a grid search over the learning rate, weight decay and warmup steps. We report the search grid as well as all other hyperparameters in Table A3. |