Attention as a Hypernetwork

Authors: Simon Schug, Seijin Kobayashi, Yassir Akram, Joao Sacramento, Razvan Pascanu

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We find empirically that this latent code is predictive of the subtasks the network performs on unseen task compositions, revealing that latent codes acquired during training are reused to solve unseen problem instances. To further examine the hypothesis that the intrinsic hypernetwork of multi-head attention supports compositional generalization, we ablate whether making the hypernetwork-generated linear value network nonlinear strengthens compositionality. We find that this modification improves compositional generalization on abstract reasoning tasks.
Researcher Affiliation Collaboration Simon Schug ETH Zürich EMAIL, Seijin Kobayashi ETH Zürich EMAIL, Yassir Akram ETH Zürich EMAIL, João Sacramento Google, Paradigms of Intelligence Team EMAIL, Razvan Pascanu Google Deep Mind EMAIL
Pseudocode Yes Algorithm 1 Multi-head softmax attention Algorithm 2 Hypernetwork linear attention Figure A1: Pseudocode comparing multi-head softmax attention to hypernetwork linear attention. Differences between the two are highlighted in yellow.
Open Source Code Yes Code available at https://github.com/smonsays/hypernetwork-attention
Open Datasets Yes We train decoder-only transformer models with 50M parameters autoregressively for 130 Billion tokens on the C4 dataset (Raffel et al., 2020).
Dataset Splits Yes Unless noted otherwise, we hold out 25% of all possible rule combinations for evaluation in our experiments.
Hardware Specification Yes We used a Linux workstation with two Nvidia RTX 3090 GPUs with 24GB of memory each for development and conducted hyperparameter searches and experiments using 1 Linux server with 4 Nvidia RTX 3090 GPUs as well as a Slurm cluster equipped with Nvidia RTX 4090 GPUs. For the language modeling experiments, we used 16 Cloud TPU v5e with a complete run taking 72-100 hours.
Software Dependencies Yes We implemented our experiments in Python using JAX (Bradbury et al., 2018, Apache License 2.0), Flax (Heek et al., 2023, Apache License 2.0), Nano Do (Liu et al., 2024, Apache License 2.0) and the Deepmind Jax Ecosystem (Babuschkin et al., 2020, Apache License 2.0).
Experiment Setup Yes Hyperparameters. For all tasks and models we perform a grid search over the learning rate, weight decay and warmup steps. We report the search grid as well as all other hyperparameters in Table A3.