Continual Adaptation of Vision Transformers for Federated Learning

Authors: Shaunak Halbe, James Seale Smith, Junjiao Tian, Zsolt Kira

TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We formulate this problem for image classification and establish strong baselines for comparison, conduct experiments on CIFAR-100 as well as challenging, large-scale datasets like Image Net-R and Domain Net. Our approach outperforms both existing methods and our own baselines by as much as 7% while significantly reducing communication and client-level computation costs.
Researcher Affiliation Academia Shaunak Halbe, James Seale Smith, Junjiao Tian, Zsolt Kira Georgia Institute of Technology Correspondence: EMAIL
Pseudocode Yes A Algorithm To better illustrate our proposed method, we present a whole picture of the method in Algorithm 1. The algorithm describes our complete procedure for a global task T i, where i [1, N].
Open Source Code Yes Code available at https://github.com/shaunak27/hepco-fed.
Open Datasets Yes Datasets. We conduct our experiments on three image classification datasets. First, we adapt CIFAR-100 Krizhevsky et al. (2009) to our formulation as it is a commonly used benchmark in CFL. Additionally, we evaluate our methods on the larger-scale Image Net-R Hendrycks et al. (2021) and Domain Net Peng et al. (2019) which have been used in recent continual learning works Smith et al. (2022) but haven t been explored in a continual federated learning setting.
Dataset Splits Yes Following Wang et al. (2022a), we use 20% of the training set as our validation dataset to determine hyperparameters for our approach and all competing baselines.
Hardware Specification Yes We use 2 NVIDIA A40 GPUs for all experiments. We conducted benchmarking using 2 NVIDIA TITAN RTX GPUs in a 5 client setup, as described in the experiments section.
Software Dependencies No We implement our methods in Py Torch and use the Py Torch Image Models library Wightman (2019) to obtain pretrained checkpoints. We use the Adam Kingma & Ba (2017) optimizer with β1 = 0.9 and β2 = 0.999. Explanation: The paper mentions PyTorch and Py Torch Image Models library but does not provide specific version numbers for these software components. It also mentions the Adam optimizer, which is an algorithm rather than a specific software dependency with a version.
Experiment Setup Yes Model Architecture. We use the Vi T-B/16 backbone Dosovitskiy et al. (2020) pretrained on Imagenet-1K Russakovsky et al. (2015) as the encoder for our method and all baselines. We use a prompt pool size (M) of 100 and a prompt length (Lp) of 20 with dimension (D) being 768 and insert prompts into 1-5 Multi-head Self Attention (MSA) layers of the Vi T encoder following the standard practice Wang et al. (2022b) and perform prefix-tuning as done in Smith et al. (2022), by prepending prompts to the keys and values of the MSA layers. The classifier (ϕ) is a fully-connected layer with input dimension D and output dimension equal to the number of classes. We implement the generator θgen using a three layer fully-connected network and train it for 100 epochs. We encode the class label using an embedding matrix and concatenate the obtained class embedding with the noise vector to form the input of the generator. Implementation Details. For all methods, we use the Adam Kingma & Ba (2017) optimizer with β1 = 0.9 and β2 = 0.999. We use a batch size of 64 for both local and server-side training... For our method and the prompting-based baselines, we use a learning rate of 1e-3, while for baselines that tune the entire model (Fed Avg, Fed Lw F.MC), we use 5e-5... The generator architecture has input sizes of [128, 256, 1024] per layer, with an output size of 768 which is the dimension of the visual query. We train the generator for 100 epochs using a batch size of 64 and a learning rate of 1e 4 using the Adam optimizer. We fine-tune the server model using a learning rate of 1e 4 for 200 epochs. We use a replay ratio of 0.5 for our method... we choose λKL and λMSE values to be 1 and 0.1 respectively.