Demystifying amortized causal discovery with transformers

Authors: Francesco Montagna, Max Cairney-Leeming, Dhanya Sridhar, Francesco Locatello

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we analyze CSIv A (Ke et al., 2023b) on bivariate causal models, a transformer architecture for amortized inference promising to train on synthetic data and transfer to real ones. First, we bridge the gap with identifiability theory, showing that the training distribution implicitly defines a prior on the causal model of the test observations: consistent with classical approaches, good performance is achieved when we have a good prior on the test data, and the underlying model is identifiable. Second, we find that CSIv A can not generalize to classes of causal models unseen during training: to overcome this limitation, we theoretically and empirically analyze when training CSIv A on datasets generated by multiple identifiable causal models with different structural assumptions improves its generalization at test time. Overall, we find that amortized causal discovery with transformers still adheres to identifiability theory, violating the previous hypothesis from Lopez-Paz et al. (2015) that supervised learning methods could overcome its restrictions.
Researcher Affiliation Collaboration Francesco Montagna Institute of Science and Technology Austria & Chan Zuckerberg Initiative EMAIL
Pseudocode No The paper describes the CSIv A architecture and its training procedure in detail in Appendix A, but it does not present any part of the methodology in a structured pseudocode or algorithm block.
Open Source Code Yes The code for CSIv A implementation can be found here. The code for reproducing the experiments of the paper can be found here
Open Datasets Yes We consider the accuracy of CSIv A trained on different dataset configurations and tested on real-world datasets. In particular, we perform evaluation on the Tübingen pairs dataset (Mooij et al., 2016), the Sachs biological dataset (Sachs et al., 2005), the Auto MPG dataset on cars fuel consumption (Bache & Lichman, 2013) and the Sprinkler dataset, a simple dataset on the causal relations between the binary categorical variables rain, sprinkler on/off, wet grass.
Dataset Splits Yes Unless otherwise specified, in our experiments we train CSIv A on a sample of 15000 synthetically generated datasets, consisting of 1500 i.i.d. observations. Classes of SCMs are defined by the mechanism type and the noise terms distribution (e.g., linear non-Gaussian): each dataset is generated from a single SCM instance sampled from that class. ... Each architecture we analyze in the experiments is trained 3 times, with different parameter initialization and training samples: the SHD presented in the plots is the average of each of the 3 models on 1500 distinct test datasets of 1500 points each, and the error bars are 95% confidence intervals.
Hardware Specification Yes Our experiments were run on a local computing cluster, using any and all available GPUs (all NVIDIA). For replication purposes, GTX 1080 Ti s are entirely suitable, as the batch size was set to match their memory capacity, when working with bivariate graphs. All jobs ran with 10GB of RAM and 4 CPU cores. The results presented in this paper were produced after 145 days of GPU time, of which 68 were on GTX 1080 Ti s, 13 on RTX 2080 Ti s, 11 on A10s, 19 on A40s, and 35 on RTX 3090s. Together with previous experiments, while developing our code and experimental design, we used 376 days of GPU time (for reference, at a total cost of 492.14 Euros), similarly split across whichever GPUs were available at the time: 219 on GTX 1080 Ti s, 38 on RTX 2080 Ti s, 18 on A10s, 63 on RTX 3090s, 31 on A40s, and 6 on A100s.
Software Dependencies No The paper mentions several software tools and libraries like 'causally3 Python library', 'causal-learn', and 'dodiscover', but it does not specify any version numbers for these, nor for any programming languages or other critical dependencies.
Experiment Setup Yes In Table 1 we detail the hyperparameters of the training of the network of the experiments. We define an iteration as a gradient update over a batch of 5 datasets. Models are trained until convergence, using a patience of 5 (training until five consecutive epochs without improvement) on the validation loss this always occurs before the 25-th epoch (corresponding to 150000 iterations). The batch size is limited to 5 due to memory constraints. Table 1: Hyperparameters for the training of the CSIv A models of the experiments in Section 4. Hidden state dimension 64 Encoder transformer layers 8 Decoder transformer layers 8 Num. attention heads 8 Optimizer Adam Learning rate 10^-4 Samples per dataset (n) 1500 Num. training datasets 15000 Num. iterations < 150000 Batch size 5