Giter VIP home page Giter VIP logo

transformers-interpret's Introduction

Transformers Interpret

Transformers Interpret is a model explainability tool designed to work exclusively with ๐Ÿค— transformers.

In line with the philosophy of the transformers package tranformers interpret allows any transformers model to be explained in just two lines. It even supports visualizations in both notebooks and as savable html files.

This package stands on the shoulders of the the incredible work being done by the teams at Pytorch Captum and Hugging Face and would not exist if not for the amazing job they are both doing in the fields of nlp and model interpretability respectively.

Install

I recommend Python 3.6 or higher, Pytorch 1.5.0 or higher, transformers v3.0.0 or higher, and captum 0.3.1 (required). The package does not work with Python 2.7.

pip install transformers-interpret

Quick Start

Let's start by importing the auto model and tokenizer class from transformers and initializing a model and tokenizer.

For this example we are using distilbert-base-uncased-finetuned-sst-2-english a distilbert model finetuned on a sentiment analysis task.

from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

With both the model and tokenizer initialized we are now able to get get some explanations on some input text.

from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer("I love you, I like you", model, tokenizer)
attributions = cls_explainer()

Returns the list of tuples below.

>>> attributions.word_attributions
[('[CLS]', 0.0),
 ('i', 0.46820533977856904),
 ('love', 0.4606184697303162),
 ('you', 0.5664126708457133),
 (',', -0.017154242497229605),
 ('i', -0.05376360639469018),
 ('like', 0.10987772217503108),
 ('you', 0.4822169265102293),
 ('[SEP]', 0.0)]

Positive attribution numbers indicate a word contributes positively towards the predicted class, negative numbers indicate a word contributes negatively towards the predicted class. Here we can see that I love you gets the most attention.

In case you want to know what the predicted class actually is:

>>> cls_explainer.predicted_class_index
array(1)

And if the model has label names for each class

>>> cls_explainer.predicted_class_name
'POSITIVE'

Visualizing attributions

Sometimes the numeric attributions can be difficult to read particularly in instances where there is a lot of text. To help with that there is also an inbuilt visualize method that utilizes Captum's in built viz library to create a HTML file highlighting attributions.

If you are in a notebook call the visualize() method will display the visualization in line, otherwise you can pass a filepath in as an argument and a html file will be created so you can view the explanation in your browser.

cls_explainer.visualize("distilbert_viz.html")

Explaining Attributions for Non Predicted Class

Attribution explanations are not limited to the predicted class. Let's test a more complex sentence that contains mixed sentiments.

In the example below we pass class_name="NEGATIVE" as an argument indicating we would like the attributions to be explained for the NEGATIVE class regardless of what the actual prediction is. Effectively because this is a binary classifier we are getting the inverse attributions.

cls_explainer = SequenceClassificationExplainer("I love you, I like you, I also kinda dislike you", model, tokenizer)
attributions = cls_explainer(class_name="NEGATIVE")

This still returns a prediction of the POSITIVE class

>>> cls_explainer.predicted_class_name
'POSITIVE'

But when we visualize the attributions we can see that the words "...kinda dislike" in the sentence are contributing to a prediction of the "NEGATIVE" class.

cls_explainer.visualize("distilbert_negative_attr.html")

Getting attributions for different classes is particularly insightful for multiclass problems as it allows you to inspect model predictions for a number of different classes and sanity check that the model is "looking" at the right things.

For a detailed example of this please checkout this multiclass classification notebook

Future Development

This package is still in its early days and there is hopefully much more planned. For a 1.0.0 release I'm aiming to have:

  • Proper documentation site
  • Support for Question Answering models
  • Support for NER models
  • Support for Multiple Choice models (not sure how feasible this is)
  • Ability to show attributions for each layer rather than a summary of all layers
  • Additional attribution methods
  • In depth examples
  • Get a nice logo
  • and more...

Questions / Get In Touch

The main contributor to this repository is @cdpierse.

If you have any questions, suggestions, or would like to make a contribution (please do ๐Ÿ˜) please get in touch at [email protected]

I'd also highly suggest checking out Captum if you find model explainability and interpretability interesting. They are doing amazing and important work.

Captum Links

Below are some links I used to help me get this package together using captum. Thank you to @davidefiocco for your very insightful GIST.

transformers-interpret's People

Contributors

cdpierse avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.