TITLE: NBDT: Neural-Backed Decision Trees
AUTHOR: Alvin Wan, Lisa Dunlap, Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez
ASSOCIATION: UC Berkeley, Boston University
FROM: arXiv:2004.00221
CONTRIBUTION
- A method is proposed for running any classification neural network as a decision tree by defining a set of embedded decision rules that can be constructed from the fully-connected layer. Induced hierarchies are designed that are easier for neural networks to learn.
- Tree supervision loss is proposed, which boosts neural network accuracy by
0.5% and produces high-accuracy NBDTs. NBDTs achieve accuracies comparable to neural networks on small, medium, and large-scale image classification datasets. - Qualitative and quantitative evidence of semantic interpretations are illustrated.
METHOD
Steps for Converting CNN into a Decision Tree
- Build an induced hierarchy;
- Fine-tune the model with a tree supervision loss;
- For inference, featurize samples with the neural network backbone;
- And run decision rules embedded in the fully-connected layer.
The following figure illustrates the main steps for converting a classification neural network into a decision tree:
Building Induced Hierarchies
The following figure illustrates how to build induced hierarchies from the network’s final fully-connected layer. For the leaf nodes, the representative vectors are extracted from the weights of FC layer. The parents’ representative vectors are computed by averaging the children.
In this work, the author found a minimal subset of the WordNet hierarchy that includes all classes as leaves, pruning redundant leaves and single-child intermediate nodes. To leverage the source of labels, hypotheses is generated for each intermediate node by finding the earliest ancestor of each subtrees’ leaves.
Training with Tree Supervision Loss
A tree supervision loss is added to the final loss function to encourage the network to separate representative vectors for each internal node. Two losses are proposed, hard tree supervision loss and soft tree supervision loss. The final loss is
where $L{original}$ is the typical cross entopy loss for classification, and $L{hard/soft}$ stands for hard or soft tree supervision loss.
The hard tree supervsion loss is defined as
where $N$ is the number of nodes in the tree, excluding leaves. $D(i){pred}$ is the predicted probabilities and $D{label}$ is the label in node $i$. Note that nodes that are not included in the path from the label to the root have no defined losses.
The soft tree supervsion loss is defined as
where $D{pred}$ is the predicted distribution over leaves and $D{label}$ is the wanted distribution.
The following figure gives an example of the hard and soft tree supervison loss.
PERFORMANCE
On all CIFAR10, CIFAR100, TinyImageNet, and ImageNet datasets, NBDT outperforms competing decision-tree-based methods, even uninterpretable variants such as a decision forest, by up to 18%. On CIFAR10, CIFAR100, and TinyImageNet, NBDTs largely stay within 1% of neural network performance.
SOME THOUGHTs
- The performance seems promissing. Howerver, the ablation studies is confusing because they have different expirement settings with more than one variables.
- The method of constructing a reasonable hierarchies is illustrated less exhaustive. My best guess is that the author force the tree to be a binary tree.
- Is this possible that the leaves have duplicated labels?