Giter VIP home page Giter VIP logo

laura-rieger / deep-explanation-penalization Goto Github PK

View Code? Open in Web Editor NEW
123.0 9.0 13.0 254.46 MB

Code for using CDEP from the paper "Interpretations are useful: penalizing explanations to align neural networks with prior knowledge" https://arxiv.org/abs/1909.13584

License: MIT License

Jupyter Notebook 80.80% Python 19.20%
interpretability neural-network machine-learning convolutional-neural-network pytorch explainability deep-learning explainable-ai artificial-intelligence ml ai python data-science jupyter-notebook feature-importance recurrent-neural-network fairness fairness-ml cdep interpretable-deep-learning

deep-explanation-penalization's Introduction

Making interpretations useful (CDEP) ๐Ÿ”จ

Regularizes interpretations (computed via contextual decomposition) to improve neural networks. Official code for Interpretations are useful: penalizing explanations to align neural networks with prior knowledges (ICML 2020 pdf).

Note: this repo is actively maintained. For any questions please file an issue.

fig_intro

documentation

  • fully-contained data/models/code for reproducing and experimenting with CDEP
  • the src folder contains the core code for running and penalizing contextual decomposition
  • in addition, we run experiments on 4 datasets, each of which are located in their own folders
    • notebooks in these folders show demos for different kinds of text

examples

ISIC skin-cancer classification - using CDEP, we can learn to avoid spurious patches present in the training set, improving test performance!

The segmentation maps of the patches can be downloaded here

ColorMNIST - penalizing the contributions of individual pixels allows us to teach a network to learn a digit's shape instead of its color, improving its test accuracy from 0.5% to 25.1%

Fixing text gender biases - CDEP can help to learn spurious biases in a dataset, such as gendered words

using CDEP on your own data

using CDEP requires two steps:

  1. run CD/ACD on your model. Specifically, 3 things must be altered:
  • the pred_ims function must be replaced by a function you write using your own trained model. This function gets predictions from a model given a batch of examples.
  • the model must be replaced with your model
  • the current CD implementation doesn't always work for all types of networks. If you are getting an error inside of cd.py, you may need to write a custom function that iterates through the layers of your network (for examples see cd.py)
  1. add CD scores to the loss function (see notebooks)

related work

  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning
  • TRIM (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests

reference

  • feel free to use/share this code openly
  • if you find this code useful for your research, please cite the following:
@inproceedings{rieger2020interpretations,
  title={Interpretations are useful: penalizing explanations to align neural networks with prior knowledge},
  author={Rieger, Laura and Singh, Chandan and Murdoch, William and Yu, Bin},
  booktitle={International Conference on Machine Learning},
  pages={8116--8126},
  year={2020},
  organization={PMLR}
}

deep-explanation-penalization's People

Contributors

csinva avatar laura-rieger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-explanation-penalization's Issues

How do I get segmentation files of ISIC dataset?

Hi, I appreciate your awesome code!!
Btw how do I get segmentation files of the ISIC dataset?
I can't find any code in 02_sort_imgs.py that save
segmentation files in "segmentation_path = os.path.join(data_path, "segmentation")"

I found in another issue that I can download SONIC study in ISIC website
but I still don't know if this data is already included in the dataset I can download with
01_download_imgs.py file.

So which data of the study in the ISIC dataset did you exactly use for the ISIC result in your paper?
There were the following studies.

2018 JID Editorial Images (0 / 100)
Dermoscopedia (CC-BY) (0 / 5)
HAM10000 (0 / 10015)
MSK-1 (0 / 1100)
MSK-2 (0 / 1535)
MSK-3 (0 / 225)
MSK-4 (0 / 947)
MSK-5 (0 / 111)
SONIC (9251)
UDA-1 (0 / 557)
UDA-2 (0 / 60)

About segmentation imags

I can't find segmentation imags in the segmentation data file. Could you provide this seg data?

How come Gradient Sum and EG do two gradient steps?

Hi! Thanks for sharing this repo!

In the MNIST Decoy code, for method 1 and 2 (gradient_sum and eg), there are two gradient steps per batch. The first step uses gradients from just explanation_penalty, and the second step uses the gradients from both the explanation_penalty and the log loss. I was wondering what the reason for this was?

Reference in code:

Thanks!

What's the performance when CNN is also trained?

Hi, Laura.
I found that with the vgg16 you trained with the ISIC dataset, actual training only happens on the FC layer at the end.
Have you tried training all the layers as well?
Thanks in advance!

ISIC images

Since the ISIC archive has changed, it's not clear anymore which images you used to train the models for the ISIC experiments. Specifically, it's unclear which images without patches were used. Would you be able to provide a table with the names of the included ISIC images?

Thanks!

Identifying Colorful Patches

Hello all,

First of all, I want to say great job on this paper and thank you so much for making the code available in such a usable fashion. I am trying to reproduce some of your results for the ISIC dataset. Specifically, I am trying to generate labels for benign images that identify the colorful patches. Is this something that can be done using your codebase? Let me know, thanks!

Hyperparameter of ColorMNIST

Hi, Laura.
I found that the performance(Test accuracy=23.41%) of CDEP that the jupyter notebook in this repo says in ColorMNIST dataset is different from the one in ICML paper(Test accuracy=31.0%).
So could you tell me the hyperparameters that you use to get the performance? (lr, lambda, batch_size, seed)
Thanks a lot!

Extension to natural images?

First, really interesting work -- I happened to chance upon this repo, as Github recommended it.

For computer vision e.g., ColorMNIST have you considered natural images? Is there a bottleneck in terms of memory or compute, that required an MNIST variant? (Or is it just more difficult to find an obvious bias in natural images?)

The text gender bias is really interesting. Will have take a closer read before asking questions though!

error in dataset creation of decoyMNIST

The dataset creation and loading for decoy has a bug (getting the error below)
Could it be looked into @laura-rieger @csinva ?

python train_mnist_decoy.py

Traceback (most recent call last):
  File "train_mnist_decoy.py", line 107, in <module>
    complete_dataset = utils.TensorDataset(train_x_tensor,train_y_tensor) # create your datset
  File "/home/avani.gupta/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/utils/data/dataset.py", line 365, in __init__
    assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
AssertionError: Size mismatch between tensors

Downloading data for the ISIC experiments

Hi!
I am trying to download the data for the reported ISIC experiments using the scripts you provided. I encountered the following issues:

  • 00_download_metadata.py prints Fetching metadata for 23906 images. The paper states The ISIC dataset consists of 21,654 images (19,372 benign). How can I figure out which images you used?
  • 01_download_imgs.py only downloads 50 images although in the config the value is set to 25000. Am I expected to change start_offset = 0 or the limit parameter in the URL somehow?

Thanks a lot for the great work!
Best regards
Verena

Why not fine-tune the weights of conv layers in the experiment of ISIC

Hi,

I notice that you choose to freeze the conv layers and only train the fully connected layers when using a pretrained VGG16 model in the ISIC experiment. Could you please, if convenient, tell me why not retrain the whole model since this seems to be an extra accuracy gain compared to training only fc layers.

All best.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.