Giter VIP home page Giter VIP logo

gan-pytorch's Introduction

Fully Connected GAN(also known as Vanilla GAN) in Pytorch

This repository contains code of FCGAN which is tested and trained on MNIST and CIFAR-10 datasets. It is based on Pytorch framework.

Generative Adversarial Networks

GANs are generally made up of two models: The Artist (Generator) and The Critic (Discriminator). The generator creates an image from random noise, and the discriminator evaluates the generated image with the images in the given dataset. We train the models by minimaxing the costs of the models. The generator tries to fool the discriminator by producing realistic looking images, and the discriminator becomes better in understanding the difference between real and fake images. This two player game improves the model until the generator produces realistic images or the system reaches nash equilibrium.

Contents

  1. Setup Instructions and Dependencies
  2. Training Model from Scratch
  3. Generating Images from Trained Models
  4. Model Architecture
  5. Repository Overview
  6. Results Obtained
    1. Generated Images
    2. Parameters Used
    3. Loss Curves
  7. Observations
  8. Credits

1. Setup Instructions and Dependencies

You may setup the repository on your local machine by either downloading it or running the following line on terminal.

git clone https://github.com/h3lio5/gan-pytorch.git

The trained models are large in size and hence their Google Drive links are provided in the model.txt file.

The data required for training is automatically downloaded when running train.py.

All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running

pip install -r requirements.txt

Make sure to have CUDA 10.0.130 and cuDNN 7.6.0 installed in the virtual environment. For a conda environment, this can be done by using the following commands:

conda install cudatoolkit=10.0
conda install cudnn=7.6.0

2. Training Model from Scratch

To train your own model from scratch, run

python train.py -config path/to/config.ini
  • The parameters for your experiment are all set by defualt. But you are free to set them on your own.
  • The training script will create a folder exp_name as specified in your config.ini file.
  • This folder will contain all data related to your experiment such as tensorboard logs, images generated during training and training checkpoints.

3. Generating Images from Trained Models

To generate images from trained models, run

python generate.py --dataset mnist/cifar-10 --load_path path/to/checkpoint --grid_size n --save_path directory/where/images/are/saved

The arguments used are explained as follows

  • --dataset requires either mnist or cifar10 according to what dataset the model was trained on.
  • --load_path requires the path to the training checkpoint to load. Point this towards the *.index file without the extension. For example -load_path training_checkpoints/ckpt-1.
  • --grid_size requires integer n and will generate n*n images in a grid.
  • --save_path requires the path to the directory where the generated images will be saved. If the directory doesn't exist, the script will create it.

To generate images from pre-trained models, download checkpoint files from the Google Drive link given in the model.txt file.

4. Model Architecture

Generator Model

  • MNIST: The generator model is a 5-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.
  • CIFAR10: The generator model is a 6-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.
  • Input is a 100-dimensional noise. It is passed through the network to produce either a 28x28x1 (MNIST) or 32x32x3 (CIFAR-10) image.

Discriminator Model

  • MNIST: The discriminator model is a 3-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.
  • CIFAR10: The discriminator model is a 4-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.
  • Output is a single number which tells if the image is real or fake/generated.

5. Repository Overview

This repository contains the following files and folders

  1. experiments: This folder contains data for different runs.

  2. resources: Contains media for readme.md.

  3. data_loader.py: Contains helper functions that load and preprocess data.

  4. generate.py: Used to generate and save images from trained models.

  5. model.py: Contains helper functions that create generator and discriminator models.

  6. model.txt: Contains google drive links to trained models.

  7. requirements.txt: Lists dependencies for easy setup in virtual environments.

  8. train.py: Contains code to train models from scratch.

6. Results Obtained

i. Generated Images

Samples generated after training model for 100 epochs on MNIST.

mnist_generated

Samples generated after training model for 200 epochs on CIFAR-10.

cifar_generated

ii. Parameters Used

  • Optimizer used is Adam
  • Learning rate 0.0002, beta-1 0.5
  • Trained for 100 epochs (MNIST) and 100 epochs (CIFAR10)
  • Batch size is 128 for both (MNIST) and (CIFAR10)
  • The model uses label flipping (i.e. real images are assigned 0 and fake images are assigned 1)

iii. Loss Curves

MNIST

CIFAR-10

7. Observations

MNIST

The model took around 12 minutes to train for 100 epochs on the gpu. The generated images are not that sharp but somewhat resemble the real data. The model is also prone to mode collapse.

Training for long duration (150+ epochs) does not seem to improve the model's performance and sometimes even deteriorates the quality of the images produced.

CIFAR-10

Training on the CIFAR10 dataset was challenging. The dataset was varied and the network has a higher number of parameters to train. The model was trained for 200 epochs and took about 30 minutes to train.

However the main problem faced by me was observing 32x32 images and evaluating if they were 'good enough'. The images are too low-resolution to properly understand the subject but they are easily passable since they look quite similar to the real data.

Some images have noise but most images don't have much artifacts in them. This is partly due to the network training on all of the 10 labels of the CIFAR-10 dataset. Better results could be obtained by only training the network on one particular label at a time but this takes away the robustness of the model.

8. Credits

To make this repository I referenced multiple sources:

gan-pytorch's People

Contributors

h3lio5 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

gan-pytorch's Issues

Judging Generator Performance by Seeing D(G(z))

Hi,

Thanks much for creating this repo, it's really useful. I'm curious about this quoted line below, why does D(G(z)) that is close to 0 mean that generator successfully fools the Discriminator? Is it that what we want is maximizing D(G(z)) to be as close as 1, i.e., generator successfully generates fake image (label =0) and the discriminator predicts it as real image (label=1, by having D(G(z)) close to 1)?

# If the value of this probability is close to 0, then it means that the generator has

switching k=4 to k=1 allows to generate images of WAY better quality

Hi,

As the title says, replacing k=4 by k=1 and keeping the rest of the code untouched allows the quality of the generated handwritten digits to be two or three times better.

By the way it seems that k=1 is the value used in the original paper experiments :

The number of steps to apply to the discriminator, k, is a hyperparameter. We used k = 1, the least expensive option, in our experiments.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.