NBDTs and a Realistic View of Interpretability for Deep Learning
About our Guest Speaker
Soroco invited Lisa Dunlap, a 2nd year PhD Student from UC Berkeley to deliver a talk on her work in explainable AI at Berkeley Artificial Intelligence Research (BAIR) Lab. The work that Lisa discussed during her talk, which you can see below, focused on a problem as old as Neural Networks – Lack of Explainability. Lisa and her colleagues came up with a new and interesting way to combine Neural Networks with Decision Trees while keeping both their strengths and compensating for their weaknesses. Their paper on Neural Backed Decision Trees (NBDTs) was published in ICLR, 2019.
Why This Talk at Soroco
Soroco is building a work graph to help enterprises understand how they do digital work at the user level. Soroco’s technology Scout performs process and task discovery to find patterns in the data which represent business steps conducted by users, which helps annotate the work graph with the business context of how teams conduct processes.
Soroco’s Machine Learning algorithms classify user activities into processes and tasks but when our models suggest that a set of user activities should be attributed to a particular process, it helps to understand why the models think so. Explainable insights can help us provide more accurate predictions while also empowering our customers to highlight information that can help us identify their processes better. That’s where we think the work presented by Lisa is pertinent to what we do.
A deeper knowledge of the day-to-day tasks and processes enables teams to identify their pain points, bottlenecks, and discover the variations in the way processes are performed. Teams can then standardize their processes, address their system or process bottlenecks, and even automate repetitive tasks to improve their efficiency.
Watch the Talk
Lisa highlighted a key challenge in developing and using machine learning models – understanding how and why a model makes a prediction, i.e. – explainability and interpretability. This is a problem for both developers and their users. For users, more opaque models are harder to understand and hence harder to trust. For developers, explainable models are easier to debug. Black box the models, regardless of their accuracy, are harder to work with. If a model is not explainable, it is difficult to determine if the error comes from a poor choice of model, lack of hyperparameter tuning, or poor data quality. If the data quality is poor, the developer shouldn’t have to go through every example, relabel, and retrain repeatedly.
Interpretable models provide a human-in-the-loop way to dive into the data. Interpretability is for these reasons, one of the most desirable properties of models in addition to accuracy.
Our most powerful and accurate models – Neural Networks – are also the least interpretable. Gradient based interpretability methods and saliency maps explain the part of the input that was most influential in making a decision, but they don’t tell us how the model made the decision. Our most interpretable models like decision trees are not nearly as powerful as Neural Networks but they provide a way to explain how the model arrived at a decision.
Key Ideas from the Talk
NBDTs enhance the traditional neural network architecture to make them more interpretable while retaining or even improving their accuracy. They consist of a decision tree that is created to work in tandem with a neural network to augment the explainability of the combined model. Below we describe more of their beneficial properties and what we think makes them powerful.
Plug and play over a traditional neural network
The most beneficial part of this approach is that it keeps the neural network almost as is. Features from the neural network itself are used to build the tree. The nodes of the fully connected layer of the neural network contain weights which describe the features that the node is looking for. Running agglomerative clustering over these nodes tells us which nodes can be grouped together. The resulting dendrogram give us the structure of the decision tree.
Providing intermediate results that lead to the decision
The intermediate nodes of the decision tree are still not explainable or interpretable. WordNet is a hierarchy of nouns. To assign labels to nodes, the earliest common ancestor for all leaves in a subtree is found from WordNet. The paper describes an intuitive example to help us understand this – say Dog and Cat are two categories that the original model predicts and they correspond to a node in the tree. Clustering tells us that they share a parent. To find a WordNet label for the parent, all ancestor concepts for Dog and Cat are found, like Mammal, Animal, and Living Thing. The closest shared ancestor is Mammal, so we assign Mammal to the parent of Dog and Cat. We do this recursively until all the nodes have a label.
Optimizing for interpretability also optimizes for accuracy
The last and most interesting contribution of the paper is how they improve both the accuracy and the interpretability of the Neural networks by adding a Tree Supervision Loss. The tree supervision loss is a cross entropy loss that encourages the network to predict the right path in the dendrogram with a higher probability. This loss ended up improving both explainability and the accuracy of the model.
Additional Thoughts and Conclusions
We discussed how NBDTs or related concepts may be helpful in solving some of our problems at Soroco. Some of the key challenges with implementing such an approach in the wild is that it could be hard to find labels for intermediate nodes of the decision tree. We discussed how multimodal models like CLIP could help in such cases. We also discussed how other large language models (LLMs) may be useful in purely NLP contexts. Most importantly, the talk debunks the myth that models cannot be both highly accurate and explainable. It has had Soroco developers rethinking how we incorporate explainability into our models to benefit our development process and end-user experience. Overall, this work has done a great job of balancing accuracy and explainability that we believe many developers and teams can learn from.