Learning a Decision Tree Algorithm with Transformers
Authors: Yufan Zhuang, Liyuan Liu, Chandan Singh, Jingbo Shang, Jianfeng Gao
TMLR 2024 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We use 632 classification datasets from Open ML (Vanschoren et al., 2013), Penn Machine Learning Benchmarks (Romano et al., 2021), along with a synthetic XOR dataset. We require each dataset to have at least 1000 data points, at most 256 features, at most 10 classes, and less than 100 categorical features, with no missing data. We randomly select 91 datasets as the left-out test set for evaluating our model s generalization capability while making sure they and their variants do not appear in the training set. |
| Researcher Affiliation | Collaboration | Yufan Zhuang EMAIL UC San Diego Microsoft Research Liyuan Liu EMAIL Microsoft Research Chandan Singh EMAIL Microsoft Research Jingbo Shang EMAIL UC San Diego Jianfeng Gao EMAIL Microsoft Research |
| Pseudocode | No | The paper describes the methodology in narrative form and illustrates the architecture and generation process in Figure 1, but it does not include a distinct block or figure explicitly labeled as "Pseudocode" or "Algorithm" with structured steps. |
| Open Source Code | Yes | Code available at: https://github.com/Evan Zhuang/Meta Tree. |
| Open Datasets | Yes | We use 632 classification datasets from Open ML (Vanschoren et al., 2013), Penn Machine Learning Benchmarks (Romano et al., 2021), along with a synthetic XOR dataset. |
| Dataset Splits | Yes | We generate our decision-tree training dataset in the following manner: for each dataset, we first divide it into train and test sets with a 70:30 split; then we sample 256 data points with 10 randomly selected feature dimensions from the training set and fit a GOSDT tree (Lin et al., 2020) and a CART tree (Breiman et al., 1984) of depth 2 both; we later record the accuracy of the two trees on the test set; We repeat this process and generate 10k trees for each dataset. |
| Hardware Specification | Yes | GOSDT can not stably generate trees with a depth greater than 2 without incurring Out-of-Memory or Out-of-Time errors on machines with up to 125G memory (see Sec. 6.3 for memory usage analysis). Table 2: We show the memory usage of CART, GOSDT and Meta Tree when fitting decision trees with depth 2 & 3, using the memory profiler tool (Pedregosa & Gervais, 2021). GOSDT takes a significant amount of memory since it is solving for the full decision tree solution on the data. *We report successfully terminated runs stats, the experiments are conducted on a workstation with 125G memories and 128 cores. |
| Software Dependencies | No | The paper mentions software like scikit-learn and imodels and uses the LLaMA architecture as a base Transformer, but it does not provide specific version numbers for these software components. For example, 'For CART and ID3, we use the sklearn implementation (Pedregosa et al., 2011)' and 'We use the LLa MA (Touvron et al., 2023) architecture as the base Transformer' do not specify versions. |
| Experiment Setup | Yes | We pretrain our model from scratch on the GOSDT dataset, and after training converges, we finetune it on the GOSDT+CART dataset. This curriculum improves performance compared to direct training on the GOSDT+CART dataset (as shown in Appendix A.6). We also show an ablation with RL training objective in Appendix A.8. Detailed hyperparameters are shown in Appendix A.5. Table A2: Hyperparameters for Meta Tree training. Hyperparameter Value: Number of Hidden Layers 12, Number of Attention Heads 12, Hidden Size 768, Number of Parameters 149M, Learning Rate 5e-5, Learning Rate Schedule Linear Decay, Optimizer Adam W β1 0.9 β2 0.999, Training dtype bf16, Number of Features 10, Number of Classes 10, Block Size 256, Tree Depth 2, σ 5e-2, Number of Warmup Steps 1000, Number of Training Steps 4,000,000, Steps in Phase 1 (GOSDT) 1,000,000, Steps in Phase 2 (GOSDT+CART) 3,000,000, Batch Size 128 |