Giter VIP home page Giter VIP logo

big-discriminator-batch-spoofing-gan's Introduction

BBMSG-GAN

BMSG-GAN with batch_spoofing, fid (while training) and Big discriminator.

big-discriminator-batch-spoofing-gan's People

Contributors

akanimax avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

big-discriminator-batch-spoofing-gan's Issues

It failed to train with 102flowers dataset

Hi, I tried BMSG-GAN with 102 flowers dataset ,and the result is very good, So I want to tried BBMSG for better result ,and I can not train it with 102 flower dataset with default parameters.

I trained 400 epochs but generate is still:
gen_400_18

Can you check it ??

Issues training on 2 GPUs

Hi there, I've been trying to get BBMSG-GAN running but haven't been able to get training started. I have two 11 GB 1080 Ti cards on Ubuntu 18.10 with pytorch 1.0.1 and python 3.6.

Once I run train.py (with python train.py --depth=8 --batch_size=32 --fid_batch_size=32 --spoofing_factor=64 --latent_size=512 --images_dir=datasets/epskal/ --sample_dir=samples/epskal_1 --model_dir=models/epskal_1) it gets to the first epoch and then stops doing anything. I can see that the memory has been allocated on my GPUs, but the utilization is 0%, so it's not actually doing any work. I've tried waiting overnight to see if it would start processing, but the next morning nothing had happened.

When trying to quit, the processes hang and I can't use my GPUs again without restarting the computer. I've added the logs below, here I've tried Ctrl+C'ing twice, afterwhich nothing else works to kill the processes. The same thing happens whether I've set the discriminator to be parallel or not.

Starting the training process ...

Epoch: 1
^CTraceback (most recent call last):
  File "train.py", line 310, in <module>
    main(parse_arguments())
  File "train.py", line 304, in main
    fid_batch_size=args.fid_batch_size
  File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 563, in train
    num_accumulations=spoofing_factor)
  File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 310, in optimize_discriminator
    loss = loss_fn.dis_loss(real_batch, fake_samples) / num_accumulations
  File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/Losses.py", line 200, in dis_loss
    r_preds = self.dis(real_samps)
  File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 197, in forward
    y = self.rgb_to_features[self.depth - 2](inputs[self.depth - 1])
  File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 75, in parallel_apply
    thread.join()
  File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1056, in join
    self._wait_for_tstate_lock()
  File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1072, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
^CException ignored in: <module 'threading' from '/home/hans/.conda/envs/pix/lib/python3.6/threading.py'>
Traceback (most recent call last):
  File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1294, in _shutdown
    t.join()
  File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1056, in join
    self._wait_for_tstate_lock()
  File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1072, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt

Any idea what I might be doing wrong / how I can get it working correctly?

P.S. am I right to assume that as long as batch_size * spoofing_factor = 2048 the results will be the same? Because without decreasing the batch size quite a bit I run out of memory especially with 1024x1024.

Training not progressing 1024x1024 on multigpu

I run a few hours on 8GPUs wihtout any progress. Each sample is pixelwise copy of each other in all layers.

4x4
image

64x64
image

setup

git clone [email protected]:akanimax/BBMSG-GAN.git
conda create -n bbmsg python==3.7
conda activate bbmsg
conda install pytorch torchvision cudatoolkit=10.0 cudnn scipy==1.2.0 tensorboard -c pytorch
pip install tensorboardX tqdm

training

# calc real fid stats before training
python ../BBMSG-GAN/sourcecode/train.py \
	--images_dir="$IMGS" \
       	--sample_dir="$SAMPLES" \
       	--model_dir="$MODELS" \
       	--depth=9 \
       	--batch_size=24 \
	--num_samples=36 \
	--feedback_factor=5 \
	--checkpoint_factor=1 \
	--num_epochs=50000 \
	--num_workers=90 \
       	--log_fid_values=True \
	--fid_temp_folder=/tmp/fid_tmp \
	--fid_real_stats="$FID" \
	--fid_batch_size=64 \
	--num_fid_images=5000

Wrong calculation of LSGAN

There exists an apparent error in the implementation of LSGAN below.

https://github.com/akanimax/BBMSG-GAN/blob/da82aa2e8507d17801bd2134a4ae754335d716f5/sourcecode/MSG_GAN/Losses.py#L145

Right implementation:

class LSGAN(GANLoss):

    def __init__(self, dis):
        super().__init__(dis)

    def dis_loss(self, real_samps, fake_samps):
        real_scores = th.mean((self.dis(real_samps) - 1) ** 2)
        fake_scores = th.mean(self.dis(fake_samps) ** 2)
        return 0.5 * (real_scores + fake_scores)

    def gen_loss(self, _, fake_samps):
        return 0.5 * th.mean((self.dis(fake_samps) - 1) ** 2)

reference equation 8 of [Least Squares Generative Adversarial Networks].

gif

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.