Beyond Human Data: Scaling Self-Training for Problem-Solving with Language Models
Authors: Avi Singh, John D Co-Reyes, Rishabh Agarwal, Ankesh Anand, Piyush Patil, Xavier Garcia, Peter J Liu, James Harrison, Jaehoon Lee, Kelvin Xu, Aaron T Parisi, Abhishek Kumar, Alexander A Alemi, Alex Rizkowsky, Azade Nova, Ben Adlam, Bernd Bohnet, Gamaleldin Fathy Elsayed, Hanie Sedghi, Igor Mordatch, Isabelle Simpson, Izzeddin Gur, Jasper Snoek, Jeffrey Pennington, Jiri Hron, Kathleen Kenealy, Kevin Swersky, Kshiteej Mahajan, Laura A Culp, Lechao Xiao, Maxwell Bileschi, Noah Constant, Roman Novak, Rosanne Liu, Tris Warkentin, Yamini Bansal, Ethan Dyer, Behnam Neyshabur, Jascha Sohl-Dickstein, Noah Fiedel
TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Testing on advanced MATH reasoning and APPS coding benchmarks using Pa LM-2 models, we find that Re STEM scales favorably with model size and significantly surpasses fine-tuning only on human data. Overall, our findings suggest self-training with feedback can reduce dependence on human-generated data. 5 Experiments and analysis |
| Researcher Affiliation | Industry | All authors are with Google Deep Mind. Correspondence to EMAIL. |
| Pseudocode | Yes | Algorithm 1: Re ST (Expectation-Maximization). Given a initial policy (e.g., pre-trained LM), Re STEM iteratively applies Generate and Improve steps to update the policy. Input: D: Training dataset, Dval: Validation dataset, L(x, y; θ): loss, r(x, y): Non-negative reward function, I: number of iterations, N: number of samples per context for i = 1 to I do // Generate (E-step) Generate dataset Di by sampling: Di = { (xj, yj)|N j=1 s.t. xj D, yj pθ(y|xj) } Annotate Di with the reward r(x, y). // Improve (M-step) while reward improves on Dval do Optimise θ to maximize objective: J(θ) = E(x,y) Di [r(x, y) log pθ(y|x)] end end Output: Policy pθ |
| Open Source Code | No | The paper does not provide an explicit statement about releasing its source code or a direct link to a code repository for the methodology described. |
| Open Datasets | Yes | We evaluate Re STEM primarily on mathematical problem solving using the Hendrycks MATH dataset (Hendrycks et al., 2021b) and code generation using the APPS (Introductory) dataset (Hendrycks et al., 2021a). ... For measuring transfer performance, we look at GSM8K (Cobbe et al., 2021), Hungarian HS finals (Paster, 2023), and Human Eval (Chen et al., 2021) datasets. We also evaluate our models using the Big-Bench Hard (Suzgun et al., 2022) benchmark to evaluate general capabilities. |
| Dataset Splits | Yes | Training Datasets. We evaluate Re STEM primarily on mathematical problem solving using the Hendrycks MATH dataset (Hendrycks et al., 2021b) and code generation using the APPS (Introductory) dataset (Hendrycks et al., 2021a). MATH and APPS (Introductory) contain 7500 and 2342 training problems respectively. We select these tasks because the model outputs can be automatically evaluated as correct or incorrect, perfectly suited for Re STEM. Both these datasets offer binary rewards: on MATH, model-generated answers can be easily verified for correctness using the ground-truth answer, while on APPS, test cases determine whether the generated code is correct. Evaluation. We report generalization performance using the test splits of the MATH and APPS (Introductory) datasets. For measuring transfer performance, we look at GSM8K (Cobbe et al., 2021), Hungarian HS finals (Paster, 2023), and Human Eval (Chen et al., 2021) datasets. We also evaluate our models using the Big-Bench Hard (Suzgun et al., 2022) benchmark to evaluate general capabilities. |
| Hardware Specification | Yes | Due to the cost of our experiments (thousands of TPU hours for every fine-tuning run), each experiment is performed once. |
| Software Dependencies | No | The paper mentions using Pa LM 2 models (Google et al., 2023) and refers to general concepts like "language model" and "policy," but it does not specify any particular software libraries with version numbers (e.g., Python, PyTorch, TensorFlow, CUDA) that would be needed to reproduce the experiment. |
| Experiment Setup | Yes | During each iteration of Re STEM, we generated a fixed number of solutions per problem for the E-step: 32 for the MATH dataset and 64 for the APPS dataset. For generating solutions, we sample from the language model using top-K sampling with K=40 and temperature of 0.7. However, directly using all these model-generated solutions can lead to an imbalanced dataset, as we will have a lot more correct solutions for the easier problems. To mitigate this, we introduced a cut-off threshold for the maximum number of solutions per problem, a design choice also used by Zelikman et al. (2022), included in the fine-tuning dataset: 10 for both MATH and APPS. For fine-tuning, we use the few-shot prompt (and the question) as input to the model, and use the model-generated solutions as targets. We only apply the next token prediction loss (Equation 1) on the targets. |