Giter VIP home page Giter VIP logo

pytorch-mobilenet-v2's Introduction

A PyTorch implementation of MobileNetV2

This is a PyTorch implementation of MobileNetV2 architecture as described in the paper Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation.

[NEW] I fixed a difference in implementation compared to the official TensorFlow model. Please use the new model file and checkpoint!

Training Recipe

Recently I have figured out a good training setting:

  1. number of epochs: 150
  2. learning rate schedule: cosine learning rate, initial lr=0.05
  3. weight decay: 4e-5
  4. remove dropout

You should get >72% top-1 accuracy with this training recipe!

Accuracy & Statistics

Here is a comparison of statistics against the official TensorFlow implementation.

FLOPs Parameters Top1-acc Pretrained Model
Official TF 300 M 3.47 M 71.8% -
Ours 300.775 M 3.471 M 71.8% [google drive]

Usage

To use the pretrained model, run

from MobileNetV2 import MobileNetV2

net = MobileNetV2(n_class=1000)
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
net.load_state_dict(state_dict)

Data Pre-processing

I used the following code for data pre-processing on ImageNet:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

input_size = 224
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0)), 
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=n_worker, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(int(input_size/0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=n_worker, pin_memory=True)

pytorch-mobilenet-v2's People

Contributors

tonylins avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.