Optimizing for Interpretability in Deep Neural Networks with Tree Regularization

Authors: Mike Wu, Sonali Parbhoo, Michael C. Hughes, Volker Roth, Finale Doshi-Velez

JAIR 2021 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Using intuitive toy examples, benchmark image datasets, and medical tasks for patients in critical care and with HIV, we demonstrate that this new family of tree regularizers yield models that are easier for humans to simulate than L1 or L2 penalties without sacrificing predictive power.
Researcher Affiliation Academia Mike Wu EMAIL Stanford University, Stanford, CA 94305 USA Sonali Parbhoo EMAIL Harvard University SEAS, Cambridge, MA 02138 USA Michael C. Hughes EMAIL Tufts University, Medford, MA 02153 USA Volker Roth EMAIL University of Basel, Basel, Switzerland Finale Doshi-Velez EMAIL Harvard University SEAS, Cambridge, MA 02138 USA
Pseudocode Yes Algorithm 1 Average decision path length (APL) Cost Function
Open Source Code No The paper does not explicitly provide a link to open-source code for the methodology described, nor does it state that the code is available in supplementary materials or will be released.
Open Datasets Yes We study timeseries data for 11,786 septic ICU patients from the public MIMIC III dataset (Johnson, Pollard, Shen, Lehman, Feng, Ghassemi, Moody, Szolovits, Celi, & Mark, 2016). ... The Eu Resist Integrated Database (Zazzi, Incardona, Rosen-Zvi, Prosperi, Lengauer, Altmann, Sonnerborg, Lavee, Schulter, & Kaiser, 2012) describes 53,236 patients diagnosed with HIV. ... Timeseries data containing broadband recordings of 630 speakers of American English reading ten phonetically rich sentences (Garofolo et al., 1993). ... Next, we explore the use of tree regularization in image classification on two standard benchmarks: MNIST (Le Cun, 1998) and Fashion MNIST (Xiao, Rasul, & Vollgraf, 2017).
Dataset Splits Yes 7,070 patients are used in training, 1,769 for validation, and 294 for test. (Sepsis Critical Care (ICU)) 37,618 patients are used for training, 7,986 for testing, and 7,632 for validation. (HIV Therapy Outcome (HIV)) There are 6,303 sequences: which we split into 3,697 for training, 925 for validation, and 1,681 for testing. (Phonetic Speech (TIMIT))
Hardware Specification Yes Table 2: Training time for a single epoch in seconds on a single Intel Core i5 CPU.
Software Dependencies No The paper mentions using "Python s scikitlearn (Pedregosa, Varoquaux, Gramfort, Michel, Thirion, Grisel, Blondel, Prettenhofer, Weiss, Dubourg, Vanderplas, Passos, Cournapeau, Brucher, Perrot, & Duchesnay, 2011a)" and the Adam optimizer, but does not provide specific version numbers for these or any other software libraries or frameworks used in their implementation.
Experiment Setup Yes Equation 1 is optimized via Adam (Kingma & Ba, 2014) using a batch size of 100 and a learning rate of 1e-3 for 250 epochs. These hyperparameters were found with grid search. (2D Parabola MLP experiment) For optimization, we use Adam with a learning rate of 1e-3, a batch size of 256, decision tree hyperparameter h = 1000, train for 300 epochs, surrogate datasets of size J = 100, and retrain every 25 steps. (Real-World Timeseries) We use the Adam optimizer with learning rate 3e-4, batch size 128, for 30 epochs in training the target neural network. To fit the surrogate, we use Adam with learning rate 1e-3, batch size 256, weight decay 1e-4 for 50 epochs. The surrogate is retrained every epoch of optimizing the main neural network. (Image Classification) In each of the following dataset, the target neural model is trained for 500 epochs with 1e-4 learning rate using Adam and a minibatch size of 128. We train under 20 different λ between 0.0001 and 10.0. We do not do early stopping to preserve overfitting effects. We use 250 samples from the convex hull and retrain every 50 gradient steps. We set h = 25 for Wine and h = 100 otherwise. (UCI Machine Learning Benchmarks)