Giter VIP home page Giter VIP logo

non-local-nn-pytorch's Introduction

PyTorch Implementation of Non-Local Neural Network

This repository contains my implementation of Non-Local Neural Netowrks (CVPR 2018).

To understand more about the structure of this paper, you may refer to this slide and video which is in Korean.

The experiment was run on CIFAR-10 dataset for the sake of ensuring that the code runs without error.

Implementation Details

The original paper used ResNet-50 as its backbone structure for conducting experiment on video datasets such as Kinetics, Charades.

As an inital study, I adopted ResNet-56 strucutre for CIFAR-10 dataset which is a 2D classification. The architecture is implemented in models/resnet2D.py.

Original baseline model from the paper called C2D uses ResNet-50 as its backbone and 1 non-local block after the 4th residual block. This structure is implemented in models/resnet3D.py. The detail of the architecture is shown in the below figure:

The four different pairwise functions discussed in the paper are implemented accordingly in models/non_local.py. You can simply pass one of the operation as an argument. The details of the non-local block is shown in the below figure:

Finally, the original experiment of activity recognition was similarly replicated in 3D_experiment folder. The necessary data preprocessing code was borrowed from https://github.com/kenshohara/3D-ResNets-PyTorch. The training is run without error but I didn't have enough time to compare the performance boost from the addition of non-local block.

Training

  1. To start training for CIFAR-10 with ResNet-56, you can simply execute run.sh.

  2. To start training for HMDB51 dataset with C2D, you first need to prepare the HMDB51 dataset as instructed in the 3D_experiment folder. Then, execute run.sh. It seems like use of multiple GPU(s) may be need due to memory issues.

Results

Trained on CIFAR-10 for 200 epochs using the command shown in run.sh. The training was conducted using single 1080ti GPU. The result showed that there wasn't a huge performance boost for image classification task on CIFAR-10. The below graph illustrates the loss curves for two different networks.

The Top-1 validation accuracy for ResNet-56 without non-local block was 93.97% while the one with non-local block had 93.98% validation accuracy.

This could be due to two reasons: 1) the proposed task was mainly for video classification 2) the input size of CIFAR-10 is too small so may not maintain spatial information after the second resnet block.

TO DO

  • Compare the result of baseline model and that of non-local model for CIFAR-10
  • Prepare video dataset (e.g. UCF-101, HMDB-51)
  • Modify the model code to adapt to spatiotemporal settings
  • Run test on some video datasets
  • Run test on image segmentation dataset (e.g. COCO)

Reference

This repo is an adaptation from several other exisitng works.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.