The Initialization Determines Whether In-Context Learning Is Gradient Descent

Authors: Shifeng Xie, Rui Yuan, Simone Rossi, Thomas Hannagan

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our experiments confirm this result and further observe that a performance gap between one-step GD and multi-head LSA persists. To address this gap, we introduce yq-LSA, a simple generalization of single-head LSA with a trainable initial guess yq. We theoretically establish the capabilities of yq-LSA and provide experimental validation on linear regression tasks, thereby extending the theory that bridges ICL and GD. Finally, inspired by our findings in the case of linear regression, we consider widespread LLMs augmented with initial guess capabilities, and show that their performance is improved on a semantic similarity task.
Researcher Affiliation Collaboration Shifeng Xie EMAIL Telecom Paris Institut Polytechnique de Paris France, Rui Yuan EMAIL Lexsi Labs, Paris France, Simone Rossi EMAIL EURECOM France, Thomas Hannagan EMAIL Stellantis France
Pseudocode No The paper describes the methodology using mathematical equations and prose, but does not include any clearly labeled pseudocode or algorithm blocks.
Open Source Code No We will release our code repository upon publication to facilitate reproducibility.
Open Datasets Yes Our experiments utilize Meta-LLa MA-3.1-8B-Instruct (Grattafiori et al., 2024), Qwen/Qwen2.5-7B-Instruct (Yang et al., 2024; Team, 2024) and the STS-Benchmark dataset (English subset) (May, 2021).
Dataset Splits No For experiments in Sections 5.1 and 5.2, we focus on a simplified setting where the LSA consists of a single linear self-attention layer without Layer Norm or softmax. We generate linear functions in a 10-dimensional input space (d = 10) and provide C = 10 context examples per task. ... For each prompt, a context was constructed by randomly sampling 10 labelled examples from the dataset.
Hardware Specification No The paper mentions using JAX for implementation but does not specify any particular hardware (e.g., GPU/CPU models, memory) used for running the experiments.
Software Dependencies No The experiments use JAX to implement and train the LSA models. ... The LLM used in our study was Meta-LLa MA-3.1-8B-Instruct (Grattafiori et al., 2024) and Qwen/Qwen2.5-7B-Instruct (Yang et al., 2024; Team, 2024).
Experiment Setup Yes We set the learning rate to lr = 5 10 4 and a batch size of 2,048. ... We train for 5000 gradient steps. ... The model was trained using Adam Optimizer(Kingma, 2014) with a learning rate of 1e-3 and a mean squared error loss function. Training was performed over 10 epochs, with a batch size of 8.