Teaching Transformers Causal Reasoning through Axiomatic Training
Authors: Aniket Vashishtha, Abhinav Kumar, Atharva Pandey, Abbavaram Gowtham Reddy, Kabir Ahuja, Vineeth N. Balasubramanian, Amit Sharma
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our results, based on applying axiomatic training to learn the transitivity axiom and d-separation rule, indicate that such generalization is possible. To avoid data contamination issues, we start with a 67 million parameter transformer model and train it from scratch. On both tasks, we find that a model trained on linear causal chains (along with some noisy variations) can generalize well to complex graphs, including longer causal chains, causal chains with reversed order, and graphs with branching. To handle diverse text inputs, the same method is extended to finetune language models. Finetuning Llama-3.1 8B model on our axiomatic data leads to significant gains on causal benchmarks such as Corr2Cause and CLEAR, in some cases providing state-of-the-art performance surpassing GPT-4. |
| Researcher Affiliation | Collaboration | Aniket Vashishtha 1 Abhinav Kumar 2 Atharva Pandey 3 Abbavaram Gowtham Reddy 4 Kabir Ahuja 5 Vineeth N Balasubramanian 6 Amit Sharma 3... 1Work primarily done as a Research Fellow at Microsoft Research India... 2Massachusetts Institute of Technology, USA 3Microsoft Research, India 4CISPA Helmholtz Center for Information Security, Germany... 5University of Washington, USA 6Work primarily done at IIT Hyderabad... |
| Pseudocode | No | The paper describes the methods and procedures in paragraph text without structured pseudocode or algorithm blocks. |
| Open Source Code | Yes | Our code repository can be accessed at: https://github.com/ Aniket Vashishtha/Causal_Axioms. |
| Open Datasets | Yes | Our code repository can be accessed at: https://github.com/ Aniket Vashishtha/Causal_Axioms. (This repository contains synthetic training data as stated in the Impact Statement: 'To mitigate these risks, we release code and synthetic training data openly'). We evaluate on two benchmarks of causal reasoning in natural language: CLEAR (Chen et al., 2024) and Corr2Cause (Jin et al., 2024a). |
| Dataset Splits | Yes | We train the model on data from simple causal graphs such as sequential chains with 3-6 nodes and evaluate its performance on more complex graphs, including longer chain-like graphs with 7-15 nodes, graphs with branching, longer variable names, and edge direction perturbations (see Figure 1). The training data consists of sequential chains of lengths from [3,6]. |
| Hardware Specification | Yes | Training was performed on 3 GPUs using Deep Speed Stage 3 with a total batch size of 128 (16 samples per GPU with gradient accumulation). 1 A40 and 1 A100 GPUs were used for training the transformer model from scratch for all Positional encodings based experiments. |
| Software Dependencies | No | We used Huggingface (wol, 2020) for implementation. The fine-tuning used Lo RA with rank 64, alpha 16, and dropout 0.1. Training was performed on 3 GPUs using Deep Speed Stage 3... No specific software versions are provided. |
| Experiment Setup | Yes | We used a learning rate of 1e-4 with linear scheduling and 3% warmup ratio, training for 4102 max steps on axiomatic instance samples with sequences of maximum length 4096 tokens. We employed mixed precision (bfloat16) training with flash attention for efficiency. After training, the Lo RA weights were merged with the base model for inference. ... The fine-tuning used Lo RA with rank 64, alpha 16, and dropout 0.1. Training was performed on 3 GPUs using Deep Speed Stage 3 with a total batch size of 128 (16 samples per GPU with gradient accumulation). |