Giter VIP home page Giter VIP logo

deformingautoencoders-pytorch's Introduction

DeformingAutoencoders-pytorch

Pytorch code for DAE and IntrinsicDAE

Project:

http://www3.cs.stonybrook.edu/~cvl/dae.html

Usage:

Requirements: PyTorch

To train a DAE, run

python train_DAE_CelebA.py --dirDataroot=[path_to_root_of_training_data] --dirCheckpoints=[path_to_checkpoints] --dirImageoutput=[path_to_output directory for training] --dirTestingoutput=[path_to_output directory for testing]

To train an IntrinsicDAE, run

python train_IntrinsicDAE_CelebA.py --dirDataroot=[path_to_root_of_training_data] --dirCheckpoints=[path_to_checkpoints] --dirImageoutput=[path_to_output directory for training] --dirTestingoutput=[path_to_output directory for testing]

set --useDense=True (default) for DenseNet-like encoder/decoder (no skip connections over the bottleneck latent representations); --useDense=False for a smaller encoder-decoder architecture.

Dataset: CelebA (http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) A google drive link to a cropped and resized version of CelebA: https://drive.google.com/open?id=1ueB8BJxid2rZbvh3RaoZ9lDdlKH4B-pL

Place the training images in [path_to_root_of_training_data]/celeba_split/img_00 (Split the dataset into multiple subsets if wanted.)

Checkpoints: Some example checkpoints can be found at: https://drive.google.com/drive/folders/1A2Qj1NhzVU5XSjeilKhjWwAgNWvlRyuA?usp=sharing

Three examples are provided:

  1. DAE for CelebA with Dense encoder decoder, where opt.idim = 8 (./DAE_CelebA_idim8)
  2. DAE for CelebA with Dense encoder decoder, where opt.idim = 16 (./DAE_CelebA_idim16)
  3. IntrinsicDAE for CelebA with Dense encoder decoder. (./IntrinsicDAE_CelebA)

If using the code, please cite:

Deforming Autoencoders: Unsupervised Disentangling of Shape and Appearance, Zhixin Shu, Mihir Sahasrabudhe, Riza Alp Guler, Dimitris Samaras, Nikos Paragios, and Iasonas Kokkinos. European Conference on Computer Vision (ECCV), 2018.

Update 12-13-2018

  1. Previous models with batch normalization suffer from data batch bias in testing. Replacing all nn.BatchNorm2d() layers (DAENet.py) with nn.InstanceNorm2d() layers (DAENet_InstanceNorm.py)would fix the problem in testing time.
  2. Fixed bug on getBasedGrid(): previous code contains bug on integer->float convertion.

deformingautoencoders-pytorch's People

Contributors

zhixinshu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

deformingautoencoders-pytorch's Issues

The models trained on 128x128 dataset

Hi, @zhixinshu
I want to train the model with 128x128 input on CelebA dataset.
Thus, i add additional layers in both encoder and decoder.
The codes are as follows:

class waspDenseEncoder(nn.Module):
    def __init__(self, opt, ngpu=1, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Sigmoid, f_args=[]):
        super(waspDenseEncoder, self).__init__()
        self.ngpu = ngpu
        self.ndim = ndim

        self.main = nn.Sequential(
                # input is (nc) x 128 x 128
                nn.BatchNorm2d(nc),
                nn.ReLU(True),
                nn.Conv2d(nc, ndf, 4, stride=2, padding=1),

                # state size. (ndf) x 64 x 64
                DenseBlockEncoder(ndf, 6),
                DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args),

                # state size. (ndf*2) x 32 x 32
                DenseBlockEncoder(ndf*2, 12),
                DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args),

                # state size. (ndf*4) x 16 x 16
                DenseBlockEncoder(ndf*4, 24),
                DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args),

                # state size. (ndf*8) x 8 x 8
                DenseBlockEncoder(ndf*8, 24),
                DenseTransitionBlockEncoder(ndf*8, ndf*16, 2, activation=activation, args=args),
                
                # state size. (ndf*16) x 4 x 4
                DenseBlockEncoder(ndf*16, 16),
                DenseTransitionBlockEncoder(ndf*16, ndim, 4, activation=activation, args=args),
                f_activation(*f_args),
        )

    def forward(self, input):
        output = self.main(input).view(-1,self.ndim)
        return output   

class waspDenseDecoder(nn.Module):
    def __init__(self, opt, ngpu=1, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, f_args=[0,1]):
        super(waspDenseDecoder, self).__init__()
        self.ngpu   = ngpu
        self.main   = nn.Sequential(
            # input is Z, going into convolution
            nn.BatchNorm2d(nz),
            activation(*args),
            nn.ConvTranspose2d(nz, ngf * 16, 4, 1, 0, bias=False),

            # state size. (ngf*16) x 4 x 4
            DenseBlockDecoder(ngf*16, 16),
            DenseTransitionBlockDecoder(ngf*16, ngf*8),
            
            # state size. (ngf*8) x 8 x 8
            DenseBlockDecoder(ngf*8, 24),
            DenseTransitionBlockDecoder(ngf*8, ngf*4),

            # state size. (ngf*4) x 16 x 16
            DenseBlockDecoder(ngf*4, 24),
            DenseTransitionBlockDecoder(ngf*4, ngf*2),

            # state size. (ngf*2) x 32 x 32
            DenseBlockDecoder(ngf*2, 12),
            DenseTransitionBlockDecoder(ngf*2, ngf),

            # state size. (ngf) x 64 x 64
            DenseBlockDecoder(ngf, 6),
            DenseTransitionBlockDecoder(ngf, ngf),

            # state size (ngf) x 128 x 128
            nn.BatchNorm2d(ngf),
            activation(*args),
            nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
            f_activation(*f_args),
        )
    def forward(self, inputs):
        return self.main(inputs)

Although the training processing is normal, the texture is wired.
Texture:
iter_225200_tex0_
Output:
iter_225200_output0_

Could you give me some advice for training 128x128 size image?
Many thanks!

About the adversarial loss

Hi, @zhixinshu
It's really a marvelous work.

But i am still confused about the page 7 in your page. There is a adversarial loss in equation 7.
Howerver, it is not imiplemented in your code.
Could you please give a complete version of your code?

Many thanks.

Some questions about the dataset celebA

Hello,

I have some questions about the dataset you used.

I'm sorry that I am not familiar with celebA dataset.

Is the train/eval/test split you used same as the official one? http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

Did you split your training set into another subsets? How and Why?

How are the cropped and resized version you provided (https://drive.google.com/open?id=1ueB8BJxid2rZbvh3RaoZ9lDdlKH4B-pL) different from the official cropped images?

Thank you,

Best,
Ahyun

L1 and L2 losses for image reconstruction

Hello, in your papers you mention the image reconstruction loss is L2 loss, but in the code you using L1 loss. Could I ask which losses are better for your image reconstruction task?

Previous checkpoint not compatible with latest code

Hi @zhixinshu

Thanks for sharing the code. I did a clean clone and downloaded the pretrained model. However, the code is rejecting the model. It seems in the last a few commits there were some updates in the network. I am wondering if you have a new pretrained model that is compatible with the latest code. Or did I miss something here?

Thank you.

Missing key(s) in state_dict:
"encoder.main.1.layers.0.2.weight",
"encoder.main.1.layers.1.2.weight",
"encoder.main.1.layers.2.2.weight",
"encoder.main.1.layers.3.2.weight",
"encoder.main.1.layers.4.2.weight",
"encoder.main.1.layers.5.2.weight",
"encoder.main.2.main.2.weight",
"encoder.main.3.layers.6.2.weight",
"encoder.main.3.layers.7.2.weight",
"encoder.main.3.layers.8.2.weight",
"encoder.main.3.layers.9.2.weight",
"encoder.main.3.layers.10.2.weight",
"encoder.main.3.layers.11.2.weight",
"encoder.main.5.layers.12.2.weight",
"encoder.main.5.layers.13.2.weight",
"encoder.main.5.layers.14.2.weight",
"encoder.main.5.layers.15.2.weight",
"encoder.main.5.layers.16.2.weight",
"encoder.main.5.layers.17.2.weight",
"encoder.main.5.layers.18.2.weight",
"encoder.main.5.layers.19.2.weight",
"encoder.main.5.layers.20.2.weight",
"encoder.main.5.layers.21.2.weight",
"encoder.main.5.layers.22.2.weight",
"encoder.main.5.layers.23.2.weight".
Unexpected key(s) in state_dict:
"encoder.main.10.main.0.weight",
"encoder.main.10.main.0.bias",
"encoder.main.10.main.0.running_mean",
"encoder.main.10.main.0.running_var",
"encoder.main.10.main.2.weight",
"encoder.main.0.running_mean",
"encoder.main.0.running_var",
"encoder.main.2.weight",
"encoder.main.2.bias",
"encoder.main.3.layers.0.0.weight",
"encoder.main.3.layers.0.0.bias",
"encoder.main.3.layers.1.0.weight",
"encoder.main.3.layers.1.0.bias",
"encoder.main.3.layers.2.0.weight",
"encoder.main.3.layers.2.0.bias",
"encoder.main.3.layers.3.0.weight",
"encoder.main.3.layers.3.0.bias",
"encoder.main.3.layers.4.0.weight",
"encoder.main.3.layers.4.0.bias",
"encoder.main.3.layers.5.0.weight",
"encoder.main.3.layers.5.0.bias",
"encoder.main.4.main.0.weight",
"encoder.main.4.main.0.bias",
"encoder.main.5.layers.0.0.weight",
"encoder.main.5.layers.0.0.bias",
"encoder.main.5.layers.1.0.weight",
"encoder.main.5.layers.1.0.bias",
"encoder.main.5.layers.2.0.weight",
"encoder.main.5.layers.2.0.bias",
"encoder.main.5.layers.3.0.weight",
"encoder.main.5.layers.3.0.bias",
"encoder.main.5.layers.4.0.weight",
"encoder.main.5.layers.4.0.bias",
"encoder.main.5.layers.5.0.weight",
"encoder.main.5.layers.5.0.bias",
"encoder.main.5.layers.6.0.weight",
"encoder.main.5.layers.6.0.bias",
"encoder.main.5.layers.7.0.weight",
"encoder.main.5.layers.7.0.bias",
"encoder.main.5.layers.8.0.weight",
"encoder.main.5.layers.8.0.bias",
"encoder.main.5.layers.9.0.weight",
"encoder.main.5.layers.9.0.bias",
"encoder.main.5.layers.10.0.weight",
"encoder.main.5.layers.10.0.bias",
"encoder.main.5.layers.11.0.weight",
"encoder.main.5.layers.11.0.bias",
"encoder.main.6.main.0.weight",
"encoder.main.6.main.0.bias",
"encoder.main.7.layers.16.0.weight",
.....

size mismatch for encoder.main.0.weight: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([32, 3, 4, 4]).
size mismatch for encoder.main.0.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([32]).
size mismatch for encoder.main.3.layers.0.0.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for encoder.main.3.layers.0.0.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for encoder.main.3.layers.0.2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for encoder.main.3.layers.1.0.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for encoder.main.3.layers.1.0.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for encoder.main.3.layers.1.2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for encoder.main.3.layers.2.0.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

.....

DenseBlockEncoder and DenseBlockDecoder seem incorrect

In https://github.com/zhixinshu/DeformingAutoencoders-pytorch/blob/master/DAENet.py#L260, the forward function of DenseBlockEncoder class is given as:

  def forward(self, inputs):
       outputs = []

       for i, layer in enumerate(self.layers):
           if i > 0:
               next_output = 0
               for no in outputs:
                   next_output = next_output + no 
               outputs.append(next_output)
           else:
               outputs.append(layer(inputs))
       return outputs[-1]

It seems the layer is only applied once. Thus, outputs[-1] will output some scalar * self.layers[0](inputs) and it does not look like a valid dense block. Same problem also occurs in DenseBlockDecoder class.

Could you kindly explain how you chose the parameter for integrator?

Hi, thank you for the helpful code.

In this following code, I wonder why you chose this parameter "1.2"? Is there any motivation?
https://github.com/zhixinshu/DeformingAutoencoders-pytorch/blob/master/DAENet.py#L484

self.warping = self.integrator(self.diffentialWarping)-1.2

Is this parameter introduced so that parts of the deformation fields could be of negative value?
And is the model performance sensitive to this number?

Thanks so much.

pretrained models

Hi,
Thanks for your amazing work. Would you like to provide some pretrained models on faces so that we can play with it directly?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.