As Simple as Fine-tuning: LLM Alignment via Bidirectional Negative Feedback Loss

Authors: Xin Mao, Huimin Xu, Feng-Lin Li, Ziqi Jin, WANG CHEN, Wei Zhang, Anh Tuan Luu

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We conduct extensive experiments across two challenging QA benchmarks and four reasoning benchmarks. The experimental results show that BNF achieves comparable performance to the best methods on QA benchmarks, while its performance decrease on the four reasoning benchmarks is significantly lower compared to the best methods, thus striking a better balance between value alignment and reasoning ability.
Researcher Affiliation Collaboration Xin Mao1, Huimin Xu1, Feng-Lin Li2, Ziqi Jin1, Wang Chen2, Wei Zhang3, Anh Tuan Luu1 1Nanyang Technological University, 2Shopee Pte. Ltd, 3SEA Group EMAIL, EMAIL EMAIL, EMAIL
Pseudocode Yes B CODE IMPLEMENTATION def BNF_loss(batch): """ Computes BNF loss for preference optimization. Args: batch: A tuple of (input_ids, lengths, labels) input_ids: input token ids (batch_size, seq_len) lengths: response lengths (batch_size,) labels: Binary labels for preference (batch_size,) Returns: loss: The computed loss value. """ # Unpack batch elements input_ids, lengths, labels = batch # Compute log-softmax for policy and reference models # policy_logp & ref_logp: (batch_size, seq_len, vocab_size) policy_logp = policy_model(input_ids).logits.log_softmax(-1) ref_logp = ref_model(input_ids).logits.log_softmax(-1) # Get log probabilities for the actual response tokens # response_logp has shape (batch_size, seq_len) response_logp = torch.gather(policy_logp, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) response_logp_ref = torch.gather(ref_logp, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) # Sum log probabilities for non-response tokens # other_logp has shape (batch_size, seq_len) other_logp = (policy_logp.exp().detach() * policy_logp).sum(-1) response_logp.exp().detach() * response_logp # Compute the dynamic target distribution for token in response y responses_target = torch.clamp(response_logp.exp() / response_logp_ref.exp(), max=1).detach() # Compute the dynamic target distribution for other token others_target = (1 responses_target) / (1 response_logp.exp().detach()) # Compute final loss and apply length normalization loss = responses_target * response_logp + others_target * other_logp loss = (loss.sum(-1) * labels / lengths).sum() return loss
Open Source Code Yes Github Url: https://github.com/Mao Xinn/BNF.
Open Datasets Yes For a fair comparison, we use the same preference training datasets constructed by Sim PO. Specifically, for each prompt x in Ultrafeedback (Cui et al., 2024), they generate 5 responses with a sampling temperature of 0.8. Then, using Pair RM (Jiang et al., 2023b) or Armo RM (Wang et al., 2024a) to score the 5 responses, selecting the highest-scoring one as yw and the lowest-scoring one as yl. We primarily evaluate all the models using two recent proposed instruction-following QA benchmarks: Arena-Hard (Li et al., 2024) and Wild-Bench(Lin et al., 2024). Furthermore, we also evaluate all the models on four logical reasoning benchmarks to verify the impact of these alignment methods on model s reasoning abilities, including: MMLU-redux (Gema et al., 2024) (Language), CRUX (Gu et al., 2024) (Code), GSM8K (Cobbe et al., 2021) and MATHL5 (Hendrycks et al., 2021) (Math). Specifically, we first fine-tune Mistral-Inst and Llama-3-Inst on a widely used mathematical synthetic dataset Meta Math (Yu et al., 2023) 3, obtaining Mistral-MM and Llama-3-MM. 3Meta Math contains 395K mathematical prompts and their corresponding answers. https:// huggingface.co/datasets/meta-math/Meta Math QA
Dataset Splits Yes For a fair comparison, we use the same preference training datasets constructed by Sim PO. Specifically, for each prompt x in Ultrafeedback (Cui et al., 2024), they generate 5 responses with a sampling temperature of 0.8. Then, using Pair RM (Jiang et al., 2023b) or Armo RM (Wang et al., 2024a) to score the 5 responses, selecting the highest-scoring one as yw and the lowest-scoring one as yl. Table 6: Experimental results with different pairing ratios. A ratio of 50% indicates that in half of the preference pairs, one response is randomly masked. A ratio of 0% means that no pairwise data is included in the dataset, while 100% represents the original preference dataset. We primarily evaluate all the models using two recent proposed instruction-following QA benchmarks: Arena-Hard (Li et al., 2024) and Wild-Bench(Lin et al., 2024). Furthermore, we also evaluate all the models on four logical reasoning benchmarks to verify the impact of these alignment methods on model s reasoning abilities, including: MMLU-redux (Gema et al., 2024) (Language), CRUX (Gu et al., 2024) (Code), GSM8K (Cobbe et al., 2021) and MATHL5 (Hendrycks et al., 2021) (Math). Figure 5: Experimental results of Mistral and Llama-3 on GSM8K. We save the model every 200 training steps and evaluate them on the testset of GSM8K.
Hardware Specification Yes We use 8 A100-80GB-SXM for training, and the precision is bf16.
Software Dependencies No In this paper, we set the maximum sequence length to 4096 and adopt the Adam W optimizer (Loshchilov & Hutter, 2018), applying cosine learning rate schedule with 10% warm-up steps. Although DTD seems complex at first glance, its code implementation is quite simple and efficient (as shown in Appendix B). Moreover, compared to DPO-series methods, the BNF loss involves no tunable hyper-parameters and eliminates the need for pairwise preference data, which reduces the costs of grid searches. policy_logp = policy_model(input_ids).logits.log_softmax(-1) ref_logp = ref_model(input_ids).logits.log_softmax(-1) response_logp = torch.gather(policy_logp, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
Experiment Setup Yes In this paper, we set the maximum sequence length to 4096 and adopt the Adam W optimizer (Loshchilov & Hutter, 2018), applying cosine learning rate schedule with 10% warm-up steps. Since our proposed BNF loss does not have any extra tunable hyper-parameters, we only perform grid searches on batch size {64, 128, 256} and learning rate {5e-7, 6e-7, 8e-7, 1e-6}. After grid search, we adopt a unified batch size of 128 and select learning rates of 5e-7 for Mistral-7B-Instruct-v0.2, 6e-7 for Meta-Llama-3-8B-Instruct, and 8e-7 for Gemma-2-9b-it.