Giter VIP home page Giter VIP logo

mxnet-gluon-syncbn's Introduction

MXNet-Gluon-SyncBN

created by Hang Zhang

A preview tutorial for MXNet Gluon Synchronized Batch Normalization (SyncBN) [1] . We follow the sync-onece implmentation described in the paper [2] . If you are not familiar with Synchronized Batch Normalization, please see this blog. Special thanks to Haibin for the technical support.

Jump to:

Install MXNet from Source

# clone the branch
git clone -b syncbatchnorm --recursive https://github.com/zhanghang1989/incubator-mxnet
# compile mxnet
cd incubator-mxnet && make -j $(nproc) USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1
# install python API
cd python && python setup.py install

How to use SyncBN

from syncbn import BatchNorm and use ModelDataParallel with the network (input and output are both a list of NDArray). Everything else looks the same as before

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn
from mxnet.gluon.nn import Block
# import SyncBN here
from syncbn import BatchNorm, ModelDataParallel

# create your own Block
class Net(Block):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2D(in_channels=3, channels=10,
                              kernel_size=3, padding=1)
        self.bn = BatchNorm(in_channels=10)
        self.relu = nn.Activation('relu')

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# set the contexts (suppose using 4 GPUs)
nGPUs = 4
ctx_list = [mx.gpu(i) for i in range(nGPUs)]
# get the model
model = Net()
model.initialize()
model = ModelDataParallel(model, ctx_list)
# load the data
data = mx.random.uniform(-1,1,(8, 3, 24, 24))
x = gluon.utils.split_and_load(data, ctx_list=ctx_list)
with autograd.record():
    y = model(x)

MNIST Example

Please visit the python notebook

Load Pre-trained Network

TODO

Reference

[1]Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." ICML 2015
[2]Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." CVPR 2018

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.