Directional Gradient Projection for Robust Fine-Tuning of Foundation Models

Authors: Chengyue Huang, Junjiao Tian, Brisa Maneechotesuwan, Shivang Chopra, Zsolt Kira

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimental results show that Di Gra P consistently outperforms existing baselines across Image Classfication and VQA tasks with discriminative and generative backbones, improving both in-distribution (ID) generalization and OOD robustness. Overview. We test Di Gra P on a variety of benchmarks, tasks and architectures to validate its effectiveness.
Researcher Affiliation Academia Chengyue Huang, Junjiao Tian, Brisa Maneechotesuwan, Shivang Chopra, Zsolt Kira Georgia Institute of Technology EMAIL
Pseudocode Yes Algorithm 1 Adam with Trainable Directional Gradient Projection Input: θ0: pre-trained model, α: learning rate, µ: learning rate for ω, (β1, β2) (0.9, 0.999) Initialize: m0 0, v0 0 for t = 1 to T do gt,1 θL(θt 1) gt,2 θt 1 θ0 Gradients of the Objectives (Eq. 3) gproj t,1 gt,1 gt,2 gt,2 2 gt,2 Gradient Projection (Eq. 3) if t = 1 then ωt 0 Initialize ω else( ωt Normalization(αt 1 gt,1 gproj t 1,1) ωt max(0, min(1, Adam Update(ωt 1, ωt, µ, t)) Updating ω (Eq. 8) if gt,1 gt,2 < 0 then gt gt,1 Unconstrained Gradient Descent (Eq. 6) else gt gt,1 ωt gproj t,1 Directional Gradient Projection (Eq. 4, Eq. 7) mt β1mt 1 + (1 β1)gt vt β2vt 1 + (1 β2)g2 t Bias Correction: c mt mt 1 βt 1 , bvt vt 1 βt 2 Update: θt θt 1 αt c mt b vt+ϵ
Open Source Code Yes 1The code is available at https://github.com/chengyuehuang511/Di Gra P
Open Datasets Yes For Sec. 4.1 and Sec. 4.2, we use Domain Net (Peng et al., 2019) as the benchmark, which consists of six domains (real, sketch, painting, infograph, clipart and quickdraw) with 345 classes. ... For Sec. 4.3, we fine-tune on VQAv2 (Goyal et al., 2017) and test on nine OOD datasets using Lo RA (Hu et al., 2021). For the near OODs, we evaluate on VQAv2 s six variants, namely IV-VQA (Agarwal et al., 2020), CV-VQA (Agarwal et al., 2020), VQA-Rephrasings (Shah et al., 2019), VQA-CP v2 (Agrawal et al., 2018), VQA-CE (Dancette et al., 2021) and Ad VQA (Sheng et al., 2021)... We also include Text VQA (Singh et al., 2019), Viz Wiz (Bigham et al.) and OK-VQAv2 (Reichman et al., 2023), which are constructed from different sources than VQAv2, as the far OOD datasets.
Dataset Splits Yes For Sec. 4.1 and Sec. 4.2, we use Domain Net (Peng et al., 2019) as the benchmark, which consists of six domains ... We fine-tune our model on real domain and evaluate on all other domains. ... We consider VQAv2 (Goyal et al., 2017) as the ID dataset. We further evaluate the model on six near OOD datasets ... and three far OOD datasets... We sample 10% of the VQAv2 training and validation set.
Hardware Specification Yes We use 4 RTX 2080 GPUs for each experiment. ... We use 8 A40 GPU for each experiment.
Software Dependencies No We use the LAVIS (Li et al., 2022) public repository to fine-tune all methods. Standard hyper-parameters are used for all: learning rate (1e 3), weight-decay (1e 4), optimizer (Adam W), scheduler (Linear Warmup With Cosine Annealing), warm-up learning rate (1e 4), minimum learning rate (1e 4), accumulation steps (2), beam size (5).
Experiment Setup Yes For Di Gra P, we fine-tune the model using SGD with a learning rate of 1e 2 and µ = 0.1 with a batchsize of 256. ... Standard hyper-parameters are used for all: learning rate (1e 3), weight-decay (1e 4), optimizer (Adam W), scheduler (Linear Warmup With Cosine Annealing), warm-up learning rate (1e 4), minimum learning rate (1e 4), accumulation steps (2), beam size (5). The model is trained for 10 epochs with a batch size of 128 ... For Lo RA (Hu et al., 2021), we limit our study to only adapting the attention weights and freeze the MLP modules for parameter-efficiency, specifically apply Lo RA to Wq, Wk, Wv, Wo with r = 8 ... We use λ = 0.5 for all Di Gra P results in Tab. 2. The regularization hyper-parameter is found through cross-validation, and the model with the best ID validation accuracy is taken.