Neural Causal Graph for Interpretable and Intervenable Classification
Authors: Jiawei Wang, Shaofei Lu, Da Cao, Dongyu Wang, Yuquan Le, Zhe Quan, Tat-Seng Chua
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We validate the effectiveness of the NCG through extensive experiments on both the Image Net dataset and a custom causal-related dataset. Our results demonstrate that NCG surpasses existing classification models in terms of both accuracy and robustness. Additionally, we showcase the test-time intervention mechanism of the NCG model, which achieves nearly 95% top-1 accuracy on the Image Net dataset. |
| Researcher Affiliation | Academia | 1Hunan University, Changsha, China 2University of Science and Technology of China, Hefei, China 3National University of Singapore, Singapore EMAIL EMAIL,EMAIL |
| Pseudocode | Yes | Algorithm 1 Pseudo code for constructing the causal graph. Algorithm 2 Pseudo code for training the NCG network. |
| Open Source Code | Yes | Additionally, we make the datasets and implementations available to the research community to facilitate further research1. 1https://github.com/Javey Wang/NCG |
| Open Datasets | Yes | Two datasets are used in this study: Bird and Image Net (Deng et al., 2009). Additionally, we make the datasets and implementations available to the research community to facilitate further research1. 1https://github.com/Javey Wang/NCG |
| Dataset Splits | Yes | Table 1: The statistics of two datasets, including the number of prior and posterior concepts, and the size of training set and testing set. Dataset | Prior | Posterior | Training set | Testing Set Bird | 16 | 9 | 11,700 | 450 Image Net | 1,357 | 1,000 | 1,281,167 | 50,000 |
| Hardware Specification | Yes | Our framework is implemented based on Py Torch3 and DGL4, and all models are trained with a minibatch size of 512 on a machine equipped with four Nvidia-3090Ti GPUs. |
| Software Dependencies | No | Our framework is implemented based on Py Torch3 and DGL4 |
| Experiment Setup | Yes | all models are trained with a minibatch size of 512 on a machine equipped with four Nvidia-3090Ti GPUs. We use Adam W (Kingma & Ba, 2014; Loshchilov & Hutter, 2017) as the optimizer with a learning rate of 0.01 for the Bird dataset and 0.001 for the Image Net dataset. We train the models for 50 epochs for the Bird dataset and 3 epochs for the Image Net dataset without any data augmentation. |