Giter VIP home page Giter VIP logo

beta-vae's People

Contributors

1konny avatar tonymetger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

beta-vae's Issues

Training on dsprites fails with 0-dim tensor

I'm running ./run_dsprites_B_gamma100_z10.sh with a reduced number of iterations.
Error is below:

=> no checkpoint found at 'checkpoints/dsprites_B_gamma100_z10/last'
 67%|██████████████████████████████████████████████████                         | 10000/15000.0 [34:22<16:59,  4.91it/s]Traceback (most recent call last):
  File "main.py", line 69, in <module>
    main(args)
  File "main.py", line 24, in main
    net.train()
  File "/Users/rlee18/git/Beta-VAE/solver.py", line 182, in train
    self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

Disentanglement metric

Hello.

Haven't found it in your code. Is there d.metric in it somewhere?

Thanks in advance

about distanglement

Why is it that when beta is set to 4, the disentanglement effect is very poor, and the mig score is only 0.06. When it is set to 8, the effect is almost the same as others who set 4.

Couldn't run well with TypeError

/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/model.py:150: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
init.kaiming_normal(m.weight)
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
Visdom python client failed to establish socket to get messages from the server. This feature is optional and can be disabled by initializing Visdom with use_incoming_socket=False, which will prevent waiting for this request to timeout.
=> no checkpoint found at 'checkpoints/main/last'
Traceback (most recent call last):
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/main.py", line 69, in
main(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/main.py", line 21, in main
net = Solver(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/solver.py", line 140, in init
self.data_loader = return_data(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/dataset.py", line 80, in return_data
train_data = dset(**train_kwargs)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/dataset.py", line 18, in init
super(CustomImageFolder, self).init(root, transform)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 99, in init
classes, class_to_idx = find_classes(root)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 24, in find_classes
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
TypeError: argument should be string, bytes or integer, not PosixPath

TypeError: save_image() got an unexpected keyword argument 'filename'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/init.py", line 708, in _send
return self._handle_post(
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/init.py", line 677, in _handle_post
r = self.session.post(url, data=data)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 590, in post
return self.request('POST', url, data=data, json=json, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 542, in request
resp = self.send(prep, **send_kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 655, in send
r = adapter.send(request, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f23720d29d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Traceback (most recent call last):
File "main.py", line 69, in
main(args)
File "main.py", line 24, in main
net.train()
File "/home/ye/python_scripts/Beta-VAE-master/solver.py", line 201, in train
self.viz_traverse()
File "/home/ye/python_scripts/Beta-VAE-master/solver.py", line 417, in viz_traverse
save_image(tensor=gifs[i][j].cpu(),
TypeError: save_image() got an unexpected keyword argument 'filename'
1%|▏ | 10000/1500000.0 [03:54<9:41:25, 42.71it/s]

How to tune hyper-parameters

Hi WonKwang,

First, thanks for the great implementation for beta-vae.

Is there any method or intuition to choose these beta, gamma, C_max hyper-parameters?

Thanks

Loss curve

I train on a custom dataset, use model H, and set z dim to 256. When I trained 100000 steps, the loss was still as high as 400. Is this a normal phenomenon? Is the high loss caused by too large z dim?

here is my loss curve
2023-03-03 15-46-08屏幕截图

and this is my param
2023-03-03 15-58-32屏幕截图

About the quality of recon_img

Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable

def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps

def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)

class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size

def forward(self, tensor):
    return tensor.view(self.size)

class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

def __init__(self, z_dim=10, nc=3):
    super(BetaVAE_H, self).__init__()
    self.z_dim = z_dim
    self.nc = nc
    self.encoder = nn.Sequential(
        nn.Conv2d(nc, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.ReLU(True),
        nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.ReLU(True),
        nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.ReLU(True),
        nn.Conv2d(64, 64, 4, 2, 1),  # B,  64,  4,  4
        nn.ReLU(True),
        nn.Conv2d(64, 256, 4, 1),  # B, 256,  1,  1
        nn.ReLU(True),
        View((-1, 256 * 1 * 1)),  # B, 256
        nn.Linear(256, z_dim * 2),  # B, z_dim*2
    )
    self.decoder = nn.Sequential(
        nn.Linear(z_dim, 256),  # B, 256
        View((-1, 256, 1, 1)),  # B, 256,  1,  1
        nn.ReLU(True),
        nn.ConvTranspose2d(256, 64, 4),  # B,  64,  4,  4
        nn.ReLU(True),
        nn.ConvTranspose2d(64, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.ReLU(True),
        nn.ConvTranspose2d(64, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.ReLU(True),
        nn.ConvTranspose2d(32, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.ReLU(True),
        nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
    )

    self.weight_init()

def weight_init(self):
    for block in self._modules:
        for m in self._modules[block]:
            kaiming_init(m)

def forward(self, x):
    distributions = self._encode(x)
    mu = distributions[:, :self.z_dim]
    logvar = distributions[:, self.z_dim:]
    z = reparametrize(mu, logvar)
    x_recon = self._decode(z)

    return x_recon, mu, logvar

def _encode(self, x):
    return self.encoder(x)

def _decode(self, z):
    return self.decoder(z)

import torch
from torch import optim
from torch.utils.data import DataLoader

from beta_vae import BetaVAE_H
import torch.nn.functional as F
from torchvision import datasets, transforms

def recon_loss(x, x_recon):
x_recon = F.sigmoid(x_recon)
rec_loss = F.mse_loss(x_recon, x)
return rec_loss

def kld_loss(mu, logvar):
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))

klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
return total_kld

def train(epochs=1000, batch_size=128, z_dim=32, device='cuda:2', lr=1e-4, beta=10):
dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True, transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor()
]), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = BetaVAE_H(z_dim=z_dim, nc=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    loss, r_loss, k_loss = (0, 0, 0)
    for idx, (images, _) in enumerate(dataloader):
        images = images.to(device)
        x_recon, mu, logvar = model(images)
        rec_loss = recon_loss(images, x_recon)
        kl_loss = kld_loss(mu, logvar)

        beta_vae_loss = rec_loss + beta * kl_loss
        optimizer.zero_grad()
        beta_vae_loss.backward()
        loss += beta_vae_loss.item()
        r_loss += rec_loss.item()
        k_loss += kl_loss.item()
        optimizer.step()

`

How to prevent from KL loss collapse

Hello! I found your kl divergence curve is flat after about 100k iterations. When I train other VAE tasks, I found that the kl divergence is easier to converge than reconstruction loss. The training loss usually suffer from kl divergence collapse. Kl converges to a number close to zero. So I want to know how to keep KL divergence curve flat with beta>1?

Supporting image size larger than 64*64

Hello, I would like to know whether there is a specific reason for preventing image size from varying. if not so, do you know what kind of modifications would be required to support, say 512*512 images ?

Need Help on Visdom

Hi,
I rerun your implementation on my machine but it seems to not work for me. Can you please suggest how can I solve this

  1. I am going to begin with CelebA dataset
  2. I extracted the dataset to folder D:...
    Beta-VAE-master\data\CelebA\Training
  3. I deleted most of the photos to make the size manageable to see if it can run (left with around 40 phtos)
  4. install visdom, torch, torchvision
  5. initialize visdom (python -m visdom.server)
    Step 5
  6. Use my browser to access http://localhost:8097/
    Step 6   Output
  7. python main.py --dataset celeba --seed 1 --lr 1e-4 --beta1 0.9 --beta2 0.999 --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1.5e3 --beta 10 --viz_name celeba_H_beta10_z10
    Step 7

Even if there is no error shown, I there is no output shown on either http://localhost:8097/ or in the folder D:...\Beta-VAE-master\outputs\celeba_H_beta10_z10
Step 6   Output

Tensor dimension bug in visualization code

Hi
I just started using your code to play around with beta-VAE and it's great. Unfortunately, there is one bug when I try to turn on visualization:

Traceback (most recent call last):
File "main.py", line 65, in
main(args)
File "main.py", line 24, in main
net.train()
File "/home/tony/Github/Beta-VAE/solver.py", line 178, in train
self.viz_lines()
File "/home/tony/Github/Beta-VAE/solver.py", line 230, in viz_lines
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1 at /pytorch/aten/src/TH/generic/THTensorMath.c:3577

Steps to reproduce:
python -m visdom.server
sh run_dsprites_B_gamma100_z10.sh

I guess this is quite easy to fix and it doesn't appear in an earlier version (commit 66fcd41), but I'm very new to Pytorch, so it's hard for me to find a fix.

Thanks again for the code and I hope this can be fixed easily,
Tony

P.S.: Would it be possible for you to add a license so that we can use your code to benchmark other models in our research (which might get published at some point)? That would be great!

Output of decoder 2 *z_dim

Hi,

Could you please explain why the encoders output is z_dim * 2 and not just z_dim.

Thanks for the very clean repo.

self.net.train()? Is is a valid function?

While reading this code, solver.py, I found in line 431 and line 433 both functions self.net.train() and self.net.eval() are undefined functions, neither defined in model.py or pytorch API. However, no error came out when I did the training. Is there anyone who can enlighten me?
Thanks a bunch

def net_mode(self, train):
        if not isinstance(train, bool):
            raise('Only bool type is supported. True or False')

        if train:
            **self.net.train()**
        else:
            **self.net.eval()**

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.