What Do Learning Dynamics Reveal About Generalization in LLM Mathematical Reasoning?
Authors: Katie Kang, Amrith Setlur, Dibya Ghosh, Jacob Steinhardt, Claire Tomlin, Sergey Levine, Aviral Kumar
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our analysis focuses on reasoning tasks, whose problem structure allows us to distinguish between memorization (the exact replication of reasoning steps from the training data) and performance (the correctness of the final solution). We find that a model s performance on test prompts can be effectively characterized by a training metric we call pre-memorization train accuracy: the accuracy of model samples on training queries before they begin to copy the exact reasoning steps from the training set. On the dataset level, this metric is able to almost perfectly predict test accuracy, achieving R2 of 0.9 across various models (Llama3 8B, Gemma2 9B), datasets (GSM8k, MATH), and training configurations. Our experiments on data curation show that prioritizing examples with low pre-memorization accuracy leads to 1.5-2x improvements in data efficiency compared to i.i.d. data scaling and other data scaling techniques. |
| Researcher Affiliation | Academia | 1UC Berkeley 2CMU. Correspondence to: Katie Kang <EMAIL>. |
| Pseudocode | Yes | Algorithm 1 Our Data Collection Process 1: Input: N = N 1 + + N n, t 2: Output: Updated dataset D train 3: Initialize D train = {} 4: for i = 1 to n do 5: Train model on Dtrain + D train 6: Evaluate model on Dtrain and compute pre-memorization accuracy for each example 7: Set P i(x) as the distribution of examples with pre-memorization accuracy below t 8: Collect N i new examples from P i(x) and add them to D train 9: end for |
| Open Source Code | No | No explicit statement about providing open-source code or a link to a repository is found in the paper. |
| Open Datasets | Yes | Our experiments show that this phenomenon holds across different models (e.g., Llama3 8B (Dubey et al., 2024), Gemma2 9B (Team et al., 2024)), tasks (e.g., GSM8k (Cobbe et al., 2021), MATH (Hendrycks et al., 2021)), dataset sizes, and hyperparameter settings, with coefficients of determination around or exceeding 0.9. |
| Dataset Splits | No | No explicit percentages or sample counts for training, validation, and test splits are provided. The paper mentions 'training dataset Dtrain = {(xi, yi)}' and 'test dataset, Dtest', but does not specify their relative sizes or methodology for splitting, beyond modifying the overall size of the training data (full, half, or quarter). |
| Hardware Specification | No | The paper does not provide specific hardware details such as GPU models, CPU models, or memory specifications used for the experiments. |
| Software Dependencies | No | The paper mentions models like Llama3 8B and Gemma2 9B, and optimizers such as Adam W and Adam, but does not provide specific version numbers for any ancillary software libraries or programming languages used in the implementation. |
| Experiment Setup | Yes | For all training runs with GSM8k and Llama3 8B, we use the Adam W optimizer, with a linear decay learning rate scheduler with 20 warmup steps, a batch size of 128, and a max gradient norm of 2. Learning Rate Epochs Dataset Size 5e-5 6 full 2e-5 6 full 5e-7 6 full 2e-4 6 full 5e-5 3 full 2e-5 3 full 5e-7 3 full 2e-4 3 full 5e-5 1 full 5e-7 1 full 2e-4 1 full 2e-5 6 half 2e-5 12 quarter. (Similar details are provided for MATH Llama3 8B, GSM8k Gemma2 9B, and MATH Gemma2 9B in Appendix B). |