Teaching LLMs How to Learn with Contextual Fine-Tuning
Authors: Younwoo Choi, Muhammad Adil Asif, Ziwen Han, John Willes, Rahul G. Krishnan
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically demonstrate that this simple yet effective modification improves the ability of LLMs to be fine-tuned rapidly on new datasets both within the medical and financial domains. 5 CONTEXTUAL FINE-TUNING EXPERIMENTS 5.1 EXPERIMENTAL SETUP 5.2 RESULTS |
| Researcher Affiliation | Academia | 1University of Toronto 2Vector Institute |
| Pseudocode | Yes | Refer to Algorithm 1 in Appendix A.2 for the detailed algorithm. |
| Open Source Code | Yes | PROJECT PAGE: https://younwoochoi.github.io/cft-iclr/ |
| Open Datasets | Yes | Finally, we open-source a biomedical dataset curated from MDPI journals and other open-source medical textbooks. 4 OPENMEDTEXT To evaluate the effectiveness of contextual fine-tuning in a domain-adaptive setting, we curated a dataset consisting of both academic journal articles and educational textbooks. |
| Dataset Splits | No | The models are trained for one epoch with a batch size of 128 and a learning rate of 2e-5. |
| Hardware Specification | Yes | All training was conducted with 8 NVIDIA A100 GPUs. |
| Software Dependencies | No | We implemented Flash Attention 2 (Dao, 2024). This mentions a technique/library but not a specific version number that would allow for reproducibility. |
| Experiment Setup | Yes | The models are trained for one epoch with a batch size of 128 and a learning rate of 2e-5. To assess the efficiency of CFT, we carefully measured the computational resources required for our experiments and compared the overhead introduced by incorporating contextual prompts. Below are the details of our computational setup and findings. We utilized the Fully Sharded Data Parallel (FSDP) training to efficiently distribute the model across multiple GPUs. Training was performed using the bf16 (Brain Floating Point) data format. We implemented Flash Attention 2 (Dao, 2024). All training was conducted with 8 NVIDIA A100 GPUs. With the above configuration, we achieved a training speed of approximately 55,188 tokens per second, measured using the Llama tokenizer. The finetuning required a total of approximately 111.11 GPU-hours to complete. Incorporating contextual prompts increased the total training time by approximately 0.89 GPU-hours, resulting in a total of 112 GPU-hours. Each contextual prompt added only about 0.8% to the length of each training example on average. This slight increase in input length led to less than a 1% increase in total training time. All models are trained using the Adam optimizer (Kingma & Ba, 2015) with default parameters. |