Giter VIP home page Giter VIP logo

pytorch-image-classifier's Introduction

Pytorch Image Classification

Project Summary

This project aims to classify flower images using a deep learning model trained with PyTorch. The model uses a pretrained feature extractor and is trained with data augmentation and normalization to improve accuracy. A command-line application is also developed to predict the class of a flower image and display the top K classes with associated probabilities. The application allows users to train a new network on a given dataset, predict the class, and display the top K classes with associated probabilities. The project demonstrates how to load data, train a deep learning model, and create a command-line interface to interact with the model.

Files

The following files are included in this repository:

train.py: This script is used to train a new network on a dataset of images. It allows users to set hyperparameters for learning rate, number of hidden units, and training epochs. The script also allows users to choose from at least two different architectures available from torchvision.models and to choose whether to train the model on a GPU. The training loss, validation loss, and validation accuracy are printed out as a network trains.

predict.py: This script is used to predict the class of a flower image and display the top K classes with associated probabilities. It allows users to load a trained model checkpoint, map class values to other category names using a JSON file, and to choose whether to use the GPU to calculate the predictions.

get_function.py: This script contains utility functions used in train.py and predict.py, including a function to load and preprocess images.

get_model.py: This script contains the function to load a pretrained feature extractor and define a new classifier.

Usage

Training

To train a new network, run the train.py script in the command line with the following arguments:

python train.py data_directory --arch "resnet18" --learning_rate 0.0003 --hidden_units 5120 --epochs 10 --gpu
  • data_directory : help= 'the directory where the training data is stored'
  • --save_dir : type=str, the directory where checkpoints will be saved
  • --arch : default='resnet18', choices=['efficientnet_v2_l','densenet121'], help='the architecture to use for the network'
  • --learning_rate : type=float, default=0.0003, help='the learning rate to use for the optimizer'
  • --hidden_units : type=int, default=5120, help='the number of units in the hidden layer'
  • --epochs : type=int, default=10, help= 'the number of epochs to train for'
  • '--gpu' : toggle to use GPU for training

Prediction

To predict the class of a flower image, run the predict.py script in the command line with the following arguments:

python predict.py "/path/to/image" checkpoint --category_names cat_to_name.json --top_k 5 --gpu
  • input_path : help= the path to the image file
  • checkpoint : help= the path to the checkpoint file
  • --top_k : type=int, default=5, help= return the top K most likely classes
  • --category_names : help= the file containing the category names
  • --gpu : help = toggle to use GPU for inference

Dependencies

This project requires the following dependencies:

  • Python 3.x
  • torch==1.13.1
  • torchvision==0.14.1
  • numpy
  • pandas
  • Matplotlib

These dependencies can be installed using pip and the requirements.txt file included in this repository:

pip install -r requirements.txt

Acknowledgements

This project was completed as part of the Udacity AI and ML Nanodegree program. The flower dataset used in this project is the 102 Category Flower Dataset by Maria-Elena Nilsback and Andrew Zisserman

pytorch-image-classifier's People

Contributors

pvanand07 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.