Neural Causal Graph for Interpretable and Intervenable Classification

Authors: Jiawei Wang, Shaofei Lu, Da Cao, Dongyu Wang, Yuquan Le, Zhe Quan, Tat-Seng Chua

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We validate the effectiveness of the NCG through extensive experiments on both the Image Net dataset and a custom causal-related dataset. Our results demonstrate that NCG surpasses existing classification models in terms of both accuracy and robustness. Additionally, we showcase the test-time intervention mechanism of the NCG model, which achieves nearly 95% top-1 accuracy on the Image Net dataset.
Researcher Affiliation Academia 1Hunan University, Changsha, China 2University of Science and Technology of China, Hefei, China 3National University of Singapore, Singapore EMAIL EMAIL,EMAIL
Pseudocode Yes Algorithm 1 Pseudo code for constructing the causal graph. Algorithm 2 Pseudo code for training the NCG network.
Open Source Code Yes Additionally, we make the datasets and implementations available to the research community to facilitate further research1. 1https://github.com/Javey Wang/NCG
Open Datasets Yes Two datasets are used in this study: Bird and Image Net (Deng et al., 2009). Additionally, we make the datasets and implementations available to the research community to facilitate further research1. 1https://github.com/Javey Wang/NCG
Dataset Splits Yes Table 1: The statistics of two datasets, including the number of prior and posterior concepts, and the size of training set and testing set. Dataset | Prior | Posterior | Training set | Testing Set Bird | 16 | 9 | 11,700 | 450 Image Net | 1,357 | 1,000 | 1,281,167 | 50,000
Hardware Specification Yes Our framework is implemented based on Py Torch3 and DGL4, and all models are trained with a minibatch size of 512 on a machine equipped with four Nvidia-3090Ti GPUs.
Software Dependencies No Our framework is implemented based on Py Torch3 and DGL4
Experiment Setup Yes all models are trained with a minibatch size of 512 on a machine equipped with four Nvidia-3090Ti GPUs. We use Adam W (Kingma & Ba, 2014; Loshchilov & Hutter, 2017) as the optimizer with a learning rate of 0.01 for the Bird dataset and 0.001 for the Image Net dataset. We train the models for 50 epochs for the Bird dataset and 3 epochs for the Image Net dataset without any data augmentation.