Comments (34)
So to bring some Twitter comments back: as mentioned in #4 me & @FeepingCreature have tried changing the architecture in a few ways to try to improve learning, and we have begun to wonder about what exactly the Loss_D means.
In particular, compared to IllustrationGAN and StackGAN, WGAN struggles to handle 128px resolution and global coherency (eg in anime faces, severe heterochromia - the convolutions apparently decide individually on plausible yet mismatched eye colors). If we add on, using --n_extra_layers
, 3 or 6 additional convolution layers for a 64px anime face model, training struggles even more and 6 is worse than 3 which is worse than none. (Albeit hilarious looking.) When we add in 1-4 fully-connected layers at the beginning of the generator like this:
+ nn.ConvTranspose2d(nz, nz, 1, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
WGAN tends to get better and more globally coherent (but we still haven't matched IllustrationGAN/StackGAN). My interpretation is that the fully-connected layers are transforming the latent-z/noise into a sort of global template which the subsequent convolution layers can then fill in more locally. For symmetry, feep also tried adding additional layers to the discriminator with their associated batchnorms and changing the penultimate layers.
One interesting thing about adding convolution+batchnorm or fully connected layers to 128px is that it makes Loss_D
wildly different. This isn't too surprising since it's noted that Loss_D isn't directly comparable across different models, but what's a little more unpleasantly surprising is that it seems to affect training. If we add in more FC layers, the Loss_D can start at -300 or higher and trains somewhat normally there but with considerable difficulty with lots of oscillations in loss & sample appearance overnight. While when we made the discriminator changes, suddenly the starting Loss_D is -0.0003 and it makes no progress beyond static overnight. Increasing --Diters
dramatically to 25 or 50 or higher, as recommended in #2, does not fix this at all. Another change that dramatically affects the loss scale is changing the clipping weights: increasing --clip_lower
/--clip_upper
to 0.05 will push the loss into the 100s.
I began thinking about what the Wasserstein distance/Loss_D means and wondered: it defines the loss for the gradients, right? Literally? So wouldn't -300 represent absurdly large gradients, and -0.0003 represent almost no gradients at all? And the former would be only ~10x what reasonable Loss_Ds are (1-10) while the latter represents gradients 10,000x too small, explaining why they worked so differently. Plus, isn't the effect of batchnorm to rescale outputs to N(0,1), so a batchnorm near the top of the discriminator would tend to make its outputs small, while on the other hand adding additional layers can make the final numbers much larger because they can vary over a wider range?
I tried multiplying one
/mone
in main.py
by 1000 to see if that helped but it didn't really, so I switched to changing the learning rates. For the -0.0003 WGAN, I found that a lrD/G of 0.001 (20x higher than the default 0.00005; I assume that the reason it's not 10,000x higher has to do with the exact definition in the gradient descent formula) made steady progress and has yielded respectable 128px faces overnight.
If it was merely the case that we aren't training the discriminator adequately, then my first fix of increasing --Diters
should've worked. And also if that was the case, then the -0.0003 WGAN should've worked brilliantly from the start - after all, it had a tiny Loss_D, always, which should indicate that it's an awesome discriminator/generator. And further if that was the case, then increasing the learning rate so much should have broken it, not made it work.
So it seems like the absolute size of Loss_D matters somehow to training, and there's a sweet spot for a model with lr=0.00005 to start with a Loss_D of 1-10 magnitude - smaller, and it doesn't learn at a reasonable speed, larger, and it oscillates wildly without making progress (learning too much from each minibatch?).
There doesn't seem to be any important meaning to the Loss_D being 3 rather than 300 or 0.3, it's just there to provide gradients for the generator, right? It's a weak distance. So maybe a robust WGAN would do something like take the Loss_D in the first minibatch after training the discriminator to optimality, and then rescale it to 1-10 thereafter? (ie something like batchnorm for the final layer/output - although I don't know if it's better for the loss to be scaled every minibatch like in batchnorm, because then you wouldn't see it dropping as progress is made.)
from wassersteingan.
Protip: something else you can do is detect when the curve is making a big jump like this and keep iterating the critic till the curve goes up again to roughly were it was before the jump.
We hope to have a more precise "convergence criterion" for the critic further on so that this is no longer an issue.
from wassersteingan.
Very useful perspective here. I'm finding that my generator loss decreases over time but the sign goes from positive to negative. Do you have intuitions about how to interpret the sign?
from wassersteingan.
did you by any chance crank up the learning rate? or change the number of D iterations per G iteration? if either of these is true, the wasserstein approximation might be off.
from wassersteingan.
hey @KingStorm Yea I got RNN generator and RNN Discriminator to work but they just don't perform as well as conv1d -- the trick was to use the improved WGAN paper that was just released. I just don't know what the problem is, but they just don't converge as well as conv1d (which is a good step but not that good).
from wassersteingan.
Yes I did every combination: lstm generator and lstm discriminator with conv generator and conv discriminator. Unfortunately conv generator and conv discriminator has yielded best results so far. I tried densenets as well and they did worse. Also tried fractal discriminators but those did worse. If one of the authors wants to help with this discussion for more ideas, I'm for trying anything. Really not sure why LSTM's don't converge better.
from wassersteingan.
@thvasilo hey theodore! I didn't know you were interested in GAN paixnidia ;-)
from wassersteingan.
@thvasilo I would point you to this repo -- has much more advanced techniques already working :)
https://github.com/amirbar/rnn.wgan/
from wassersteingan.
Yes I did, the learning rate was set to 5e-4 for both D and G.
Are there any theoretical or practical bounds on these? (I'll check the paper as well)
from wassersteingan.
This type of curve is a strong sign that the discriminator isn't trained till optimality (and therefore it's error doesn't correspond to the performance of the generator). That can be for a number of reasons:
- Momentum on the discriminator.
- High learning rates.
- Low number of disc iterations.
There can be other reasons that are problem dependent, but the general principles is that anything that helps have the disc trained till optimality is going to help with this issue.
from wassersteingan.
Loving the protip and the quick responses!
Gonna give it a shot, thanks!
from wassersteingan.
Lukas, did you try Martin's suggestion to create heuristics for the number of iterations of the Discriminator? I'm having a similar issue
from wassersteingan.
@rafaelvalle Yes, I have tried increasing the number of iterations for the discriminator.
What I found was that the increase leads to
a) a much longer runtime
b) no visible decrease in lossD (after 10000 iterations)
Ok I am training this on a different dataset, but still, if you look at the paper, we're talking 100k generator iterations, that's 500k discriminator iterations with default values. Depending on your hardware, that might take a while 'til you start seeing results.
Nevertheless increasing learning rate and not worrying about sudden drops gave me best results in short time, but have not checked for mode collapse etc.
from wassersteingan.
from wassersteingan.
@rafaelvalle is there something specific about RNNs that don't agree with WGANs?
I'm using LSTMs to generate sequences with a WGAN loss. My model is learning to decrease d_loss
rapidly, with improvement in generated samples until the point when it results in something like mode collapse.
I'd like to diagnose the issue correctly. Any insights? Thanks!
from wassersteingan.
@pavanramkumar Assuming that a Recurrent GAN is harder to train than DCGAN, mode collapse can come from the critic "not being trained till optimality, which makes it impossible to collapse modes. "
Also, covariate shift might be an issue.
from wassersteingan.
@rafaelvalle thanks!
increasing d_iters
beyond 100 does not change the learning curve much, so i'm wondering what if any are other metrics of optimal critic training.
i think covariate shift is likely due to several deterministic layers. we've used batch norm throughout, but it's probably inadequate for the covariate shift that our deterministic layers have created.
from wassersteingan.
@pavanramkumar is your discriminator a RNN as well? Have you considered switching the disc/critic to a CNN?
from wassersteingan.
@NickShahML a lot of our mode collapse problems went away when we stopped using LSTMs in both the generator and the discriminator. Tuning the relative complexity and regularization in both G and D is still a black art though!
from wassersteingan.
@pavanramkumar were you able to train a RNN generator with adversarial loss?
from wassersteingan.
@pavanramkumar @rafaelvalle ... as a follow up what exactly are passing to the discriminator from the generator? Do you pass the softmaxed logits?
Supposed we are softmaxing over 100 different chars for each timestep
Generator --> Logits shape [batch size, timesteps, 100] --> Discriminator
Trying to figure out the best way to pass outputs from the generator to the discriminator to lessen complexity of the task.
from wassersteingan.
Quick thing, it is possible that the clipping in RNNs creates a lot of vanishing gradients. I've heard that using layer normalization is aleviates this on wgan-rnns so it's a good thing to try.
I will look at this more in depth next week when I have more time.
from wassersteingan.
@martinarjovsky , I'm only clipping on the discriminator which is a stacked CNN. I'm specifically testing the difference between RNN and CNN in the generator. Are we supposed to clip in the generator as well? That doesn't make too much sense to me. If you're talking specifically about the discriminator, then I would understand.
from wassersteingan.
Ah I see. Nono don't clip on the generator :).
OK will take a look.
from wassersteingan.
Yes, we generate softmax over the vocabulary (batch_size
, n_symbols
, n_sequence_length
) to feed into the discriminator, in addition to several additional tricks that may not be directly relevant to your problem.
We also observed that increasing the clipping threshold in the discriminator led to divergence (the generator produces random samples but doesn't really learn to move towards the data).
Using dense layers in both discriminator and generator helped alleviate both mode collapse and divergence.
from wassersteingan.
Thanks @pavanramkumar -- I'll try adding more dense layers and see if that helps resolve issues. I'll also keep the clipping at [-0.01,0.01]
. So to clarify, did you actually get a RNN generator to work? Or did you stick with CNN's for both the generator and discriminator.
from wassersteingan.
@NickShahML you're welcome. we just used dense layers for both. In fact, some keras code is here if it helps: https://github.com/tree-gan/BonsaiNet
We hope to improve our README soon.
from wassersteingan.
Oh wow I didn't realize you literally only used Dense layers. I'm still working on correcting the negative loss for the discriminator. But definitely good to know that Dense layers helped you reach convergence on your task.
from wassersteingan.
@NickShahML Hi nick, I am also using RNN as generator with CNN as critic, the D_loss seems weird to fluctuate a lot. Did you by any chance make your RNN generator work?
from wassersteingan.
@NickShahML Thanks for the reply :). May I ask have you ever tried to use conv net as discriminator and RNN as generator while using "Improved Training of WGAN". Tensorflow seems have some problems in calculating second-order derivatives of LSTM.
from wassersteingan.
@NickShahML Do you have a repo with your experiments? I'd like to try out some of these combinations as well.
from wassersteingan.
Hey guys, @NickShahML @KingStorm @pavanramkumar have you guys been successful in the end?
I also face the same problem as you guys. My normal RNN GAN (Generator Dense+LSTM layer and Discriminator conv1D) has the problem of mode collapse.
I switched to WassersteinGAN but as you guys described, the loss is going crazy now. (Anyways having better results than before).
So did "Improved Training of WGAN" really "improved" your results? Did the loss got more meaningful?
from wassersteingan.
I'm wondering if your iterating in all your samples inside each of D iteration or you are doing D iteration on on minibatch sample?
from wassersteingan.
@gwern
I'm doing some works similar as you did, and I found that the wgan loss seems very sensitive to the network structures, input shape, etc.
As you mentioned, the Wasserstein distance could be very small values (about 1e-4), and same phenomenon occured in my experiments. I agree with your gradient's theory in WD definition under my experiment observations, but it seems very hard (or unable) to train a GAN with a very small WD at the begining of training, in fact, the model suffer from gradients vanishing in my experiments under this situation. On the contrast, the GAN is much easy to train with a large WD at the begining. So I think the scale of WD is not much important but we should expect it is a large values (maybe 1e-1~1e2) rather than (~1e-4) in our GAN.
from wassersteingan.
Related Issues (20)
- After training the model, how to generate Test samples using the generator?
- Where can I find bibtex of Wasserstein GAN and related works? HOT 1
- Why have a tensor of 1 or -1 in loss.backward()? HOT 1
- Problems with the optimization of loss. HOT 4
- cifar10 result not good as expect ! HOT 7
- how to train a 256*128 image dataset and output 256*128 result? HOT 3
- Results on cifar10 very bad even if trained for over 1000 epochs HOT 1
- The parameter ‘db_path' of LSUN setting in 'main.py' should be changed to 'root'
- Results cannot be reproduced. HOT 2
- module name can\'t contain "." HOT 1
- Inconsistent loss function from the paper? HOT 8
- No convergence in onw dataset
- No sigmoid activation for G on MLP?
- should the gamma and beta on batchnormalization layer be clipped?
- some problem when running the WassersteinGAN HOT 1
- Interpreting Generator and Critic loss HOT 1
- How can I use a loss as the stopping criteria in Wasserstein GAN?
- Why did not tell the label to the discriminator
- Generator update HOT 1
- I cannot find the calculating or estimating of wasserstein distance! HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from wassersteingan.