Giter VIP home page Giter VIP logo

cat-recognition-train's Introduction

Cat-recognition-train

This repository demonstrates how to train a cat vs dog recognition model and export the model to an optimized frozen graph easy for deployment using TensorFlow. If you want to know how to deploy a flask app which recognizes cats/dogs using TensorFlow, please visit cat-recognition-app.

Requirements

  • Python3 (Tested on 3.6.8)
  • TensorFlow (Tested on 1.12.0)
  • NumPy (Tested on 1.15.1)
  • tqdm (Tested on 4.29.1)
  • Dogs vs. Cats dataset from https://www.kaggle.com/c/dogs-vs-cats
  • (Optional if you want to run tests) PyTorch (Tested on 1.0.0 and 1.0.1)

Build environment

We recommend using Anaconda3 / Miniconda3 to manage your python environment.

If the machine you're using does not have a GPU instance, you can just:

$ pip install -r requirements.txt

or

$ conda install --file requirements.txt

However, if you want to use GPU to accelerate the training process, please visit TensorFlow - GPU support for more information.

Train a Convolutional Neural Network

In this part, we will use TensorFlow to train a CNN to classify cats' images from dogs' image using Kaggle dataset Dogs vs. Cats. We will do the following things:

  • Create training/valid set (dataset.py)
  • Load, augment, resize and normalize the images using tensorflow.data.Dataset api. (dataset.py)
  • Define a CNN model (net.py)
    • Here we use the ShufflenetV2 structure, which achieves great balance between speed and accuracy.
    • We do transfer learning on ShuffleNetV2 using the pretrained weights from https://github.com/ericsun99/Shufflenet-v2-Pytorch.
    • If you want to know how to load PyTorch weights onto TensorFlow model graph, please check convert_pytorch_weight_test starting from line 44 in module_tests.py.
  • Train the CNN model (train.py)
  • Serialize the model for deployment (train.py)

If you want to execute the code, make sure you have all package requirements installed, and Dogs vs. Cats training dataset placed in datasets. The folder structure should be like:

cat-recognition-train
+-- train.py
+-- net.py
+-- dataset.py
+-- datasets
    +-- train
    |   +-- cat.0.jpg
    |   +-- cat.1.jpg
    |   ...
    |   +-- cat.12499.jpg
    |   +-- dog.0.jpg
    |   +-- dog.1.jpg
    |   ...
    |   +-- dog.12499.jpg
+-- ...

After all requirements set, run the following command using default arguments:

$ python train.py

Or you can pass your desired arguments:

$ python train.py --epochs 30 --batch_size 32 --valset_ratio .1 --optim sgd --lr_decay_step 10

See train.py for available arguments.

Visualizing Learning using Tensorboard

During training, you can supervise how is the training going by running:

$ tensorboard --logdir runs

And you can check the tensorboard summaries on localhost:6006.

Training and Validation Flow

Whole training and validation flow, including CNN model and other training/validation operations like optimizer, saver, accuracy counter, etc

Model Performance

Validation/Train loss and validation accuracy on each epoch

Optimized Network Graph

Optimized Network Graph

Predict Using Optimized Frozen Graph

See predict.py for details and demo.

Default image used for predict.py demo

You can run

$ python predict.py

The result should be:

Predicting catness on images/test.png using model from baseline_model/optimized_net_best_acc.pb
Catness: 16.460064
Cat Probability: 1.000000
It's a cat.

for demonstration. Also, if you have your own cat / dog photo for testing, run

$ python predict.py --path path/to/your/img.png

PNGs, JPGs, BMPs are supported.

cat-recognition-train's People

Contributors

leemengtw avatar mnicnc404 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cat-recognition-train's Issues

hao can train the gray image

Hello, I want to train the datasets which consist of the gray image,but it's occur the error as follows:
self.global_step: epoch})
File "/home/lijh/anaconda3/envs/weiwuhuhu/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/home/lijh/anaconda3/envs/weiwuhuhu/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "/home/lijh/anaconda3/envs/weiwuhuhu/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "/home/lijh/anaconda3/envs/weiwuhuhu/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with 'BM6\264\004\000\000\000\000\0006\004\000\000(\000'
[[{{node DecodeJpeg}}]]
[[IteratorGetNext]]
It looks like that formats error, how can i solve this,thank you your ealy reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.