Giter VIP home page Giter VIP logo

conditionalvariationalautoencoder's Introduction

Conditional Variational Autoencoder

dep1 license

Implement CVAE (Conditional Variational Autoencoder) and VAE (Variational Autoencoder) by tensorflow. Experiment for MNIST dataset.

Model

This repository includes following three type of CVAE:

  1. 3 CNN: encoder (CNN x 3 + FC x 1) and decoder (CNN x 3 + FC x 1)
  2. 2 CNN: encoder (CNN x 2 + FC x 1) and decoder (CNN x 2 + FC x 1)
  3. 3 FC: encoder (FC x 3) and decoder (FC x 3)

At first, I have implemented 2 CNN model uses CNN with fixed stride (2 x 2) and kernel size (4 x 4). However, there are problems with this model that, the size of trainable variables of FC (fully connected) layer is much larger than that of CNN layer. More precisely, if the latent dimension is 20, the trainable variables becomes 20 x 512 = 10240 for FC and 4 x 16 x 32 = 2048 for the biggest CNN. It might be difficult to learn such a huge ./FC layer and the effect of CNN layer would be vanished. In other hand, 3CNN model has relatively small trainable variables of FC, which has 100 and CNN has 6336 for the first layer and 8192 for the second and third layer.

FC2 was implemented aiming to see how better the CNN 3 model is than FC based CVAE. But it seems that FC2 model behaves very well as generative model (I mention more in later section).

How to use

Clone the repository

git clone https://github.com/asahi417/ConditionalVariationalAutoEncoder.git cvae
cd cvae                                                       

To train 3 CNN model for MNIST,

python train.py cvae_cnn3 -n 2 -c 1 -l 0.001 -e 400 -l 0.001

then, plotting some graphs by

python plot.py cvae_cnn3 -n 2

The trained model is saved at ./log and figures are at ./figure. Check python train.py -h and python plot.py -h to see the detail about the options.

Result of Each Model for Mnist

Let's see some results for Mnist data by 3CNN model.

Reconstruction


reconstruction (3CNN)

Generate by Random Variable


Generated digit

2-D Latent Space

Here is the latent space of CVAE.


2d latent space (CVAE)

It seems in this space it is hard to distinguish each digits, and this space can be regarded as the common feature space of hand written digit.

On the other hand, the latent space of VAE is shown as below.


2d latent space (VAE)

Appendix

To find the best stride and depth of layer, simple deep CNN model for classification has been implemented. This consists of four CNN layer, and each layer includes max pooling and dropout.
For mnist classification, this model achieves over 98 % validation accuracy.


learning log

Reference

  • [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
  • [2] Kingma, Diederik P., et al. "Semi-supervised learning with deep generative models." Advances in Neural Information Processing Systems. 2014.

conditionalvariationalautoencoder's People

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.