Merging by Matching Models in Task Parameter Subspaces
Authors: Derek Tam, Mohit Bansal, Colin Raffel
TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | To explore the effectiveness of Ma TS, we comprehensively compare it to existing merging methods on multitask and intermediate-task merging of language models and vision models trained via parameter-efficient or fullmodel fine-tuning. We first explore using preexisting merging methods as an initialization for Ma TS and demonstrate that Ma TS can significantly boost performance given a suitable choice of merging objective. In particular, in multitask language model merging, we find that Ma TS attains state-of-the-art results by a large margin. We use insights from this exploration to develop an effective merging recipe (i.e. a consistent initialization and objective to use) for parameter-efficient and full-model fine-tuning, which we then apply to multitask vision model merging and intermediate-task language model merging. In both cases, we validate that Ma TS can boost performance over its initialization and often attains state-of-the-art results. Finally, we discuss how Ma TS has a higher computational cost than existing merging methods but is nevertheless dramatically cheaper than explicit multitask training. Taken as a whole, our results validate both our perspective of model merging as matching models in their task parameter subspace as well as the effectiveness of using the conjugate gradient method for solving the corresponding linear system. |
| Researcher Affiliation | Academia | Derek Tam EMAIL University of Toronto Vector Institute Mohit Bansal EMAIL University of North Carolina Chapel Hill Colin Raffel EMAIL University of Toronto Vector Institute |
| Pseudocode | Yes | C Algorithm We show the algorithm for Ma TS in the (IA)3 setting in algorithm 1. Algorithm 1: Ma TS Algorithm |
| Open Source Code | Yes | We ultimately demonstrate that our merging framework called Matching Models in their Task Parameter Subspace (Ma TS) achieves state-of-the-art results in multitask and intermediate-task model merging. We release all of the code and checkpoints used in our work.1 1https://github.com/r-three/mats |
| Open Datasets | Yes | Experimental set-up. As an experimental setting, we focus on merging models fine-tuned on datasets from the T0 mixture (Sanh et al., 2021) to form a multitask models. Zhou et al. (2022) found eight datasets (listed in appendix D.2) were the most important for performance and thus we first focus on merging models fine-tuned on these eight datasets. ... For vision, we follow Ilharco et al. (2022); Yadav et al. (2023) and merge the same set of CLIP-based models (Radford et al., 2021) fine-tuned on eight tasks listed in appendix D.4. |
| Dataset Splits | No | We compute the empirical Fisher over the validation set and compare various ways of computing the Fisher in section 6.6. The only hyperparameter in Ma TS is the number of iterations to run the conjugate gradient method, which we allow to take on values ranging from 10 to 100 in step sizes of 10. We tune the hyperparameters for each method based on validation set performance. |
| Hardware Specification | Yes | In practice, we found that merging with Ma TS in the setting from section 6.2 only takes about 11 minutes on a single NVIDIA A6000 GPU, which is a relatively modest amount of time in the realm of training large neural networks. |
| Software Dependencies | No | The paper mentions software components like "Adam W" which is an optimizer, but does not specify any programming languages or library versions (e.g., Python, PyTorch, TensorFlow versions) that would be needed to reproduce the work. |
| Experiment Setup | Yes | D.1 Fine-tuning Details For all the models we fine-tuned, we use the same hyperparameter setup to be consistent. Concretely, we use Adam W, learning rate of 1e 4, batch size of 1024, bfloat16 during training, training for 5K batches, checkpointing every 100 batches, and early stopping if the model has not improved for 5 checkpoints. For (IA)3, the only difference is the learning rate is set to 5e 3. |