Learn Your Reference Model for Real Good Alignment
Authors: Alexey Gorbatovski, Boris Shaposhnikov, Alexey Malakhov, Nikita Surnachev, Yaroslav Aksenov, Ian Maksimov, Nikita Balagansky, Daniil Gavrilov
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our results show that TR alignment methods effectively mitigate overoptimization, enabling models to maintain strong performance even when substantially deviating from the initial reference policy. We demonstrate the efficacy of these approaches not only through toy examples that exhibit reduced overoptimization, but also through direct, side-by-side comparisons in specific tasks such as helpful and harmless dialogue, as well as summarization, where they surpass conventional methods. Additionally, we report significant improvements in general-purpose assistant setups with the Llama3 model on the Alpaca Eval 2 and Arena-Hard benchmarks, highlighting the advantages of Trust Region methods over classical approaches. |
| Researcher Affiliation | Industry | Alexey Gorbatovski, Boris Shaposhnikov , Alexey Malakhov, Nikita Surnachev, Yaroslav Aksenov, Ian Maksimov, Nikita Balagansky, Daniil Gavrilov T-Tech *Corresponding author: EMAIL |
| Pseudocode | No | The paper provides mathematical derivations and descriptions of methods but does not include any explicitly labeled pseudocode or algorithm blocks with structured steps. |
| Open Source Code | No | The paper discusses the use of existing open-source models (Pythia, Llama3) and datasets, but does not provide any concrete access information (link to repository, explicit statement of code release) for the implementation of the Trust Region Alignment methods (TR-DPO, TR-IPO, TR-KTO) described in this paper. |
| Open Datasets | Yes | Task-Specific Datasets: For specialized evaluations, the Anthropic-HH1 (Bai et al., 2022) dataset, which focuses on dialogue alignment where preferred responses are selected based on their helpfulness and harmlessness. For the summarization task, we employ the Reddit TL;DR summarization2 (Stiennon et al., 2020) dataset, training models to generate concise and accurate summaries of long social posts. General Benchmarks: For broader, general-purpose evaluations, we use the Ultra Chat-200k3 (Ding et al., 2023a) dataset, designed to train models ability to follow instructions in open-domain conversations. Additionally, the Ultra Feedback4 (Cui et al., 2023) dataset provides a binarized preference framework, useful for aligning models across various domains in an offline setting, making it suitable for training and evaluating general-purpose assistants. 1https://huggingface.co/datasets/Anthropic/hh-rlhf 2https://huggingface.co/datasets/UCL-DARK/openai-tldr-summarisation-preferences 3https://huggingface.co/datasets/Hugging Face H4/ultrachat_200k 4https://huggingface.co/datasets/Hugging Face H4/ultrafeedback_binarized |
| Dataset Splits | Yes | A summary of the dataset sizes is provided in Table 4. Table 4: Summary of dataset sizes used for training and validation. Dataset Training Examples Validation Examples Anthropic-HH 160,800 8,552 Reddit TL;DR summarization (SFT) 41,947 11,941 Reddit TL;DR summarization (Preference) 73,396 21,198 Ultra Chat-200k 207,865 23,110 Ultra Feedback 61,135 2,000 |
| Hardware Specification | Yes | All computations were performed on 8 NVIDIA A100 GPUs with 80GB of memory, which provided the necessary computational power to efficiently train our models. |
| Software Dependencies | Yes | Table 3: Training hyperparameters for Pythia and Llama3 Models. Optimizer Adam (Kingma & Ba, 2014). Memory optimization Deep Speed (Rasley et al., 2020). Attention Mechanism Flash Attention 2 (Dao, 2023). |
| Experiment Setup | Yes | D.1 TRAINING DETAILS The training of Pythia and Llama models adhered to a set of hyperparameters optimized for performance (see Table 3). Unless otherwise noted, the following hyperparameters were consistent across all training setups. Table 3: Training hyperparameters for Pythia and Llama3 Models. Hyperparameter Value Max Tokens Length 1024 (Pythia), 2048 (Llama3) Epochs 1 Learning Rate (SFT) 6.0 10 6 Learning Rate (Baseline/TR) 1.0 10 6 Optimizer Adam (Kingma & Ba, 2014) Adam β1 0.9 Adam β2 0.95 Batch Size 128 Learning Schedule Linear Decay (Loshchilov & Hutter, 2016) Warm-up Steps 100 Max gradient norm 2 Memory optimization Deep Speed (Rasley et al., 2020) Attention Mechanism Flash Attention 2 (Dao, 2023) |