Robust Weight Imprinting: Insights from Neural Collapse and Proxy-Based Aggregation
Authors: Justus Westerhoff, Golzar Atefi, Mario Koddenbrock, Alexei Figueroa, Alexander Löser, Erik Rodner, Felix Alexander Gers
TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this work, we propose the general IMPRINT framework, identifying three main components: generation, normalization, and aggregation. Through the lens of this framework, we conduct an in-depth analysis and a comparison of the existing methods. Our findings reveal the benefits of representing novel data with multiple proxies in the generation step and show the importance of proper normalization. Beyond an extensive analytical grounding, our framework enables us to propose a novel variant of imprinting which outperforms previous work on transfer learning tasks by 4%. We present IMPRINT, a framework that enables a comprehensive analysis of existing imprinting techniques. More precisely, we generalize prior work by decomposing imprinting into three principal steps (see fig. 1). The computational efficiency of imprinting allows us to perform a large number of experiments. Through IMPRINT, we are able to propose a novel, best-performing imprinting strategy using multi-proxy weight imprinting in combination with L2 normalization, outperforming previously studied methods, as depicted in fig. 2. |
| Researcher Affiliation | Collaboration | Justus Westerhoff EMAIL DATEXIS, Berliner Hochschule für Technik (BHT), Germany Golzar Atefi DATEXIS, Berliner Hochschule für Technik (BHT), Germany Mario Koddenbrock KI Werkstatt, Hochschule für Technik und Wirtschaft Berlin (HTW), Germany Alexei Figueroa DATEXIS, Berliner Hochschule für Technik (BHT), Germany Alexander Löser DATEXIS, Berliner Hochschule für Technik (BHT), Germany Erik Rodner KI Werkstatt, Hochschule für Technik und Wirtschaft Berlin (HTW) & Merantix Momentum, Germany Felix A. Gers DATEXIS, Berliner Hochschule für Technik (BHT), Germany |
| Pseudocode | Yes | Algorithm 1 k-least-squares Input: Class data {Hc}C c=1, number of proxies k Output: Weights {Wc}C c=1 with shape [k, l] per class 1: if k = 1 then 2: return standard least squares weights WLS for each class (see eq. (A.1)) 3: end if 4: for each class c do 5: Cluster Hc into k clusters via k-means 6: Assign each cluster j to proxy class (c, j) 7: Collect proxy samples {H(c,j)}k j=1 8: end for 9: Compute least squares weights w(c,j) jointly for all proxy classes (c, j) 10: Assemble Wc = [w(c,1), . . . , w(c,k)] for each original class c 11: return {Wc}C c=1 |
| Open Source Code | Yes | We publicly release our code at https://github.com/DATEXIS/IMPRINT. |
| Open Datasets | Yes | To find out the best imprinting strategy within our IMPRINT framework, we focus on tasks T created from the datasets MNIST (Deng, 2012), Fashion MNIST (Xiao et al., 2017), and CIFAR-10 (Krizhevsky et al., 2009), each containing 10 classes. In the analysis of neural collapse (NC), we also look at the FMs pre-training data (Image Net). As its test set is not available, we use its validation set in NC1 computations. MNIST-M (Ganin et al., 2016), SVHN (Netzer et al., 2011), USPS (Hull, 1994). Digit classification datasets each containing digits 0-9 with domain-specific visual characteristics. |
| Dataset Splits | Yes | To investigate the effect of the number of samples given, we look at n-shot (n N) scenarios. For that, we randomly pre-sample the training data of T to n samples per class transitioning into the low-data regime. MNIST (Deng, 2012). A benchmark dataset of handwritten digits (0 9), consisting of 60 000 training and 10 000 test grayscale images of size 28 28. CIFAR-10 (Krizhevsky et al., 2009). A dataset of 32 32 RGB images covering 10 object classes, with 50 000 training and 10 000 test samples. Image Net (Deng et al., 2009). We use the ILSVRC 2012 version (commonly called Image Net-1K) containing 1 000 classes and 1.2M training images. Since the test set is not publicly available, we use the validation set (50 000 images) as a stand-in. MNIST-M (Ganin et al., 2016), SVHN (Netzer et al., 2011), USPS (Hull, 1994). Digit classification datasets each containing digits 0-9 with domain-specific visual characteristics. MNIST-M applies color and texture transformations to MNIST digits, yielding 60 000 training and 10 000 RGB images of size 28 28. SVHN consists of digit crops from house numbers in Google Street View, totaling 73 257 training and 26 032 test images of 32 32 in RGB. USPS contains scanned and normalized handwritten digits in 16 16 grayscale, split into 7 219 training and 2 007 test images. |
| Hardware Specification | Yes | To ensure transparency, all timing measurements were obtained on identical hardware (8 vCPUs on Intel Xeon Gold 6438Y+ nodes (2 sockets, 64 physical cores/128 threads)). |
| Software Dependencies | No | The paper mentions using KMeans from sklearn and PyTorch's torchvision models, but no specific version numbers are provided for these software dependencies. For example: "k-means cluster centers using KMeans from sklearn (Pedregosa et al., 2011)" and "To generate the embeddings, we use Py Torch s torchvision models." |
| Experiment Setup | Yes | Unless stated otherwise, we report the median test accuracy over three different seeds. In sections 5.1 and 5.2, we evaluate the imprinting performance by varying the FM (4) and T (12) and report average accuracy and standard deviation (std) across these. Due to the heterogeneity of models and tasks, large std are expected. Therefore, overall method comparisons are based on ranks rather than absolute accuracies. Methods are ranked by their median accuracy, yielding 4 12 = 48 potentially different ranks. We report the average rank and assess statistical significance of ranking (dis-)agreements using critical difference (CD) diagrams as explained in section 3.3. The code used to generate these diagrams is inspired by Ismail Fawaz et al. (2019). Table A.4: Weight decay values λ used for each model. resnet18, resnet50 (He et al., 2016) 0.0001 vit_b_16 (Dosovitskiy et al., 2021) 0.1 swin_b (Liu et al., 2021) 0.05 |