Giter VIP home page Giter VIP logo

capsnet-pytorch's Introduction

Dynamic Routing Between Capsules - PyTorch implementation

PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour, Nicholas Frosst and Geoffrey E. Hinton.

The hyperparameters and data augmentation strategy strictly follow the paper.

This is a forked version. I made some annotation and make it compatible with pytorch 1.1.0

Requirements

Only PyTorch with torchvision is required (tested on pytorch 0.2.0 and 0.3.0). Jupyter and matplotlib is required to run the notebook with visualizations.

Usage

Train the model by running

python net.py

Optional arguments and default values:

  --batch-size N          input batch size for training (default: 128)
  --test-batch-size N     input batch size for testing (default: 1000)
  --epochs N              number of epochs to train (default: 250)
  --lr LR                 learning rate (default: 0.001)
  --no-cuda               disables CUDA training
  --seed S                random seed (default: 1)
  --log-interval N        how many batches to wait before logging training
                          status (default: 10)
  --routing_iterations    number of iterations for routing algorithm (default: 3)
  --with_reconstruction   should reconstruction layers be used

MNIST dataset will be downloaded automatically.

Results

The network trained with reconstruction and 3 routing iterations on MNIST dataset achieves 99.65% accuracy on test set. The test loss is still slightly decreasing, so the accuracy could probably be improved with more training and more careful learning rate schedule.

Visualizations

We can create visualizations of digit reconstructions from DigitCaps (e.g. Figure 3 in the paper)

Reconstructions

We can also visualize what each dimension of digit capsule represents (Section 5.1, Figure 4 in the paper).

Below, each row shows the reconstruction when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 in the range [−0.25, 0.25].

Perturbations

We can see what individual dimensions represent for digit 7, e.g. dim6 - stroke thickness, dim11 - digit width, dim 15 - vertical shift.

Visualization examples are provided in a jupyter notebook

My Results

83.35% on CIFAR10 99.51% on MNIST,with reconstruction 99.58%

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.