The Pitfalls of Memorization: When Memorization Hurts Generalization
Authors: Reza Bayat, Mohammad Pezeshki, Elvis Dohmatob, David Lopez-Paz, Pascal Vincent
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We validate the effectiveness of MAT in improving generalization under subpopulation shift. Then, we provide a detailed analysis of the memorization behaviors of models trained with ERM and MAT. We evaluate our approach on four datasets under subpopulation shift, as detailed in Appendix A. In all experiments, we assume that training environment annotations are not available. For the validation set and for the purpose of model selection, we consider two settings: (1) group annotations are available in the validation set for model selection, and (2) no annotations are available even in the validation set. |
| Researcher Affiliation | Collaboration | 1Mila 2Université de Montréal 3Meta FAIR 4Concordia University |
| Pseudocode | Yes | Algorithm 1 Memorization-Aware Training (MAT) Input: Training set {(xi, yi)}n i=1, Validation set {(x i, y i)}m i=1, Pre-trained XRM model f XRM(x) Optional: Validation environment annotations {a i}m i=1; if not available, infer a i = argmax f ho(xi) = yi Model Selection: Early stopping based on validation worst-group accuracy (using a i if provided, or a i if inferred) Initialize a classifier f(x). Compute p(yho | xi) = softmax(f XRM(xi)/τ). Compute p(y, yho) = concat 1 n P p(yho | xi)[y = yi] yi {1,...,K} Compute p(y | yho) = p(y,yho) P y p(y , yho),1 Compute pho(y | xi) = p(yho | xi), p(y | yho)T Repeat until early stopping: Update the loss: 1 n P l(softmax(f(xi) + log pho(. | xi), yi) Track worst-group accuracy and update the best model Stop if no improvement is observed after P iterations |
| Open Source Code | Yes | Equal contribution. Code: https://github.com/facebookresearch/Pitfalls-of-Memorization |
| Open Datasets | Yes | We tested our method on 4 standard datasets: two image datasets, Waterbirds (Sagawa et al., 2019) and Celeb A (Liu et al., 2015), and two natural language datasets, Multi NLI (Williams et al., 2017) and Civil Comments (Borkan et al., 2019). The configuration of each dataset is provided below. For Celeb A, predictors map pixel intensities into a binary blonde/not-blonde label. |
| Dataset Splits | Yes | Input: Training set {(xi, yi)}n i=1, Validation set {(x i, y i)}m i=1,... For evaluation, we report two key metrics on the test set: (1) average test accuracy and (2) worst-group test accuracy, the latter being computed using ground-truth annotations. Table 2: Summary of datasets used, including their data types, the number of classes, the number of groups, and the total dataset size for each. ... Train size |
| Hardware Specification | No | Part of the early experiments were conducted using computational resources provided by Mila Quebec AI Institute. |
| Software Dependencies | No | We use SGD with momentum of 0.9 for the Waterbirds dataset, and we employ Adam W (Loshchilov & Hutter, 2017) with default values of β1 = 0.9 and β2 = 0.999 for the other datasets. We use a pre-trained Res Net-50 (He et al., 2016) for image datasets. For text datasets, we use a pre-trained BERT (Devlin et al., 2018). |
| Experiment Setup | Yes | The hyperparameter search involves testing 16 random hyperparameter combinations sampled from the search space described in Table 3, using a single random seed. We select the hyperparameter combination and the early-stopping iteration that achieve the highest validation worst-group accuracy, either with ground truth group annotations or pseudo annotations, depending on the method, or the worst-class accuracy if groups are not available. Table 3: Hyperparameter search space. ERM and MAT share the same hyperparameter search space, except that MAT has one additional hyperparameter, τ, which is used in the softmax function as the temperature parameter to control the sharpness/smoothness of the output distribution. |