Understanding Factual Recall in Transformers via Associative Memories
Authors: Eshaan Nichani, Jason Lee, Alberto Bietti
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Empirical Validation. In Figure 1, we train both linear and MLP associative memories to store the association f (x) = x. Given a fixed model size (d, m), we fit datasets with increasing values of N using the cross entropy loss, and plot the largest value of N for which we can obtain at least 99% accuracy. We next empirically verify Theorems 5 and 6. We first train the linear attention model with orthogonal embeddings (15) with S = 16, R = 4 and D = 8, and plot the loss over time. |
| Researcher Affiliation | Academia | Eshaan Nichani Princeton University Jason D. Lee Princeton University Alberto Bietti Flatiron Institute |
| Pseudocode | No | The paper describes methods with mathematical formulations and proofs, but it does not contain any structured pseudocode or algorithm blocks. |
| Open Source Code | Yes | Code for all the experiments can be found at https://github.com/eshnich/ factual-recall-iclr. |
| Open Datasets | No | In this section we introduce a synthetic factual recall task, and show that one-layer transformers constructed via associative memories can store a number of facts proportional to parameter count. The data distribution consists of prompts containing a subject token s and relation token r hidden amongst a set of noise tokens, which the learner must map to a ground truth answer a (s, r). The paper generates synthetic data rather than using a publicly available dataset or providing access to their generated data. |
| Dataset Splits | No | The data distribution is over length T + 1 sequences z1:T +1 := (z1, z2, . . . , z T , z T +1) VT +1, generated via the following procedure: 1. First, sample a subject and relation tuple (s, r) from some distribution p over S R. 2. Next, sample two distinct indices i, j [T 1]. Set zi = s and zj = r. 3. For the remainder of tokens zk where k [T 1] \ {i, j}, draw zk uniformly at random from the noise tokens N. 4. Set z T = EOS. 5. Finally, set z T +1 = a (s, r). The paper describes a generative process for the data rather than providing specific train/test/validation splits of a fixed dataset. It mentions 'online batch gradient descent with batch size 1024 on the population loss (i.e we sample an independent batch at each timestep)' which implies dynamic sampling for training, not fixed splits. |
| Hardware Specification | No | The paper does not provide specific details about the hardware used for running experiments, such as GPU or CPU models. |
| Software Dependencies | No | We use ADAM with a learning rate of 10-2 for 2^14 steps. While ADAM is mentioned, no specific version numbers for software libraries or dependencies (e.g., Python, PyTorch, TensorFlow) are provided. |
| Experiment Setup | Yes | We train a two-layer neural network on the cross entropy loss... The network is trained using ADAM with a learning rate of 10-2 for 2^14 steps. We consider a fixed prompt length of T = 32, and train the models via online batch gradient descent with batch size 1024 on the population loss... All models were trained using ADAM for 2^14 steps, with a sweep over learning rates in {.001, .003, .01} (where we consider the best performing model over all learning rates). |