Giter VIP home page Giter VIP logo

pytorch-gan's Issues

Training ACGAN

Sorry this is not an issue. I have some questions regarding the implementation of ACGAN and training ACGAN.

  1. In your implementation the encoded label vector is multiplied with the noise vector and given as the input to the G. But shouldn't it be concatenated?
  2. The CrossEntropy loss in PyTorch already includes a softmax function. Therefore I am unclear whether a softmax function should be included in the Discriminator or not.
  3. I am unclear about when to stop training. It seems (https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) the accuracy of the D for real vs fake images will be initially high for the real images, but low for the fake images. However, when I run ACGAN both are high initially (over 95%) and then reduces to ~70%. This accuracy does not go below 60%. So do you think my training is correct? Also I am unclear about when to stop training. May be this accuracy close to 50%, or the loss of the D converges, or loss of the G converges..?

About the embedding in CGAN

Hi, I have a question about the cgan implementation.
In your code, you use nn.embedding to embed the prior labels. The problem is, when the learnable weights are not specified, the vocabulary will be randomly initialized.

In both generator and discriminator, you use two different nn.embedding, and they are initialized differently. However, when we generate a fake image, we use one embedding, but when we use discriminator to distinguish the fake image, we use another embedding. Will this have effect on the final performance?

I am not very familiar with GAN. But I just think this is strange. It's true that we still use the same labels, but the actual embeddings are different. I think using the same embedding for the discriminator and generator will be more reasonable?

CycleGan error: weight of size [64, 3, 7, 7], expected input[1, 1, 262, 262] to have 3 channels, but got 1 channels instead

I am having this error using horse2zebra, I have checked all the input images and they all have the size [3, 265, 265], so, most probable that the error is caused by G_AB(real_B). I have, however, tried cyclegan with cifar-100 and monet2photo and everything went fine. I am using PyTorch 0.4 and Python 3.6.

line 157, in <module>
    loss_id_B = criterion_identity(G_AB(real_B), real_B)
...
Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 262, 262] to have 3 channels, but got 1 channels instead

wgan retrain_graph

Running wgan.py:

Traceback (most recent call last):
  File "wgan.py", line 179, in <module>
    gen_validity.backward(valid)
  File "/home/alcaster/.pyenv/versions/ml/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/alcaster/.pyenv/versions/ml/lib/python3.6/site-packages/torch/autograd/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

My versions of packages:
torch==0.4.0
torchvision==0.2.1

WGAN implementation error

In your WGAN implementation, you have your n-critique loop around the Generator learning when it should actually be around the Discriminator's. Critique stands for Discriminator.

NameError: name 'FeatureExtractor' is not defined

When evaluating the ESRGAN and SRGAN I can see that the class FeatureExtractor() is not defined anywhere. I can see latest commit is 13 days ago, so I assume you are currently working on implementing these models?

some question about cyclegan.py

I'm a little confused about why use mse_loss in GAN loss and use L1 loss in Cycle loss and Identity loss,and i didn't find this in the paper.
And the second question is why use the fake_A_buffer.push_and_pop() ,it seems to do something like if the len<50 do nothing,and when len>50,the part of >50 do the random choice the sample? i really confused about this

About the Identity loss in cyclegan.py

The source code of Identity loss is shown below:
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)

This seems a little bit weird to me, maybe it should be:
loss_id_A = criterion_identity(G_AB(real_A), real_A)
loss_id_B = criterion_identity(G_BA(real_B), real_B)

Add GAN implementation in NLP field

Hi, thanks for every contributor!
I was wondering if this repo could add some GAN implementation in NLP field?
Is it possible to add DPGAN?

and TextGAN

I wish i could contribute but i am new to pyotch. I would be happy to see pytorch implementation of these GANs!

WGAN-GP gradient penalty not calculated correctly

The L2 norm of the gradient penalty term in WGAN-GP should be calculated across all dimensions of an image, but the current implementation calculates it across each dimension of an image separately (i.e., the absolute value of each pixel in an image is calculated).

Indeed, in the following line, gradients is a tensor of size (batch_size, nb_channels, img_width, img_height):

gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

To solve the issue, the 4-dimensional tensor containing the gradients should be flattened across the last 3 dimensions:

gradients = gradients.view(real_samples.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

UserWarning: nn.Upsample is deprecated.

Hi~

I found an error while using InfoGAN. The error is as follows. Hope to repair.

/usr/local/lib/python3.7/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
/usr/local/lib/python3.7/site-packages/torch/nn/modules/container.py:92: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  input = module(input)

Thank you for writing the code, gave me a lot of inspiration!

Link of img_align_celeba.zip is turn off by dropbox.

Hi there, thanks very much for the wonderful repo.
When I want to download the img_align_celeba.zip from Dropbox, I found the link is turned off. So can you update the link or share me a private link for downloading the dataset?
image

Thanks very much.

Classify the generated image

I run your code about cgan and the generated image looks quite perfect. However, when I tried to classified the generated image by another network (Resnet18), the predicted label is always 'eight'. Is this a common feature of cgan?

Switched index in loading dataset?

Hello, I noticed that within your implementation of pix2pix, Datasets will return image in form of like this

return {"A": img_A, "B": img_B}

But when reading it again in training loop, it was written like

real_A = Variable(batch["B"].type(Tensor))
real_B = Variable(batch["A"].type(Tensor))

This happened multiple times within pix2pix.py when loading the image. Is this switching intentional?

loss_D in SRGAN

Hello and my thanks for this great repo. I like how your code is simple and effective.

I would like to point out that you are using MSE for your Discriminator Loss instead of Binary Cross Entropy. If you have a specific reason for why you are doing that, could you share it?

wgan issuse

when I run python3 wgan.py, appears print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs,batches_done % len(dataloader), len(dataloader),d_loss.item(), gen_validity.item()))
ValueError: only one element tensors can be converted to Python scalars
how can i fix it?

Possible error in code

In line 265 the code does not look right :

code_input = Variable(FloatTensor(np.random.normal(-1, 1, (batch_size, opt.code_dim))))

instead of 'normal' shouldn't be 'uniform'?

  • Mirtha

Possible error of began

In the line 161 of the 'began' implementation, should that be...?

g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - real_imgs))

instead of

g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))

Please correct me if I miss something. Many thanks!

Possible error of relativistic gan

in https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/relativistic_gan/relativistic_gan.py

        if opt.rel_avg_gan:
            g_loss = adversarial_loss(fake_pred - real_pred.mean(0, keepdim=True), valid)
        else:
            g_loss = adversarial_loss(fake_pred - real_pred, valid)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

Is this expected? Does it look like g_loss is getting overwritten?

Issue in CycleGan-Pix2Pix while calling Discriminator()

While calling Discriminator() from models.py in cycleGan and pix2pix, I get a syntax error response as ;

Traceback (most recent call last):
  File "cyclegan.py", line 15, in <module>
    from models import *
  File "./PyTorch-GAN/implementations/cyclegan/models.py", line 167
    *discriminator_block(64, 128, 2, True),
    ^
SyntaxError: invalid syntax

This happens with both python3 and python2.7.
Looks like dereferencing does not work and I could not find a way to make it work.

Any bits of advice?

Thanks

Runtime error

An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.
ForkingPickler(file, protocol).dump(obj)

BrokenPipeError: [Errno 32] Broken pipe

Namespace(b1=0.5, b2=0.999, batch_size=1, channels=3, checkpoint_interval=-1, dataset_name='facades', decay_epoch=100, epoch=0, img_height=256, img_width=256, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=500)

Namespace(b1=0.5, b2=0.999, batch_size=1, channels=3, checkpoint_interval=-1, dataset_name='facades', decay_epoch=100, epoch=0, img_height=256, img_width=256, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=500)

Abnormal to me, this line was printed twice

Brightness problem in SRGAN

I notice that the generated images have higher brightness and more colors than the original image, or the resulting images of other approaches. What causes it?

query in Energy Based GAN (EBGAN)

Hi
Thank you for your wonderful effort in implementing so many papers.
I have a query regarding your EBGAN implementation.
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/ebgan/ebgan.py

In line 175 when you are optimizing the generator G why is the pixelwise_loss computed using gen_imgs.detach() and why not simply gen_imgs() ? If we do .detach() while updating G then the pixelwise_loss will not contribute any gradients towards optimizing the generator weights. Is it the right way to do that ?

Please clarify my doubt.
Thank You in Advance !

Problem in DCGAN

DCGAN fails learning the mnist dataset. Is there a problem in implementation.

wgan-gp

hi, i believe the implementation of wgan-gp is buggy. the interpolation in

alpha = Tensor(np.random.random(size=real_samples.shape))
uses a random number for each pixel, whereas the pseudocode in the paper says to use a random number for each example.

i believe the line should be replaced by
alpha = Tensor(np.random.random(size=(real_samples.shape[0], 1, 1, 1)))

no use auxiliary_loss in cGAN

hi, i am wonder in cgan. the auxiliary_loss is not use in optimizer_G.step(). But the CGAN can training normal and get correct result. I think, Maybe i overlook some significant details. So who can give me some tips. thanks !

AttributeError: module 'torchvision.transforms' has no attribute 'Resize'

I'm running the srgan.py implementation and receive the following error:

Namespace(b1=0.5, b2=0.999, batch_size=1, channels=3, checkpoint_interval=-1, dataset_name='img_align_celeba', decay_epoch=100, epoch=0, hr_height=256, hr_width=256, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=100)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/Username/.torch/models/vgg19-dcbb9e9d.pth
100%|███████████████████████████████████████████| 574673361/574673361 [00:34<00:00, 16895858.86it/s]
Traceback (most recent call last):
  File "srgan.py", line 104, in <module>
    lr_transforms = [   transforms.Resize((opt.hr_height//4, opt.hr_height//4), Image.BICUBIC),
AttributeError: module 'torchvision.transforms' has no attribute 'Resize'

If you need any additional information let me know...

2 small suggestions

  1. It can be summarized in the chronological order of the papers rather than the first letter of the model so that it is easier to study the related development of GAN more clearly.

  2. It would be better if you can briefly summarize the connections and differences between the models in papers.

no results of CycleGAN

I run CycleGAN following all commands the author gave. But there is empty in images and saved_models. I tested that the cyclegan.py didn't enter the 140 line "for i, batch in enumerate(dataloader):".
Any advices is appreciated.

loss_D about wgan

Hi:
Thanks for sharing your impressive repo, I found the loss_D in wan.py is :
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) ,but the wgan's paper described the loss_D as follow:
image

So is it a bug or other reasons?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.