Giter VIP home page Giter VIP logo

csc1004-python-imageclassification's Introduction

CSC1004-python-project

Scores: 100/100

Task description

The project is about using Convolutional Neural Networks (CNN) to classify handwritten digits. minist-img

To achieve it, the code will do the following:

  • Load the training config.
  • Download the minist dataset (divided into training and testing sets).
  • Construct the neural network.
  • Update the network parameters with training dataset by minimizing the loss. (Training).
  • Test the neural network with the testing dataset. (Testing)
  • Plot the results.

How the Deep Learning is implemented

The given code is in main_given.py

My implementation is in main.py

See original repository: here

My repository: here

main function
  1. Load the config and random seeds.
  2. Run the multiprocessing with different seeds.
  3. Plot the results. They are recorded by the .txt file.
run_with_seed and run
  1. Modify the config to the seed assigned, then run with the updated config.
  2. In the run function, first assign the seed and device (CPU or CUDA).
  3. The rest of the run part was given. It mainly does the following:
    1. Set DataLoader arguments based on whether CUDA is available
    2. Load and preprocess the MNIST dataset
    3. Initialize the DataLoaders for training and testing, model, optimizer, and learning rate scheduler
    4. Train and test the model for each epoch as well as update the leraning rate scheduler. At the same time records the statistics.
    5. Plot the training and testing performance
    6. Save the results to a file and save the trained model
train function
  1. Iterate through each batch in the training dataset loader.
  2. Move the input data and target labels to the device.
  3. Reset the gradients to zero for the optimizer.
  4. Forward pass: compute the model's output based on the input data.
  5. Compute the loss between the model's output and target labels.
  6. Backward pass: compute gradients of the loss with respect to the model's parameters.
  7. Update the model's parameters using the gradients.
  8. Accumulate the total loss for current batch.
  9. Get the index of the max log-probability (i.e., the predicted class) for each sample in the batch.
  10. Compare the predicted class with the target class and count the number of correct predictions.
  11. After the for loop: Compute the average loss and accuracy for this training set.
test function
  1. Iterate through each batch in the test dataset loader.
  2. Move the input data and target labels to the device.
  3. Forward pass: compute the model's output based on the input.
  4. Compute the loss, take the summation.
  5. Get the index of the max log-probability (i.e., the predicted class) for each sample in the batch.
  6. Count the number of correct predictions .
plot function & plot_mean

Simply use the pyplot API to plot the epoch and performance data along with the title and label as required.

The plot_mean function reads the data recorded in text files, use a list to store the epoch and use the dictionary for each epoch to store the four datas used for plotting. Then convert each group of data who belong to the same plot into a list as the "performance" parameter of the plot function.

csc1004-python-imageclassification's People

Contributors

hyle33ies avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.