Giter VIP home page Giter VIP logo

wgan-gaussian's Introduction

wgan-gaussian

An implementation for Wasserstein Generative Adversarial Network to generate different 5 gaussian distributions. wgan_frames

Unlike Vanilla GAN, Wasserstein GAN is able to learn all distributions in the training data. The loss function is Wasserstein distance (Earth movement distance) for minimizing the distance between the generated data and the real data.

Content:

  • WGAN Architecture
  • Generator Architecture
  • Critic Architecture
  • Wasserstein distance vs JS and KL divergences
  • Wassersteain distance as GAN loss function
  • How to train the model

Model Overview

Similar to gan we have two neural networks: generative model and discriminative model, but we called the discriminative model Critic instead of Discriminator because we use another error function. image

Generator Architecture

It consists of an input layer of 2 neurons for the z vector, 3 hidden layers of 512 neurons and an output layer of 2 neurons activation functions of the 3 hidden layers are Relus and linear for the output layer

image

Critic Architecture

it consists of an input layer of 2 neurons for the training data, 3 hidden layers of 512 neurons of Relu activation function and an output layer of 1 neuron of linear activation function

image

Wasserstein distance vs Jensen–Shannon divergence & Kullback–Leibler divergence

The Wasserstein Distance is a well-known distance metric for probability distributions. It is sometimes called EarthMover’s Distance and is studied in the field of optimal transportation. It measures the optimal cost of transporting one distribution to another (Solomon et al., 2014). Actually, even when two distributions are located in a lower dimensional manifolds without overlaps, the Wasserstein distance can still provide a meaningful and smooth representation of the distance in-between. Meanwhile, other distance functions suffer from issues related to continuity. For example, the Kullback-Leibler divergence is infinity for two fully disjoint distributions. Another example is the Jensen-Shannon, which is not differentiable for fully overlapped cases i.e.: it has a sudden jump at zero distance. Thus, only the Wasserstein distance provides a smooth measure, which makes it really helpful for stable learning. Therefore, it is predestinated to solve the stability issue, which appears in normal GANs.

Wasserstein as GAN loss function

d_loss = tf.reduce_mean(d_real - d_fake)
g_loss = tf.reduce_mean(d_fake)

We clipp the Critic weights to enforce a Lipschitz constraint.

clip_weights = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_d]

We use RMSprop optimizer instead of Adam because Adam could cause instability in the training process.

How to train the model

Write in the console python wgan.py to train the model for generating 5 Gaussian distributions. The results will be saved for each epoch in the tf_wgan_results folder

License

MIT License

Copyright (c) 2019

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

wgan-gaussian's People

Contributors

dhyaaalayed avatar

Forkers

pkulwj1994

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.