Implement CVAE (Conditional Variational Autoencoder) and VAE (Variational Autoencoder) by tensorflow. Experiment for MNIST dataset.
This repository includes following three type of CVAE:
- 3 CNN: encoder (CNN x 3 + FC x 1) and decoder (CNN x 3 + FC x 1)
- 2 CNN: encoder (CNN x 2 + FC x 1) and decoder (CNN x 2 + FC x 1)
- 3 FC: encoder (FC x 3) and decoder (FC x 3)
At first, I have implemented 2 CNN model uses CNN with fixed stride (2 x 2) and kernel size (4 x 4). However, there are problems with this model that, the size of trainable variables of FC (fully connected) layer is much larger than that of CNN layer. More precisely, if the latent dimension is 20, the trainable variables becomes 20 x 512 = 10240 for FC and 4 x 16 x 32 = 2048 for the biggest CNN. It might be difficult to learn such a huge ./FC layer and the effect of CNN layer would be vanished. In other hand, 3CNN model has relatively small trainable variables of FC, which has 100 and CNN has 6336 for the first layer and 8192 for the second and third layer.
FC2 was implemented aiming to see how better the CNN 3 model is than FC based CVAE. But it seems that FC2 model behaves very well as generative model (I mention more in later section).
Clone the repository
git clone https://github.com/asahi417/ConditionalVariationalAutoEncoder.git cvae
cd cvae
To train 3 CNN model for MNIST,
python train.py cvae_cnn3 -n 2 -c 1 -l 0.001 -e 400 -l 0.001
then, plotting some graphs by
python plot.py cvae_cnn3 -n 2
The trained model is saved at ./log
and figures are at ./figure
.
Check python train.py -h
and python plot.py -h
to see the detail about the options.
Let's see some results for Mnist data by 3CNN model.
Here is the latent space of CVAE.
It seems in this space it is hard to distinguish each digits, and this space can be regarded as the common feature space of hand written digit.
On the other hand, the latent space of VAE is shown as below.
To find the best stride and depth of layer, simple deep CNN model for classification has been implemented.
This consists of four CNN layer, and each layer includes max pooling and dropout.
For mnist classification, this model achieves over 98 % validation accuracy.
- [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
- [2] Kingma, Diederik P., et al. "Semi-supervised learning with deep generative models." Advances in Neural Information Processing Systems. 2014.