Giter VIP home page Giter VIP logo

Comments (17)

rafaelvalle avatar rafaelvalle commented on May 30, 2024 4

Here's a summary of the thread and conclusions for future visitors.

Crepe Network with WGAN presented mode collapse when using a conditional model, where the noise and the condition were concatenated at the first layer. This was happening most likely because the generator was ignoring the noise distribution and mainly using the labels. The problem was solved by passing the noise through a dense layer first and then concatenating the output of this dense layer with the labels

Crepe Network with WGAN did not converge if batch norm was applied to it. Xiang Zhang, the main author on the Crepe Network paper, explained that zero-padding the last features in length dimension have close to zero variance, which makes the gradients blow when backpropagated.

from wassersteingan.

omair-kg avatar omair-kg commented on May 30, 2024 1

@martinarjovsky hi, you mentioned above that the loss for the critic should never be negative. In case it is going to negative, what is the counter measure? Increase the iterations for the critic?

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 30, 2024

Interesting. A lot of weird things can happen when switching to rnns. Can I take a look at the architecture and training curves? Are you using batchnorm / layer normalization?

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

@martinarjovsky Yes, but note that I am using CNNs for sequence generation.
I'm running this file https://github.com/rafaelvalle/neural_network_control_improvisation/blob/master/wcgan_text.py and the models are in https://github.com/rafaelvalle/neural_network_control_improvisation/blob/master/models.py

Critic is a Crepe CNN and the generator is a simple architecture with 2 transposed convs.
Thanks for the help!

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

Generator

        layer = InputLayer(shape=(None, noise_size), input_var=input_var)
        cond_in = InputLayer(shape=(None, n_conds), input_var=cond_var)
        layer = concat([layer, cond_in])
        # fully-connected layer
        layer = batch_norm(DenseLayer(layer, 1024))
        # fully-connected layer
        layer = batch_norm(DenseLayer(layer, 1024))
        # project and reshape
        layer = batch_norm(DenseLayer(layer, 128*34*34))
        layer = ReshapeLayer(layer, ([0], 128, 34, 34))
        # two fractional-stride convolutions
        layer = batch_norm(Deconv2DLayer(
            layer, 128, 5, stride=2, crop='same', b=None, nonlinearity=lrelu))
        layer = Deconv2DLayer(
            layer, 1, 6, stride=2, crop='full', b=None,
            nonlinearity=tanh_temperature)

Critic

        # CREPE
        layer = InputLayer(shape=(None, 1, 128, 128), input_var=input_var,
                                       name='d_in_data')
        # form words from sequence of characters
        layer = batch_norm(Conv2DLayer(layer, 1024, (7, 128), nonlinearity=lrelu))
        layer = MaxPool2DLayer(layer, (3, 1))
        # temporal convolution, 7-gram
        layer = batch_norm(Conv2DLayer(layer, 512, (7, 1), nonlinearity=lrelu))
        layer = MaxPool2DLayer(layer, (3, 1))
        # temporal convolution, 3-gram
        layer = Conv2DLayer(layer, 256, (3, 1), nonlinearity=lrelu)
        layer = Conv2DLayer(layer, 256, (3, 1), nonlinearity=lrelu)
        layer = Conv2DLayer(layer, 256, (3, 1), nonlinearity=lrelu)
        layer = Conv2DLayer(layer, 256, (3, 1), nonlinearity=lrelu)
        layer = flatten(layer)
        # fully-connected layers
        layer = dropout(DenseLayer(layer, 1024, nonlinearity=rectify))
        layer = dropout(DenseLayer(layer, 1024, nonlinearity=rectify))
        # Condition 1-hot
        layer_cond = InputLayer(shape=(None, n_conds), input_var=cond_var,
                       name='d_in_condition')
        layer_cond = batch_norm(DenseLayer(layer_cond, 1024, nonlinearity=lrelu))
        layer = concat([layer, layer_cond])
        layer = DenseLayer(layer, 1, nonlinearity=None, b=None)

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

These are results from the training that produced the mode collapsed text above, with batchnorm on the generator only and learning rate 5e-5.
g_updates30

These are results with batch norm on generator and critic with learning rate 1e-5. Produced text looks like garbage
g_updates22

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

I suspect that the generator is ignoring the noise input and relying on conditions to generate output. I'll modify the model and post the results.

from wassersteingan.

NickShahML avatar NickShahML commented on May 30, 2024

@rafaelvalle I've experimented with wgan's and text generation myself. For me, the discriminator stays at a loss of -8.2 and the generator stays at a loss of -0.00012 (it changes a little bit but stays in that range). I think the discriminator is getting stuck, which seems to be your problem as well. I've tried annealing the learn rates but didn't have an effect.

I'm considering different discriminator architectures and I'm convinced that they need to be CNN's. (RNN's have a really hard time with gradients and don't converge for me).

Also how are you actually generating your characters? I haven't had a chance to look at your code, but do you generate softmaxed probability distributions and and pass those to the discriminator?

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

@NickShahML I think our problems are different. I've had generators able to produce word that have high probability in my data. However, the generator has one single mode for each label and I suspect this is because the generator is ignoring the noise input and relying on the labels only: I concatenate the noise and the labels at the first layers. I'm slightly changing the architecture such that the noise goes through a fully connected dense layer before being concatenated with the labels. Characters are generated with argmax post-processing.

On another note, it really impresses me that the model does not learn any words if the discriminator has batch norm. This was made without conditioning and with batch norm on critic and generator.
g_updates99

from wassersteingan.

NickShahML avatar NickShahML commented on May 30, 2024

@rafaelvalle if I use a generator RNN without batch norm on the critic the model does not converge. The discriminator basically overpower's the generator. So in other words, I have to use batch norm on the critic to get anywhere at all. You may be right: You have to use a convolutional generator to actually reach convergence.

from wassersteingan.

martinarjovsky avatar martinarjovsky commented on May 30, 2024

I will have a deeper look at all of this when I have more time next week, but the general observation from the curves is that the critic is not well trained till optimality ( - Loss of the critic should never be negative, since outputing 0 would yeald a better loss so this is a huge red flag).

@NickShahML there is no overpowering. What you are seeing is that without batchnorm on the critic the critic is not trained till optimality, therefore the loss ceases to be a good wasserstein estimate, the numbers mean very little and the updates to the generator are crap.

For now batchnorm seems to be important on the critic to get it close to optimality in a few iterations. We've stressed this in other issues as well.

from wassersteingan.

NickShahML avatar NickShahML commented on May 30, 2024

@martinarjovsky Thank you for that clarification. I didn't realize that negative losses in the critic are a huge red flag. Can you clarify what this indicates? I'll investigate why this is occurring.

If the generator loss is also negative, is that a red flag? I would imagine that a discriminator raw logit could be negative or positive (since we are clipping [-0.01, 0.01]). Thus the loss of the generator could potentially be negative.

self._g._loss = tf.reduce_mean(d._fake_logits_evaluation)

The raw outputs of the discriminator can be negative meaning that the generator's loss could be negative itself. Let me know if this is incorrect. Really appreciate your help!

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

Here are some plots of the generated data with and without batch norm on the Crepe critic.
Note that without batch norm, there is some change in the generator, suggesting that it's learning something; with batch norm there's little to no change in the output of the generator.

Loss Curve without Batch Norm
Loss Curve Batch Norm
Generator Output, without Batch Norm 0th, 5th and 10th iterations
No batch 0th iteration
No batch norm 5th iteration
No batch 10th iteration

Loss Curve with Batch Norm
Loss Curve with Batch Norm
Generator Output, with Batch Norm 0th, 5th and 10th iterations
Batch Norm 0th iteration
Batch Norm 5th iteration
Batch Norm 5th iteration

from wassersteingan.

NickShahML avatar NickShahML commented on May 30, 2024

@rafaelvalle can you quickly report what the generator and discriminator architectures were? RNN or CNN or Dense?

I've used CNN for both generator and critic and I"m still not reaching convergence. The Critic immediately shoots down to a negative -5.0 loss and the generator goes to loss of 0.0

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

Both critic and generator are CNN. I asked Xiang Zhang, the main author on the Crepe Network paper, about batch norm in Crepe networks. He explained that he zero-padded the data such that it would fit the specs of the input layer of the models he was training. This zero-padding makes the last features in length dimension have close to zero variance, which makes the gradients blow when backpropagated.
Does anyone think that noise-padding instead of zero-padding would allow the use of batch-norm and not significantly perturb training?

from wassersteingan.

NickShahML avatar NickShahML commented on May 30, 2024

@rafaelvalle I understand that you're trying to make this WGAN work but why are you trying to make the Crepe CNN work? Sure it may have some great features, but first try to make a regular convolutional network work first for your task. If that works, then I would try Crepe Networks. There's also different architectures to explore such as bytenet and dense CNN nets.

from wassersteingan.

rafaelvalle avatar rafaelvalle commented on May 30, 2024

Temporal CNNs, of which the Crepe network is a type, have been used for text encoding with adversarial training.
I'm comparing the performance of different pairs of critic/generator networks: naive CNN, DCGAN, Crepe, LSTM, CNN-RNN...

Learning Deep Representations of Fine-Grained Visual Descriptions
https://arxiv.org/pdf/1605.05396.pdf

Generative Adversarial Text to Image Synthesis
https://arxiv.org/pdf/1605.05396.pdf

from wassersteingan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.