Softmax is not Enough (for Sharp Size Generalisation)
Authors: Petar Veličković, Christos Perivolaropoulos, Federico Barbero, Razvan Pascanu
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | To motivate our theory, we train a simple architecture including a single attentional head to predict a feature of the maximum item in a set. Each item s features are processed with a deep MLP before attending, and the output vector of the attention is passed to a deep MLP predictor (see Appendix A for experimental details). We train this model using sets of 16 items, and in Figure 2 we visualise the head s attentional coefficients, computed over sets of varying size at inference time. While the model indeed attributes focus sharply and cleanly on the maximum item, this only holds true on the problem sizes that the model was trained on. |
| Researcher Affiliation | Collaboration | 1Google Deep Mind 2University of Oxford. Correspondence to: Petar Veliˇckovi c <EMAIL>. |
| Pseudocode | Yes | def adaptive_temperature_softmax(logits): original_probs = jax.nn.softmax(logits) poly_fit = jnp.array([-0.037, 0.481, -2.3, 4.917, -1.791]) # see Figure 6 entropy = jnp.sum(-original_probs * jnp.log(original_probs + 1e-9), axis=-1, keepdims=True) # compute the Shannon entropy beta = jnp.where( # beta = 1 / theta entropy > 0.5, # don't overcorrect low-entropy heads jnp.maximum(jnp.polyval(poly_fit, entropy), 1.0), # never increase entropy 1.0) return jax.nn.softmax(logits * beta) Figure 4. Our implementation of adaptive temperature in JAX. |
| Open Source Code | Yes | The JAX (Bradbury et al., 2018) implementation of our adaptive-θ softmax is provided in Figure 4, and we use it as a drop-in replacement for jax.nn.softmax in all of our experiments. |
| Open Datasets | Yes | To validate the utility of our proposed adaptive temperature scheme, we evaluate it on both our previously-mentioned max retrieval task which allows us a pristine environment for evaluating whether adaptive temperature leads to more useful attention heads as well as the CLRS-Text algorithmic reasoning benchmark (Markeeva et al., 2024), which represents a challenging reasoning task for decoder-only Transformers, and is hence likely to require low-entropy behaviour. |
| Dataset Splits | Yes | We train our model using sets of 16 items, and in Figure 2 we visualise the head s attentional coefficients, computed over sets of varying size at inference time. ... we fine-tune Gemma 2B models (Gemma Team et al., 2024) on the thirty algorithmic execution tasks in CLRS-Text, plotting their performance profiles inand out-of-distribution at various problem sizes. ... Both Gemma 2B variants were explicitly trained on CLRS-Text tasks the training set sizes are denoted by red dots and are evaluated zero-shot. |
| Hardware Specification | Yes | Under this implementation, adaptive temperature can easily scale to large context windows (which we have validated empirically up to 131,072 tokens) on a single NVIDIA A100 node. |
| Software Dependencies | No | A concise implementation of our network using JAX (Bradbury et al., 2018) and Flax (Heek et al., 2024) is as follows: |
| Experiment Setup | Yes | We train our model for 100,000 gradient steps using the Adam optimiser (Kingma & Ba, 2015) with initial learning rate of η = 0.001. At each step, we present to the model a batch of 128 input sets. All sets within a batch have the same size, sampled uniformly from n U{5, . . . , 16}. The model is trained using cross-entropy, along with L2 regularisation with hyperparameter λ = 0.001. |