Building Math Agents with Multi-Turn Iterative Preference Learning

Authors: Wei Xiong, Chengshuai Shi, Jiaming Shen, Aviv Rosenberg, Zhen Qin, Daniele Calandriello, Misha Khalman, Rishabh Joshi, Bilal Piot, Mohammad Saleh, Chi Jin, Tong Zhang, Tianqi Liu

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental The effectiveness of our framework is validated through training of various language models using an augmented prompt set from the GSM8K and MATH datasets. Our results demonstrate substantial improvements: a supervised fine-tuned Gemma-1.1-it-7B model s performance increased from 77.5% to 83.9% on GSM8K and from 46.1% to 51.2% on MATH. Similarly, a Gemma-2-it-9B model improved from 84.1% to 86.3% on GSM8K and from 51.0% to 54.5% on MATH.
Researcher Affiliation Collaboration University of Illinois Urbana-Champaign1 University of Virginia2 Google Deepmind3 Google Research4 Princeton University5 EMAIL; EMAIL
Pseudocode Yes Algorithm 1 Online Iterative M-GSHF
Open Source Code Yes We also provide a comprehensive recipe for the practical implementation of our online iterative multi-turn methods, and will make our models, datasets, and code publicly available for further research and development. ... Additionally, we have open-sourced our training code along with a step-by-step guide, using Gemma-1.1-it-7B as an example. We have also made the processed SFT dataset, prompt set, and the training data for the first iteration of M-DPO/M-KTO available for easy download (see supplemental materials for details).
Open Datasets Yes We use the test sets of MATH (Hendrycks et al., 2021) and GSM8K (Cobbe et al., 2021a) to measure the model s ability to solve the mathematical problems. To construct the training prompt set, we use the prompts from Meta Math QA (Yu et al., 2023) and MMIQC (Liu & Yao, 2024), which is an augmented prompt set from the 7.5K training problems of MATH and 7.47K training problems of GSM8K.
Dataset Splits Yes To construct the training prompt set, we use the prompts from Meta Math QA (Yu et al., 2023) and MMIQC (Liu & Yao, 2024), which is an augmented prompt set from the 7.5K training problems of MATH and 7.47K training problems of GSM8K. ... We evaluate the model every 50 training steps by the split prompt set.
Hardware Specification Yes The RLHF experiments of this paper are run with 8x A100 80G GPUs, where an additional machine with 8x A100 40G GPUs is also used to accelerate data collection and model evaluation.
Software Dependencies Yes We use transformers 4.42.4, torch 2.3.0, sympy 1.2, antlr4python3-runtime 4.11.0, IPython 8.26.0 for all models. We evaluate the models using torch.float and use vllm 0.5.0.post1 for most the experiments except for Gemma-2 where vllm 0.5.1 is required. ... For SFT, we use the open-source axolotl project with version 0.4.1 and for online iterative preference learning and RAFT, we use the code base from RLHF Workflow (Dong et al., 2024).
Experiment Setup Yes We run the iterative training for 3 epochs in total. ... Then, we train the model on the collected samples using the M-DPO/M-KTO loss. ... We train the model for 1 epoch at most and tune the learning rate in {2e-7, 4e-7, 7e-7, 1e-6} with the first iteration of iterative training. Eventually, the learning rate of 4e-7 is used for Gemma-1.1 models and 2e-7 is used for Gemma-2 model and Mistral model. The global batch size is 32 with a warm-up step of 40. ... For all the data generation process, we adopt the following constraints: (1) for each turn, the model can generate up to 512 tokens; (2) the maximal number of steps is H=6; (3) the maximal number of token for each trajectory is 2048.