Causal Invariance-aware Augmentation for Brain Graph Contrastive Learning
Authors: Minqi Yu, Jinduo Liu, Junzhong Ji
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experiments on three real-world brain disease datasets demonstrate that our method achieves state-of-the-art performance, effectively generalizes to multi-site brain datasets, and provides certain interpretability. The code is available at https://github. com/qinsheng1900/CIA-GCL. |
| Researcher Affiliation | Academia | 1Beijing Municipal Key Laboratory of Multimedia and Intelligent Software Technology, College of Computer Science, Beijing University of Technology, Beijing, China. Correspondence to: Jinduo Liu <EMAIL>. |
| Pseudocode | Yes | A. Algorithm Pseudocode Algorithm 1 CIA-GCL for Brain Graph analysis Input: brain graph G, labels Y Output: brain invariant graph Ginv 1: for number of training iterations do 2: Sample mini training batch 3: \Invariant Subgraphs Extract: 4: Calculate the probabilities of edges in Ginv by Eq. (3) 5: Select the reserved edge mask Minv based on the threshold t by Eq.(4)(5) 6: Obtain the invariant graph Ginv and the spurious graph Gs by Eq.(1) 7: for each subject Gk do 8: \Augmented Samples Generation: 9: Select Gs i with Yi = Yk 10: Obtain the mixed spurious subgraph Gmsk by Eqs.(8)(9) 11: Combine Ginv k and G to get Gpos k by Eq.(10) 12: end for 13: \Gradient-based optimization via backpropagation: 14: Calculate the causal loss Lcau by Eq.(11) 15: Calculate the invariant loss Linv supervised loss by Eq.(12) 16: Calculate the contrastive loss Lcon supervised loss by Eq.(13) 17: Update the model by minimizing the combination of the above three losses 18: end for 19: return brain invariant graph Ginv |
| Open Source Code | Yes | Experiments on three real-world brain disease datasets demonstrate that our method achieves state-of-the-art performance, effectively generalizes to multi-site brain datasets, and provides certain interpretability. The code is available at https://github. com/qinsheng1900/CIA-GCL. |
| Open Datasets | Yes | We validate our approach on three real-world brain disease datasets. The three rs-f MRI datasets are Autism Brain Imaging Data Exchange (ABIDE) I, ABIDE II, and ADHD200, which are publicly available MRI datasets collected from different international imaging sites. ABIDE I: ABIDE I2 (and others, 2014) is a common dataset to evaluate the effectiveness of the Autism Spectrum Disorder (ASD) brain network classification tasks, which anonymously collected and shared f MRI and phenotype data for a total of 1035 subjects from 17 different sites around the world, including 505 subjects with ASD and 530 typical controls (TC). We used two brain segmentation atlases to divide the brain into smaller regions and obtained two datasets of ABIDE I. The first atlas is AAL (tzo, 2002), which contains 90 brain regions and 26 cerebellar brain regions. ABIDE II: ABIDE I3 was created to advance scientific discovery regarding the brain connectome in autism spectrum disorder (ASD). ADHD200: ADHD2004 dataset, collected from 8 independent imaging sites, consists of 491 datasets from typically developing individuals and 285 from children and adolescents with ADHD (ages 7-21 years). |
| Dataset Splits | Yes | All experimental results are obtained based on 10-fold cross-validation. On three datasets, our method outperforms other methods in most metrics. The causal analysis-based CI-GNN and the contrastive learning-based methods both performed well in three datasets, demonstrating the feasibility of causal anal- ysis and contrastive learning in brain graph analysis. Our method performs poorly on the Recall index, indicating that it tends to make conservative decisions. We also calculated the average results for all methods across 3 datasets and performed t-tests in Figure 2. It is evident that our method outperforms the other methods in most of the metrics. And the overall performance across the 3 datasets is statistically significantly different from others, with significance levels of p < 0.01( ), and p < 0.001( ). We employ ten-fold cross-validation to get a dependable and stable model, and the ratio of the training set, validation set, and test set is 8:1:1. |
| Hardware Specification | Yes | All the experiments are conducted on a server equipped with NVIDIA Ge Force RTX 3090 alongside the computational prowess of an AMD Ryzen 9 5950X 16-Core Processor CPU. |
| Software Dependencies | No | Traditional machine learning methods: The traditional methods include support vector machine (SVM) classifier and a random forest (RF) classifier, which were all implemented using the scikit-learn library (Pedregosa et al., 2011). |
| Experiment Setup | Yes | The model uses Adam optimizer with lr=4e-5, batchsize=32, and max epochs=300. In the subgraph selection section, r=0.25. In the summary of contrastive loss, τ1 and τ2 are both set to 0.01. In L = λ1Lcon + λ2Lcau + Linv formula, the two types of losses:λ1,λ2 = {1,0.5}. |