Giter VIP home page Giter VIP logo

mobilenetv3-tensorflow's Introduction

MobileNetV3 TensorFlow

Unofficial implementation of MobileNetV3 architecture described in paper Searching for MobileNetV3. This repository contains small and large MobileNetV3 architecture implemented using TensforFlow with tf.keras API.

Google Colab

  • Open In Colab MNIST
  • Open In Colab CIFAR10

Requirements

  • Python 3.6+
  • TensorFlow 1.13+
pip install -r requirements.txt

Build model

MobileNetV3 Small

from mobilenetv3_factory import build_mobilenetv3
model = build_mobilenetv3(
    "small",
    input_shape=(224, 224, 3),
    num_classes=1001,
    width_multiplier=1.0,
)

MobileNetV3 Large

from mobilenetv3_factory import build_mobilenetv3
model = build_mobilenetv3(
    "large",
    input_shape=(224, 224, 3),
    num_classes=1001,
    width_multiplier=1.0,
)

Train

CIFAR10 dataset

python train.py \
    --model_type small \
    --width_multiplier 1.0 \
    --height 128 \
    --width 128 \
    --dataset cifar10 \
    --lr 0.01 \
    --optimizer rmsprop \
    --train_batch_size 256 \
    --valid_batch_size 256 \
    --num_epoch 10 \
    --logdir logdir

MNIST dataset

python train.py \
    --model_type small \
    --width_multiplier 1.0 \
    --height 128 \
    --width 128 \
    --dataset mnist \
    --lr 0.01 \
    --optimizer rmsprop \
    --train_batch_size 256 \
    --valid_batch_size 256 \
    --num_epoch 10 \
    --logdir logdir

Evaluate

CIFAR10 dataset

python evaluate.py \
    --model_type small \
    --width_multiplier 1.0 \
    --height 128 \
    --width 128 \
    --dataset cifar10 \
    --valid_batch_size 256 \
    --model_path mobilenetv3_small_cifar10_10.h5

MNIST dataset

python evaluate.py \
    --model_type small \
    --width_multiplier 1.0 \
    --height 128 \
    --width 128 \
    --dataset mnist \
    --valid_batch_size 256 \
    --model_path mobilenetv3_small_mnist_10.h5

TensorBoard

Graph, training and evaluaion metrics are saved to TensorBoard event file uder directory specified with --logdir` argument during training. You can launch TensorBoard using following command.

tensorboard --logdir logdir

License

Apache License 2.0

mobilenetv3-tensorflow's People

Contributors

adelshb avatar martinkersner avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mobilenetv3-tensorflow's Issues

Number of filters in comparison to official Keras implementation

Hello

As far as I can see the large model conatins filters of size 16, 24, 24, 40, 40, 40, 80, 80, 80, 80, 112, 112, 160, 160, 160. On the other hand, the official Keras MobileNetV2 implementation contains filters of size 16, 24, 24, 32, 32, 32, 64, 64, 64, 64, 96, 96, 96, 160, 160, 160, 320.

Why does this implementation here differ in terms of the number of filters?

AttributeError: 'PrefetchDataset' object has no attribute 'output_shapes'

Could you please advise on this error?

I tried to resolve it through "tf.compat.v1.", but unsure of this new output_shapes error ---->

line 34, in build_dataset
dataset["channels"] = ds_train.output_shapes["image"][-1].value

AttributeError: 'PrefetchDataset' object has no attribute 'output_shapes'

~Sanket

it seems that the model is overfitting

hello,thanks for sharing your code.
But when I run the code, I find that the model is overfitting. For example, in cifar10 dataset, the train acc can reach 98%+, but val acc only 90%+.So what can I do?

Help: Keras with feed_dict mechanism of Session

Dear @martinkersner ,

This is a question rather than an issue. I'd be grateful if you could give your thought.

I'm using your code to test a small (but old) project. The project was developed based on Queue-based and feed_dict mechanism of Session. That means I can only use your Model (build_mobilenetv3 function) without other Keras-based functions (model.compile, model.fit, model.evaluate, etc.)

The problem is that I need to set Keras's learning_phase to true during training, and false during evaluation or inference. I tried several ways like:

  1. tf.keras.backend.set_learning_phase(True) / tf.keras.backend.set_learning_phase(False)
  2. feed_dict = {tf.keras.backend.learning_phase() : 1} / feed_dict = {tf.keras.backend.learning_phase() : 0}
  3. I also modified your code a bit such as:
model = BuildMobileNetV3(
        num_classes=5,
        width_multiplier=1.0,
        l2_reg=1e-5,
    )
input_shape=(150, 150, 3)
input_tensor = tf.keras.layers.Input(shape=input_shape)

And use,
model(input_tensor, training=True) / model(input_tensor, training=False)

However, none of the aforementioned methods works. The training cost and accuracy look good, but the validation accuracy's never been improved. I guess this is because of BN and dropout layers, but what else should I do beyond the above 3 approaches.

Thank you for your great work. I'm looking forward to discussing with you about the question.

Thanks.
Cuong

Did any one have the same problem?

2022-03-27 16:28:09.999279: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at iterator_ops.cc:867 : Cancelled: Operation was cancelled
Traceback (most recent call last):
File "/home/hanwenxing/mobilenet_v3/train.py", line 113, in
main(args)
File "/home/hanwenxing/mobilenet_v3/train.py", line 81, in main
callbacks=callbacks,
File "/home/hanwenxing/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit
use_multiprocessing=use_multiprocessing)
File "/home/hanwenxing/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit
steps_name='steps_per_epoch')
File "/home/hanwenxing/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 300, in model_iteration
batch_outs = f(actual_inputs)
File "/home/hanwenxing/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in call
run_metadata=self.run_metadata)
File "/home/hanwenxing/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in call
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Expected dimension in the range [0, 0), but got -1
[[{{node metrics/acc/ArgMax}}]]
(1) Invalid argument: Expected dimension in the range [0, 0), but got -1
[[{{node metrics/acc/ArgMax}}]]
[[metrics/acc/Identity/_941]]
0 successful operations.
0 derived errors ignored.

Process finished with exit code 1

Train on ImageNet?

Hi! Thanks for providing the code! I wonder did you pretrain it on ImageNet and could you please share it?

Not including top layers

Hello

Thanks for the code and the work. How can I remove the top layers (classification layer) so that I can add my own classification layers? Similar to when using MobileNetV2(weights=None, include_top=False) in tensorflow.keras.applications.

Implementing Semantic Segmentation Head?

Firstly, thanks for the implementation of the MobileNetV3!

Do you have an interest in implementing the semantic segmentation head described in section 6.4? I am currently planning to implement that, so if you have an interest, we can work together to do it!

one question about kernel 5*5

when train cifar10, initial size is 9696 , and when downscale size is less than kernel 55 ,how to deal these channels?

Training on simple classification tasks got nan result

INFO 05-13 18:44:56 classify_flowers_mbv3.py:94 - Epoch: 0, iter: 0, loss: 1.6094379425048828, train_acc: 0.0
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)
pred:  tf.Tensor([[[[nan nan nan nan nan]]]], shape=(1, 1, 1, 5), dtype=float32)

have you trained your implementation?

Saving model

Hello

As part of the ModelCheckpoint callback I would like to save the mobilentv3 model (to hdf5) and later load the model. When doing so I'm getting a NotImplementedError. Is is not possible save the model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.