Giter VIP home page Giter VIP logo

Comments (34)

gwern avatar gwern commented on May 29, 2024 11

So to bring some Twitter comments back: as mentioned in #4 me & @FeepingCreature have tried changing the architecture in a few ways to try to improve learning, and we have begun to wonder about what exactly the Loss_D means.

In particular, compared to IllustrationGAN and StackGAN, WGAN struggles to handle 128px resolution and global coherency (eg in anime faces, severe heterochromia - the convolutions apparently decide individually on plausible yet mismatched eye colors). If we add on, using --n_extra_layers, 3 or 6 additional convolution layers for a 64px anime face model, training struggles even more and 6 is worse than 3 which is worse than none. (Albeit hilarious looking.) When we add in 1-4 fully-connected layers at the beginning of the generator like this:

+            nn.ConvTranspose2d(nz, nz, 1, 1, 0),
+            nn.LeakyReLU(0.2, inplace=True),

WGAN tends to get better and more globally coherent (but we still haven't matched IllustrationGAN/StackGAN). My interpretation is that the fully-connected layers are transforming the latent-z/noise into a sort of global template which the subsequent convolution layers can then fill in more locally. For symmetry, feep also tried adding additional layers to the discriminator with their associated batchnorms and changing the penultimate layers.

One interesting thing about adding convolution+batchnorm or fully connected layers to 128px is that it makes Loss_D wildly different. This isn't too surprising since it's noted that Loss_D isn't directly comparable across different models, but what's a little more unpleasantly surprising is that it seems to affect training. If we add in more FC layers, the Loss_D can start at -300 or higher and trains somewhat normally there but with considerable difficulty with lots of oscillations in loss & sample appearance overnight. While when we made the discriminator changes, suddenly the starting Loss_D is -0.0003 and it makes no progress beyond static overnight. Increasing --Diters dramatically to 25 or 50 or higher, as recommended in #2, does not fix this at all. Another change that dramatically affects the loss scale is changing the clipping weights: increasing --clip_lower/--clip_upper to 0.05 will push the loss into the 100s.

I began thinking about what the Wasserstein distance/Loss_D means and wondered: it defines the loss for the gradients, right? Literally? So wouldn't -300 represent absurdly large gradients, and -0.0003 represent almost no gradients at all? And the former would be only ~10x what reasonable Loss_Ds are (1-10) while the latter represents gradients 10,000x too small, explaining why they worked so differently. Plus, isn't the effect of batchnorm to rescale outputs to N(0,1), so a batchnorm near the top of the discriminator would tend to make its outputs small, while on the other hand adding additional layers can make the final numbers much larger because they can vary over a wider range?

I tried multiplying one/mone in main.py by 1000 to see if that helped but it didn't really, so I switched to changing the learning rates. For the -0.0003 WGAN, I found that a lrD/G of 0.001 (20x higher than the default 0.00005; I assume that the reason it's not 10,000x higher has to do with the exact definition in the gradient descent formula) made steady progress and has yielded respectable 128px faces overnight.

If it was merely the case that we aren't training the discriminator adequately, then my first fix of increasing --Diters should've worked. And also if that was the case, then the -0.0003 WGAN should've worked brilliantly from the start - after all, it had a tiny Loss_D, always, which should indicate that it's an awesome discriminator/generator. And further if that was the case, then increasing the learning rate so much should have broken it, not made it work.

So it seems like the absolute size of Loss_D matters somehow to training, and there's a sweet spot for a model with lr=0.00005 to start with a Loss_D of 1-10 magnitude - smaller, and it doesn't learn at a reasonable speed, larger, and it oscillates wildly without making progress (learning too much from each minibatch?).

There doesn't seem to be any important meaning to the Loss_D being 3 rather than 300 or 0.3, it's just there to provide gradients for the generator, right? It's a weak distance. So maybe a robust WGAN would do something like take the Loss_D in the first minibatch after training the discriminator to optimality, and then rescale it to 1-10 thereafter? (ie something like batchnorm for the final layer/output - although I don't know if it's better for the loss to be scaled every minibatch like in batchnorm, because then you wouldn't see it dropping as progress is made.)

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 29, 2024 4

Protip: something else you can do is detect when the curve is making a big jump like this and keep iterating the critic till the curve goes up again to roughly were it was before the jump.

We hope to have a more precise "convergence criterion" for the critic further on so that this is no longer an issue.

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024 4

Very useful perspective here. I'm finding that my generator loss decreases over time but the sign goes from positive to negative. Do you have intuitions about how to interpret the sign?

from wassersteingan.

soumith avatar soumith commented on May 29, 2024 1

did you by any chance crank up the learning rate? or change the number of D iterations per G iteration? if either of these is true, the wasserstein approximation might be off.

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024 1

hey @KingStorm Yea I got RNN generator and RNN Discriminator to work but they just don't perform as well as conv1d -- the trick was to use the improved WGAN paper that was just released. I just don't know what the problem is, but they just don't converge as well as conv1d (which is a good step but not that good).

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024 1

Yes I did every combination: lstm generator and lstm discriminator with conv generator and conv discriminator. Unfortunately conv generator and conv discriminator has yielded best results so far. I tried densenets as well and they did worse. Also tried fractal discriminators but those did worse. If one of the authors wants to help with this discussion for more ideas, I'm for trying anything. Really not sure why LSTM's don't converge better.

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 29, 2024 1

@thvasilo hey theodore! I didn't know you were interested in GAN paixnidia ;-)

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024 1

@thvasilo I would point you to this repo -- has much more advanced techniques already working :)
https://github.com/amirbar/rnn.wgan/

from wassersteingan.

LukasMosser avatar LukasMosser commented on May 29, 2024

Yes I did, the learning rate was set to 5e-4 for both D and G.

Are there any theoretical or practical bounds on these? (I'll check the paper as well)

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 29, 2024

This type of curve is a strong sign that the discriminator isn't trained till optimality (and therefore it's error doesn't correspond to the performance of the generator). That can be for a number of reasons:

  • Momentum on the discriminator.
  • High learning rates.
  • Low number of disc iterations.

There can be other reasons that are problem dependent, but the general principles is that anything that helps have the disc trained till optimality is going to help with this issue.

from wassersteingan.

LukasMosser avatar LukasMosser commented on May 29, 2024

Loving the protip and the quick responses!
Gonna give it a shot, thanks!

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 29, 2024

Lukas, did you try Martin's suggestion to create heuristics for the number of iterations of the Discriminator? I'm having a similar issue

from wassersteingan.

LukasMosser avatar LukasMosser commented on May 29, 2024

@rafaelvalle Yes, I have tried increasing the number of iterations for the discriminator.
What I found was that the increase leads to
a) a much longer runtime
b) no visible decrease in lossD (after 10000 iterations)

Ok I am training this on a different dataset, but still, if you look at the paper, we're talking 100k generator iterations, that's 500k discriminator iterations with default values. Depending on your hardware, that might take a while 'til you start seeing results.

Nevertheless increasing learning rate and not worrying about sudden drops gave me best results in short time, but have not checked for mode collapse etc.

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 29, 2024

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024

@rafaelvalle is there something specific about RNNs that don't agree with WGANs?

I'm using LSTMs to generate sequences with a WGAN loss. My model is learning to decrease d_loss rapidly, with improvement in generated samples until the point when it results in something like mode collapse.

I'd like to diagnose the issue correctly. Any insights? Thanks!

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 29, 2024

@pavanramkumar Assuming that a Recurrent GAN is harder to train than DCGAN, mode collapse can come from the critic "not being trained till optimality, which makes it impossible to collapse modes. "
Also, covariate shift might be an issue.

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024

@rafaelvalle thanks!

increasing d_iters beyond 100 does not change the learning curve much, so i'm wondering what if any are other metrics of optimal critic training.

i think covariate shift is likely due to several deterministic layers. we've used batch norm throughout, but it's probably inadequate for the covariate shift that our deterministic layers have created.

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024

@pavanramkumar is your discriminator a RNN as well? Have you considered switching the disc/critic to a CNN?

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024

@NickShahML a lot of our mode collapse problems went away when we stopped using LSTMs in both the generator and the discriminator. Tuning the relative complexity and regularization in both G and D is still a black art though!

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 29, 2024

@pavanramkumar were you able to train a RNN generator with adversarial loss?

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024

@pavanramkumar @rafaelvalle ... as a follow up what exactly are passing to the discriminator from the generator? Do you pass the softmaxed logits?

Supposed we are softmaxing over 100 different chars for each timestep

Generator --> Logits shape [batch size, timesteps, 100] --> Discriminator

Trying to figure out the best way to pass outputs from the generator to the discriminator to lessen complexity of the task.

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 29, 2024

Quick thing, it is possible that the clipping in RNNs creates a lot of vanishing gradients. I've heard that using layer normalization is aleviates this on wgan-rnns so it's a good thing to try.

I will look at this more in depth next week when I have more time.

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024

@martinarjovsky , I'm only clipping on the discriminator which is a stacked CNN. I'm specifically testing the difference between RNN and CNN in the generator. Are we supposed to clip in the generator as well? That doesn't make too much sense to me. If you're talking specifically about the discriminator, then I would understand.

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 29, 2024

Ah I see. Nono don't clip on the generator :).

OK will take a look.

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024

Yes, we generate softmax over the vocabulary (batch_size, n_symbols, n_sequence_length) to feed into the discriminator, in addition to several additional tricks that may not be directly relevant to your problem.

We also observed that increasing the clipping threshold in the discriminator led to divergence (the generator produces random samples but doesn't really learn to move towards the data).

Using dense layers in both discriminator and generator helped alleviate both mode collapse and divergence.

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024

Thanks @pavanramkumar -- I'll try adding more dense layers and see if that helps resolve issues. I'll also keep the clipping at [-0.01,0.01]. So to clarify, did you actually get a RNN generator to work? Or did you stick with CNN's for both the generator and discriminator.

from wassersteingan.

pavanramkumar avatar pavanramkumar commented on May 29, 2024

@NickShahML you're welcome. we just used dense layers for both. In fact, some keras code is here if it helps: https://github.com/tree-gan/BonsaiNet

We hope to improve our README soon.

from wassersteingan.

NickShahML avatar NickShahML commented on May 29, 2024

Oh wow I didn't realize you literally only used Dense layers. I'm still working on correcting the negative loss for the discriminator. But definitely good to know that Dense layers helped you reach convergence on your task.

from wassersteingan.

KingStorm avatar KingStorm commented on May 29, 2024

@NickShahML Hi nick, I am also using RNN as generator with CNN as critic, the D_loss seems weird to fluctuate a lot. Did you by any chance make your RNN generator work?

from wassersteingan.

KingStorm avatar KingStorm commented on May 29, 2024

@NickShahML Thanks for the reply :). May I ask have you ever tried to use conv net as discriminator and RNN as generator while using "Improved Training of WGAN". Tensorflow seems have some problems in calculating second-order derivatives of LSTM.

from wassersteingan.

thvasilo avatar thvasilo commented on May 29, 2024

@NickShahML Do you have a repo with your experiments? I'd like to try out some of these combinations as well.

from wassersteingan.

Naxter avatar Naxter commented on May 29, 2024

Hey guys, @NickShahML @KingStorm @pavanramkumar have you guys been successful in the end?
I also face the same problem as you guys. My normal RNN GAN (Generator Dense+LSTM layer and Discriminator conv1D) has the problem of mode collapse.
I switched to WassersteinGAN but as you guys described, the loss is going crazy now. (Anyways having better results than before).
So did "Improved Training of WGAN" really "improved" your results? Did the loss got more meaningful?

from wassersteingan.

fsalmasri avatar fsalmasri commented on May 29, 2024

I'm wondering if your iterating in all your samples inside each of D iteration or you are doing D iteration on on minibatch sample?

from wassersteingan.

Kaede93 avatar Kaede93 commented on May 29, 2024

@gwern
I'm doing some works similar as you did, and I found that the wgan loss seems very sensitive to the network structures, input shape, etc.

As you mentioned, the Wasserstein distance could be very small values (about 1e-4), and same phenomenon occured in my experiments. I agree with your gradient's theory in WD definition under my experiment observations, but it seems very hard (or unable) to train a GAN with a very small WD at the begining of training, in fact, the model suffer from gradients vanishing in my experiments under this situation. On the contrast, the GAN is much easy to train with a large WD at the begining. So I think the scale of WD is not much important but we should expect it is a large values (maybe 1e-1~1e2) rather than (~1e-4) in our GAN.

from wassersteingan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.