Can Transformers Learn Full Bayesian Inference in Context?
Authors: Arik Reuter, Tim G. J. Rudner, Vincent Fortuin, David RĂ¼gamer
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Extensive experiments on real-world datasets demonstrate that our ICL approach yields posterior samples that are similar in quality to state-of-the-art MCMC or variational inference methods that do not operate in context. ... We present the results of our in-context learning approach on extensive real-world and synthetic datasets in Section 4 and discuss the challenges and the transformative potential of in-context learning for full Bayesian inference in Section 5. ... 4. Experiments: To show that the proposed methodology is not just an abstract concept, we derive exemplary use cases that demonstrate how well ICL is able to keep up with MCMC and VI approaches in practice. ... We evaluate the methods on 50 synthetic datasets and 17 real-world datasets ... Three metrics are employed to compare samples from different approximations of the posterior distribution. |
| Researcher Affiliation | Academia | 1Department of Statistics, LMU Munich, Munich, Germany 2Center for Data Science, New York University, New York, USA 3Department of Computer Science, Technical University of Munich, Munich, Germany 4Helmholtz AI, Munich, Germany. 5Munich Center for Machine Learning (MCML), Munich, Germany. Correspondence to: Arik Reuter <EMAIL>. |
| Pseudocode | Yes | Algorithm 2 Generation of synthetic data for GLMs ... Algorithm 3 Generation of synthetic data for FA ... Algorithm 4 Generation of synthetic data for a GMM. |
| Open Source Code | Yes | The source code for this paper is available at https://github.com/Arik Reuter/ ICL_for_Full_Bayesian_Inference |
| Open Datasets | Yes | Datasets. We evaluate the methods on 50 synthetic datasets and 17 real-world datasets from a benchmark suite for tabular regression problems proposed by Grinsztajn et al. (2022). |
| Dataset Splits | Yes | We use in total 75 million synthetic samples for all scenarios. Of the total number, half, i.e. 37.5 million, are used for training and 10 percent for validation and the remaining 40 percent for testing. ... For random forest, we use default hyperparameters, as defined in Scikit-learn (Pedregosa et al., 2011) and 10-fold cross-validation. |
| Hardware Specification | Yes | A single L4 GPU is used for the GLM scenarios and a single A100 GPU for the FA and GMM cases. |
| Software Dependencies | No | The paper mentions several software components like "Torchdiffeq (Chen, 2018)", "Pyro (Bingham et al., 2019)", "Numpyro (Phan et al., 2019)", "Adam optimizer (Kingma, 2014)", "Scikit-learn (Pedregosa et al., 2011)". However, it does not provide specific version numbers for these software components or the underlying programming language (e.g., Python version) and its core libraries, which are required for a reproducible description. |
| Experiment Setup | Yes | The dimensionality of encoder representations is set to 512 and is expanded to 1024 in the feed-forward blocks. We use 8 heads and 8 encoder layers with a dropout rate of 0.1. For the decoder part we also use 512 as the dimensionality of the representations and 1024 as the intermediate representation in the feed-forward layers and a dropout rate of 0.1. Furthermore, 3 simple fully connected layers with ada LN conditioning are used for final processing in the decoder. For the time conditioning, we use 3 simple fully connected layers to map the scalar-valued time t onto a 512 dimensional conditioning vector that is used for the ada LN blocks. ... We use an Adam optimizer (Kingma, 2014) with a cosine learning rate schedule (Loshchilov & Hutter, 2016), where the maximum learning rate is 5e-4, the final division factor is 10^4 and 10 percent of the epochs are used for warm-up. We use a weight decay parameter of 10^-5 and a batch size of 1024 and gradient clipping with a maximum gradient norm of one. |