PRISM: Privacy-Preserving Improved Stochastic Masking for Federated Generative Models

Authors: Kyeongkook Seo, Dong-Jun Han, Jaejun Yoo

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experiments on MNIST, FMNIST, Celeb A, and CIFAR10 demonstrate that PRISM outperforms existing methods, while maintaining privacy with minimal communication costs. PRISM is the first to successfully generate images under challenging non-IID and privacy-preserving FL environments on complex datasets, where previous methods have struggled.
Researcher Affiliation Academia Kyeongkook Seo1, Dong-Jun Han2 , Jaejun Yoo1 1Ulsan National Institute of Science and Technology (UNIST) 2Department of Computer Science and Engineering, Yonsei University EMAIL, EMAIL
Pseudocode Yes Algorithm 1 MADA Parameter: learning rate η, communication rounds T, local iterations I Input: local datasets K k=1Dk, Image Net pretrained VGGNet ψ, random noise z Server execute: Initialize a random weight Winit and score vector s, then broadcasts to all clients. for round t = 1, ..., T do Client side: for each client k [1, K] do sk t = st Download score vector for local iteration i = 1, , , L do θk t Sigmoid(sk t ) M k t Bern(θk t ) W k t Winit M k t Dk fake W k t (z) Generate fake images Extract real and fake features ψ(Dk), ψ(Dk fake) sk t sk t η Lk MMD(ψ(Dk), ψ(Dk fake)) Update local score vector end for θk t Sigmoid(sk t ) θk t = θk t + N(0, Iσ2) Clip to [c, 1-c] M k t Bern(θk t )) Upload binary mask M k t to the server. end for Server side: ˆθt+1 PK k=1 M k t Aggregate the received binary masks st+1 Sigmoid 1(ˆθt+1) λ hd(Mt, Bern(ˆθt+1)) Compute the hamming distance st+1 (1 λ)st + λst+1 end for Sample the supermask M Bern(θT ) Obtain the final model W Winit M Algorithm 2 PRISM Input: ratio of score layer α Output: probability θk t (100 α) and binary mask M k t (α)
Open Source Code No Our code is available at PRISM.
Open Datasets Yes In this section, we validate the effectiveness of PRISM on MNIST, FMNIST, Celeb A, and CIFAR10 datasets.
Dataset Splits Yes The training set of each dataset is distributed across 10 clients following either IID or non-IID data distributions1. We provide the detailed partitioning strategies for IID and non-IID simulation in Appendix D. In addition, for a fair comparison, we set (9.8, 10 5)-DP for all methods. Appendix D: Here, we elaborate on the dataset partitioning strategies used to simulate the IID and non-IID setup. For non-IID setup, we used two strategies for partitioning datasets: 1) Shards-partitioned and 2) Dirichlet-partitioned. IID setup. In the case of IID scenario, we assign equal-sized local datasets by uniformly sampling from the entire training set. This allows for a balanced distribution of data across the clients. Non-IID setup (Shards-partitioned). In Section 5.2 and Section 5.3, we partition the MNIST, FMNIST, and CIFAR10 datasets into 40 segments based on the sorted class labels and randomly assign four segments to each client. For the Celeb A dataset, which contains multiple attributes per image, defining a clear non-IID distribution for splitting is inherently ambiguous. In this work, we divide the dataset into tow partitions based on a pivotal attribute (gender, in our case). The total number of partitions corresponds to the number of clients, with each client being assigned either the positive or negative subset of images for the pivotal attribute. However, the remaining 39 attributes are still shared among clients, resulting in relatively weak heterogeneity. This explains why the performance drop on Celeb A is not as significant compared to the IID case. For a more detailed discussion, please refer to Appendix D. Non-IID setup (Dirichlet-partitioned). We further explore a more realistic and challenging non IID scenario, where datasets are partitioned using Dirichlet distribution. Specifically, we set Dirichlet parameter α = 0.005 to create a more label-skewed distribution. For Celeb A, we assign data to clients such that each client possesses the pivotal attribute in different proportions using Dirichlet distribution. For example, client 1 has 60% male and 40% female, while client 2 has 20% male and 80% female.
Hardware Specification No The paper mentions 'GPU-accelerated gradient descent' in Section J, but does not provide specific models or detailed specifications for the hardware used in experiments.
Software Dependencies No The paper mentions 'Adam optimizer (Kingma & Ba, 2014)]' and the 'Opacus library which is the user-friendly pytorch framework for differential privacy (Yousefpour et al., 2021)]', but does not provide specific version numbers for these or other software components.
Experiment Setup Yes In this section, we provide the detailed description of our implementations and experimental settings. In Table 4, we provide the model architectures used in our experiments. We use Res Net-based generator and set the local epoch to 100 and learning rate to 0.1. In addition, we do not employ training schedulers or learning rate decay. Our code is based on (Santos et al., 2019; Yeo et al., 2023)]. They employ the Image Net-pretrained VGG19 network for feature matching by minimizing the Eq. 1. However, calculating the first and second moments require the large batch size to obtain the accurate statistics. To address this issue, (Santos et al., 2019)] introduces Adam moving average (AMA). With a rate λ, the update of AMA m is expressed as follows: m m λADAM(m ), (7) where ADAM denotes Adam optimizer (Kingma & Ba, 2014)] and is the discrepancy of the means of the extracted features. Note that ADAM(m ) can be interpreted as gradient descent by minimizing the L2 loss: min m 1 2 m 2 . (8) This means the difference of statistics (m ) is passed through a single MLP layer and updated using the Adam optimizer to the direction of minimizing Eq. 8. By utilizing AMA, Eq. 1 is for-mulated as Lk MMD = Ex Dk[ψ(x)] Ey Dk fake[ψ(y)] 2 + Cov(ψ(Dk)) Cov(ψ(Dk fake)) 2 , Algorithm 1, 2 provides the pseudocode for MADA and PRISM correspondingly. AMA is omitted to simply express the flow of our framework. See our code for pytorch implementation. We train the local generator for 100 local iterations with learning rate of 0.1. For the AMA layer, learning rate is set to 0.005. In addition, we use the Adam optimizer with β1 = 0.5, β2 = 0.999 to update the scores of the generators. After all clients complete their training, communication round is initiated. We set the global epoch to 150 for the MNIST dataset and 350 for the Celeb A and CIFAR10 datasets. As we do not adjust the parameters, note that there is room for performance improvements through hyperparameter tuning.