Giter VIP home page Giter VIP logo

fashion-mnist-classifier's Introduction

ImageClassifier

Codes for the first homework for machine learning course DATA620004, School of Data Science, Fudan University. Please follow the following instructions to run the codes.

Prepare dataset

In the path data/, run download.sh to obtain the Fashion-MNIST dataset.

Train and test

To train and test the model under certain hyper parameters, modify the configurations in main.py if needed:

# configs
hidden_sizes = [512, 128]
activation_type = 'relu'
init_lr=1e-3
gamma=0.95
l2_reg=1e-3
train_ratio = 0.95

The above parameters are described as follows:

  • hidden_sizes: the number and sizes of hidden layers. For example, [512, 128] means there are two hidden layers with sizes 512 and 128, respectively.
  • activation_type: the activation function used in hidden layers, relu, sigmoid and tanh are supported.
  • init_lr: specify the initial learning rate
  • gamma: the parameter for exponential decay of learning rate during training, which is in (0, 1). Larger value results in faster decrease of learning rate.
  • l2_reg: the parameter for L2 regularization during training.
  • train_ratio: the ratio of the training set. We use this parameter since the Fashion-MNIST dataset does not provide validation set, thus we should preserve some samples in the original training set for validation.

Use the following command to train and test a model:

python main.py

The resulting parameter file of the trained model appears in ./outputs/*.pkl, and the training and validation loss and accuracy are saved in ./outputs/*.csv. The visualization of loss and accuracy curves, and the visualization of model parameters and samples from corresponding categories are saved in ./outputs/*.png.

Grid search for hyper parameters

The candidate hyper parameters are defined in grid_search.py:

# configs
param_grid = {
    'hidden_size_1': [128, 512, 1024],
    'hidden_size_2': [32, 64, 128],
    'init_lr': [0.01, 0.005, 0.001, 0.0005, 0.0001],
    'l2_reg': [0.01, 0.001, 0.0001, 0.00001],
}

Use the following command to execute the hyper parameter search, in which the model will be trained and tested automatically:

python grid_search.py

The resulting validation and test accuracies along with corresponding hyper parameters are saved in ./outputs/grid_search_results.csv.

fashion-mnist-classifier's People

Contributors

yz-cai avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.