Giter VIP home page Giter VIP logo

pytorch_wgan_gp's Introduction

PyTorch WGAN GP

This repository is only for training. Later, I will provide/upload pretrained weight.

Version

  • pytorch=1.4.0
  • pyyhon3.6
  • cuda10.0.x
  • cudnn7.6.3
  • environment.yaml should be used for reference only, since it has too many dependencies.

Dataset

Dataset CelebA HQ FFHQ(thumbnails)
size 1024 x 1024 128 x 128
# of images 30000 70000

(In training, CelebA HQ images are resized in 128 x 128 or 64 x 64)

You should change the directory name in data_loder.py.

Options and Help

wgan_gp$ python main.py -h
usage: main.py [-h] [--main_gpu MAIN_GPU] [--use_tensorboard USE_TENSORBOARD]
               [--checkpoint_dir CHECKPOINT_DIR] [--log_dir LOG_DIR]
               [--image_name IMAGE_NAME] [--train_data TRAIN_DATA]
               [--optim OPTIM] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]
               [--latent_dim LATENT_DIM]
               [--generator_upsample GENERATOR_UPSAMPLE]
               [--weight_init WEIGHT_INIT] [--norm_g NORM_G] [--norm_d NORM_D]
               [--nonlinearity NONLINEARITY] [--slope SLOPE]
               [--batch_size BATCH_SIZE] [--iter_num ITER_NUM]
               [--img_size IMG_SIZE] [--loss LOSS] [--n_critic N_CRITIC]
               [--lambda_gp LAMBDA_GP]

optional arguments:
  -h, --help            show this help message and exit
  --main_gpu MAIN_GPU   main gpu index
  --use_tensorboard USE_TENSORBOARD
                        Tensorboard
  --checkpoint_dir CHECKPOINT_DIR
                        full name is './checkponit'.format(main_gpu)
  --log_dir LOG_DIR     dir for tensorboard
  --image_name IMAGE_NAME
                        sample image name
  --train_data TRAIN_DATA
                        celeba or ffhq
  --optim OPTIM         Adam or RMSprop
  --lr LR               learning rate
  --beta1 BETA1         For Adam optimizer.
  --beta2 BETA2         For Adam optimizer.
  --latent_dim LATENT_DIM
                        dimension of latent vector
  --generator_upsample GENERATOR_UPSAMPLE
                        if False, using ConvTranspose.
  --weight_init WEIGHT_INIT
                        weight init from normal dist
  --norm_g NORM_G       inorm : instancenorm, bnorm : batchnorm, lnorm :
                        layernorm or None for Generator
  --norm_d NORM_D       inorm : instancenorm, bnorm : batchnorm, lnorm :
                        layernorm or None for discriminator(critic)
  --nonlinearity NONLINEARITY
                        relu or leakyrelu
  --slope SLOPE         if using leakyrelu, you can use this option.
  --batch_size BATCH_SIZE
                        size of the batches
  --iter_num ITER_NUM   number of iterations of training
  --img_size IMG_SIZE   size of each image dimension
  --loss LOSS           wgangp or bce, default is wgangp
  --n_critic N_CRITIC   number of training steps for discriminator per iter
  --lambda_gp LAMBDA_GP
                        amount of gradient penalty loss

Training

When I use 'RMSprop', it has the best training performance.

wgan_gp$ python main.py --main_gpu 4 \
                        --log_dir gpu4 \
                        --train_data celeba
                        --latent_dim 128 \
                        --image_name gpu_4.png \
                        --batch_size 32 \
                        --n_critic 5 \
                        --lr 0.00005 \
                        --lambda_gp 10 \
                        --optim RMSprop \
                        --generator_upsample True \
                        --norm_g bnorm \
                        --norm_d lnorm \
                        --nonlinearity leakyrelu \
                        --slope 0.2 \
                        --loss wgangp

Results - CelebA HQ

Tensorboard

To open tensorbaord window in local, if you run it on remote server, you should follow this command in local.

ssh -NfL localhost:8898:localhost:6009 [USERID]@[IP]

'8898' is arbitrary port number for local , and '6009' is arbitrary port number for remote.

pytorch_wgan_gp's People

Contributors

hichoe95 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

pytorch_wgan_gp's Issues

A question about the BN layer

Hi, thanks for sharing your code.

I just read through your code and see the use of BN layer in the discriminator. I'd like to ask whether that caused any problem in the training, because in theory that will makes the gradient penalty loss useless. In the paper <Improved..>, the author suggest using Layer normalization instead of BN layer, but I did not find any further information about this.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.