Upweighting Easy Samples in Fine-Tuning Mitigates Forgetting
Authors: Sunny Sanyal, Hayden Prairie, Rudrajit Das, Ali Kavis, Sujay Sanghavi
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically demonstrate the efficacy of our method on both language and vision tasks. As an example, when fine-tuning Gemma 2 2B on Meta Math QA, our method results in only a 0.8% drop in accuracy on GSM8K (another math dataset) compared to standard finetuning, while preserving 5.4% more accuracy on the pre-training datasets. |
| Researcher Affiliation | Collaboration | 1University of Texas at Austin 2Google Research. Correspondence to: Sunny Sanyal <EMAIL>, Hayden Prairie <EMAIL>, Rudrajit Das <EMAIL>, Ali Kavis <EMAIL>, Sujay Sanghavi <EMAIL>. |
| Pseudocode | Yes | Algorithm 1 Fine-tuning with Pre-trained Loss-Oriented Weighting (FLOW) Input: Pre-trained model θ , dataset {(xi, yi)}n i=1 for the new task, and temperature parameter τ. fi(θ) ith sample s loss at θ, with a non-negative loss function (e.g., cross-entropy loss). 1. Compute sample weights: wi = exp fi(θ ) τ . 2. Weighted loss: L(θ) = Pn i=1 wifi(θ). 3. Fine-tune with weighted loss: bθ := arg min θ L(θ). Output: Fine-tuned model bθ . |
| Open Source Code | Yes | 2Our code is publicly available here. |
| Open Datasets | Yes | Models. We experimented with Res Net-18 and Res Net-50 (Wightman et al., workshop) pre-trained on Imagenet-1K (IN-1K). Datasets. We used seven widely-used image classification datasets: CIFAR-10 (Krizhevsky, 2009), CIFAR-100 (Krizhevsky, 2009), Flowers102 (Nilsback & Zisserman, 2008), Caltech101 (Li et al., 2022), Cars (Krause et al., 2013), and Dogs (Parkhi et al., 2012). Datasets. Following previous work (Biderman et al., 2024; Chen et al., 2024b), we fine-tune on Meta Math QA (Yu et al., 2023), a mathematical reasoning dataset that is bootstrapped from the training set of GSM8K (Cobbe et al., 2021) and MATH (Hendrycks et al., 2021b) using a LLM. |
| Dataset Splits | Yes | Evaluation metrics. Vision models are trained with taskspecific parts, such as classification head (head) and batchnorm (BN); see Appendix B for how FLOW works with with task-specific parts. Forgetting is measured by how much the model s top-1 validation accuracy on Image Net1K (subsequently referred to as IN-1K accuracy) reduces after fine-tuning. Datasets. We train with all 395K samples in Meta Math QA. Evaluation metrics. To evaluate the validity of FLOW, we break down our metrics into general capability and target fine-tuning evaluations. To evaluate general capabilities, we again follow a similar setup to Chen et al. (2024b), where we use commonsense reasoning, 5-shot MMLU (Hendrycks et al., 2021a), and 3-shot MBPP (Austin et al., 2021) metrics. To evaluate the target domain, we use 5-shot GSM8K (Cobbe et al., 2021). |
| Hardware Specification | Yes | Table 8. Hyperparameter configurations for finetuning Res Net18 and Res Net50 on the image classification datasets. # GPUs 1 A6000 |
| Software Dependencies | Yes | All training for language experiments is done with Hugging Face peft (Mangrulkar et al., 2022), transformers (Wolf et al., 2020), datasets (Lhoest et al., 2021), and accelerate (Gugger et al., 2022). |
| Experiment Setup | Yes | For both Gemma 2 2B (Team et al., 2024) and Llama 3.2 3B (Grattafiori et al., 2024), we run hyper-parameter sweeps on learning rates for each baseline. For standard fine-tuning, ℓ2 regularization, and FLOW, we do a learning rate sweep in [1e-4, 2e-5, 1e-5, 5e-6], and for Lo RA (r = 64) we do a sweep in [2e-4, 2e-1], following the learning rates used in (Biderman et al., 2024). We then select the learning rate that results in the best GSM8K (Cobbe et al., 2021) accuracy, oblivious to general capability metrics. We report the hyper-parameters used for our Gemma 2 2B (Team et al., 2024) experiments in Table 6 and for Llama 3.2 3B (Grattafiori et al., 2024) in Table 7. |