Giter VIP home page Giter VIP logo

gradcam.pytorch's Introduction


This is the repository for Pytorch Implementation of "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization". If you have any issues regarding this repository, please contact [email protected].

You can see the original paper here

Modules

Requirements

See the installation instruction for a step-by-step installation guide. See the server instruction for server settup.

pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl
pip install torchvision
git clone https://github.com/meliketoy/gradcam.pytorch

Grad-CAM

"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization".

alt_tag

In this repo, we will be training and testing the model with a very simple, cat vs dog dataset. You can view and download the dataset yourself by clicking the link above.

Implementation on your own private data only requires modifications in the directory within the configuration files inside each modules.

STEP 1 : Data preperation

You can prepare your data with the preprocessing module. In the configuration file, set the directory to the directory containing the training data.

As we are fine-tuning the model, we will only be taking a small portion of the original training set.

$ cd ./1_preprocessor
$ python main

> Enter mode name : split # This will make a train-validation split in your 'split_dir' in config.py
> Enter mode name : check # This will print out the distribution of your split.
> Enter mode name : meanstd # This will print out the meanstd value of your train set.

Copy the value of meanstd in the third line, and paste it in the configurations of each module 3 and module 4. View the README-preprocessor for further instructions.

STEP 2 : Classification

Then, in the classifier module, run the line below

$ ./scripts/train/resnet

This will fine-tune a pre-trained resnet-50 model on your dataset. To train your network on different models & layers, view the scripts. See README-classifier for further instructions.

STEP 3 : Detection

After you have trained your model, there will be a model saved in the checkpoint directory. The files in directory will be automatically updated in the detector module, searched by the directory name of your training set.

In the configuration of module 4, match the 'name' variable identical to the 'name' you used in your classification training data directory name.

The heatmap generation for each of the test data can be done by running,

$ ./scripts/detect.sh

This will generate a heatmap which will look like

Attention for cat

alt-text-1 alt-text-2

Attention for dog

alt-text-1 alt-text-2

See README-detector for further instructions.

FUTURE WORKS : Semi-supervised Object Detection

This strategy could be used as a method to perform semi-supervised detection, a detection learning when only given the classification label and not any local annotations.

Implementation on luekocyte detection(which I submitted a paper on) will look like, alt_tag

If you want to change the model configuration, see the script or the configuration file

gradcam.pytorch's People

Contributors

bmsookim avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gradcam.pytorch's Issues

missing './logs/resnet.csv'

an error is encountered during training using the resnet - 50. the error message is
FileNotFoundError: [Errno 2] No such file or directory './logs/resnet.csv'
what could be the possible solution?

Curious About Your Paper!

In the end of the 'readme' , you mention ''Implementation on luekocyte detection(which I submitted a paper on) will look like''. I am curious about your work, can you give a paper link or send it to '[email protected]'. Thanks a lot.

GradCAM does not match exactly. Why?

Hi, @meliketoy

As a result of classifying with Resnet, Accuarcy is over 99%. If you hit map the object area with gradCAM with that model file, it does not match exactly. Why?
it does not match exactly. Why?

It seems to be a problem of GradCAM rather than Resnet classification learning. The objects to be hit-mapped are not as local or blob like dogs or cats, but close to a long straight line. In this case, GradCAM seems to miss the object area. Have you experienced this?

For a well-trainedd Resnet34 model, how do you optimize GradCAM?

Thanks, in advance.

from @bemoregt.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.