Interpretable deep learning for natural language processing
- Author(s): Murdoch, William James
- Advisor(s): Yu, Bin
- et al.
Machine-learning models have demonstrated great success in learning complex patterns that enable them to make predictions about unobserved data. In addition to using models for prediction, the ability to interpret what a model has learned is receiving an increasing amount of attention. These interpretations have found a number of uses, ranging from providing scientific insight to auditing the predictions themselves to ensure fairness with respect to protected categories like race or gender. However, there is still considerable confusion about the notion of interpretability. In particular, it is currently unclear both what it means to interpret a ML model, and how to actually do so.
In the first part of this thesis, we address the foundational question of what it means to interpret a ML model. In particular, it is currently unclear what it means to be interpretable, and how to select, evaluate, or even discuss, methods for producing interpretations of machine-learning models. We aim to clarify these concerns by defining interpretable machine learning and constructing a unifying framework for existing methods which highlights the under-appreciated role played by human audiences. Within this framework, methods are organized into two classes: model-based and post-hoc. To provide guidance in selecting and evaluating interpretation methods, we introduce three desiderata: predictive accuracy, descriptive accuracy, and relevancy. Using our framework, we review existing work, grounded in real-world studies which exemplify our desiderata, and suggest directions for future work.
The second through fourth parts introduce a succession of methods for interpreting predictions made by neural networks. The second part focuses on Long Short Term Memory networks (LSTMs) trained on question-answering and sentiment analysis, two popular tasks in natural language processing. By decomposing the LSTM's update equations, we introduce a novel method for computing feature importance scores of specific inputs for determining the output of an LSTM. In order to verify the output of our method, we use the introduced scores to search for consistently important patterns of words learned by state of
the art LSTMs on sentiment analysis and question answering. This representation is then quantitatively validated by using the extracted phrases to construct a simple, rule-based classifier which approximates the output of the LSTM.
While feature importance scores are helpful in understanding a model's predictions, they ignore the complex interactions between variables typically learned by neural networks. To this end, the third part introduces contextual decomposition (CD), an interpretation algorithm for analysing individual predictions made by standard LSTMs, without any changes to the underlying model. By decomposing the output of a LSTM, CD captures the contributions of combinations of words or variables to the final prediction of an LSTM. On the task of sentiment analysis with the Yelp and Stanford Sentiment Treebank (SST) data sets, we show that CD is able to reliably identify words and phrases of contrasting sentiment, and how they are combined to yield the LSTM's final prediction. Using the phrase-level labels in SST, we also demonstrate that CD is able to successfully extract positive and negative negations from an LSTM, something which has not previously been done.
When considering interactions between variables, the number of interactions quickly becomes too large for manual inspection, leading to the question of how to automatically select and display an informative subset. In the fourth part, we introduce the use of hierarchical interpretations to explain DNN predictions through our proposed method: agglomerative contextual decomposition (ACD). Given a prediction from a trained DNN, ACD produces a hierarchical clustering of the input features, along with the contribution of each cluster to the final prediction. This hierarchy is optimized to identify clusters of features that the DNN learned are predictive. We introduce ACD using examples from Stanford Sentiment Treebank and ImageNet, in order to diagnose incorrect predictions, identify dataset bias, and extract polarizing phrases of varying lengths. Through human experiments, we demonstrate that ACD enables users both to identify the more accurate of two DNNs and to better trust a DNN's outputs. We also find that ACD's hierarchy is largely robust to adversarial perturbations, implying that it captures fundamental aspects of the input and ignores spurious noise.