0%

Reading Note: NBDT: Neural-Backed Decision Trees

TITLE: NBDT: Neural-Backed Decision Trees

AUTHOR: Alvin Wan, Lisa Dunlap, Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez

ASSOCIATION: UC Berkeley, Boston University

FROM: arXiv:2004.00221

CONTRIBUTION

  1. A method is proposed for running any classification neural network as a decision tree by defining a set of embedded decision rules that can be constructed from the fully-connected layer. Induced hierarchies are designed that are easier for neural networks to learn.
  2. Tree supervision loss is proposed, which boosts neural network accuracy by
    0.5% and produces high-accuracy NBDTs. NBDTs achieve accuracies comparable to neural networks on small, medium, and large-scale image classification datasets.
  3. Qualitative and quantitative evidence of semantic interpretations are illustrated.

METHOD

Steps for Converting CNN into a Decision Tree

  1. Build an induced hierarchy;
  2. Fine-tune the model with a tree supervision loss;
  3. For inference, featurize samples with the neural network backbone;
  4. And run decision rules embedded in the fully-connected layer.

The following figure illustrates the main steps for converting a classification neural network into a decision tree:

Main Steps

Building Induced Hierarchies

The following figure illustrates how to build induced hierarchies from the network’s final fully-connected layer. For the leaf nodes, the representative vectors are extracted from the weights of FC layer. The parents’ representative vectors are computed by averaging the children.

Building Induced Hierarchies

In this work, the author found a minimal subset of the WordNet hierarchy that includes all classes as leaves, pruning redundant leaves and single-child intermediate nodes. To leverage the source of labels, hypotheses is generated for each intermediate node by finding the earliest ancestor of each subtrees’ leaves.

Training with Tree Supervision Loss

A tree supervision loss is added to the final loss function to encourage the network to separate representative vectors for each internal node. Two losses are proposed, hard tree supervision loss and soft tree supervision loss. The final loss is

where $L{original}$ is the typical cross entopy loss for classification, and $L{hard/soft}$ stands for hard or soft tree supervision loss.

The hard tree supervsion loss is defined as

where $N$ is the number of nodes in the tree, excluding leaves. $D(i){pred}$ is the predicted probabilities and $D{label}$ is the label in node $i$. Note that nodes that are not included in the path from the label to the root have no defined losses.

The soft tree supervsion loss is defined as

where $D{pred}$ is the predicted distribution over leaves and $D{label}$ is the wanted distribution.

The following figure gives an example of the hard and soft tree supervison loss.

Tree Supervison Loss

PERFORMANCE

On all CIFAR10, CIFAR100, TinyImageNet, and ImageNet datasets, NBDT outperforms competing decision-tree-based methods, even uninterpretable variants such as a decision forest, by up to 18%. On CIFAR10, CIFAR100, and TinyImageNet, NBDTs largely stay within 1% of neural network performance.

Performance

SOME THOUGHTs

  1. The performance seems promissing. Howerver, the ablation studies is confusing because they have different expirement settings with more than one variables.
  2. The method of constructing a reasonable hierarchies is illustrated less exhaustive. My best guess is that the author force the tree to be a binary tree.
  3. Is this possible that the leaves have duplicated labels?