The Initialization Determines Whether In-Context Learning Is Gradient Descent
Authors: Shifeng Xie, Rui Yuan, Simone Rossi, Thomas Hannagan
TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our experiments confirm this result and further observe that a performance gap between one-step GD and multi-head LSA persists. To address this gap, we introduce yq-LSA, a simple generalization of single-head LSA with a trainable initial guess yq. We theoretically establish the capabilities of yq-LSA and provide experimental validation on linear regression tasks, thereby extending the theory that bridges ICL and GD. Finally, inspired by our findings in the case of linear regression, we consider widespread LLMs augmented with initial guess capabilities, and show that their performance is improved on a semantic similarity task. |
| Researcher Affiliation | Collaboration | Shifeng Xie EMAIL Telecom Paris Institut Polytechnique de Paris France, Rui Yuan EMAIL Lexsi Labs, Paris France, Simone Rossi EMAIL EURECOM France, Thomas Hannagan EMAIL Stellantis France |
| Pseudocode | No | The paper describes the methodology using mathematical equations and prose, but does not include any clearly labeled pseudocode or algorithm blocks. |
| Open Source Code | No | We will release our code repository upon publication to facilitate reproducibility. |
| Open Datasets | Yes | Our experiments utilize Meta-LLa MA-3.1-8B-Instruct (Grattafiori et al., 2024), Qwen/Qwen2.5-7B-Instruct (Yang et al., 2024; Team, 2024) and the STS-Benchmark dataset (English subset) (May, 2021). |
| Dataset Splits | No | For experiments in Sections 5.1 and 5.2, we focus on a simplified setting where the LSA consists of a single linear self-attention layer without Layer Norm or softmax. We generate linear functions in a 10-dimensional input space (d = 10) and provide C = 10 context examples per task. ... For each prompt, a context was constructed by randomly sampling 10 labelled examples from the dataset. |
| Hardware Specification | No | The paper mentions using JAX for implementation but does not specify any particular hardware (e.g., GPU/CPU models, memory) used for running the experiments. |
| Software Dependencies | No | The experiments use JAX to implement and train the LSA models. ... The LLM used in our study was Meta-LLa MA-3.1-8B-Instruct (Grattafiori et al., 2024) and Qwen/Qwen2.5-7B-Instruct (Yang et al., 2024; Team, 2024). |
| Experiment Setup | Yes | We set the learning rate to lr = 5 10 4 and a batch size of 2,048. ... We train for 5000 gradient steps. ... The model was trained using Adam Optimizer(Kingma, 2014) with a learning rate of 1e-3 and a mean squared error loss function. Training was performed over 10 epochs, with a batch size of 8. |