Giter VIP home page Giter VIP logo

bagan's Introduction

BAGAN

Keras implementation of Balancing GAN (BAGAN) applied to the MNIST example.

The framework is meant as a tool for data augmentation for imbalanced image-classification datasets where some classes are under represented. The generative model applied to sample new images for the minority class is trained in three steps: a) training a preliminary autoencoder, b) initialization of the generative adversarial framework by means of the pre-trained autoencoder modules, and c) fine tuning of the generative model in adversarial mode.

Along these steps, the generative model learns from all available data including minority and majority classes. This enables the model to automatically figuring out if and which features from over-represented classes can be used to draw new images for under-represented classes. For example, when considering a traffic sign recognition problem, all warning signs share the same external triangular shape. BAGAN can easily learn the triangular shape from any warning sign in the majority classes and reuse this pattern to draw other under-represented warning signs.

The application of this approach toward fairness enhancement and bias mitigation in deep-learning AI systems is currently an active research topic.

Example results

The German Traffic Sign Recognition benchmark is an example of imbalanced dataset composed of 43 classes, where the minority class appears 210 times, whereas the majority class 2250 times.

Here we show representative sample images generated with BAGAN for the three least represented classes.

alt text

Refer to the original work (https://arxiv.org/abs/1803.09655) for a comparison to other state of the art approaches.

The code in this repository executes on the MNIST dataset. The dataset is originally balanced and, before to train BAGAN, we force class imbalance by selecting a target class and removing from the training dataset a significant portion of its instances. The following figure shows 0-image samples generated when dropping 97.5% of 0-images from the training set before training.

alt text

Running the MNIST example

This software has been tested on tensorflow-1.5.0.

To execute BAGAN for the MNIST example, run the command: ./run.sh

A directory named: res_MNIST_dmode_uniform_gmode_uniform_unbalance_0.05_epochs_150_lr_0.000050_seed_0 will be generated and results will be stored there.

After the training, you will find in that directory a set of h5 files stroging the model weights, a set of csv files storing the loss functions measured for each epoch, a set of npy files storing means and covariances distributions for the class-conditional latent-vector generator, a set of cmp_class_<X>_epoch_<Y>.png showing example images obtained during training.

The file cmp_class_<X>_epoch_<Y>.png shows images obtained when training the GAN for Y epochs and considering class X as minority class. There are three row per class: 1) real samples, 2) autoencoded reconstructed samples, 3) randomly-generated samples. Note that in BAGAN, after the initial autoencored training, the generative model is trained in adversarial mode and the autoencoder loss is no longer taken into account. Thus, during the adversarial training the autoencoded images may start to deteriorate (row 2 may no longer match row 1), whereas the generated images (row 3) will improve quality.

For more information about available execution options: python ./bagan_train.py --help

bagan's People

Contributors

cristianomalossi avatar minlee077 avatar ova-mariani avatar ptran1203 avatar stevemart avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bagan's Issues

Choice of cnn architecture

The network structure used in BAGAN paper is restnet18, but not here. How to modify the network structure to adapt to other larger data sets?

Not getting proper outputs for MNIST

I am running the same code I downloaded from the GitHub repository and I am using google colab notebook. During GAN training discriminator loss is reducing and generator loss is increasing. The images generated are not meaningful. Please help to solve this issue.

I wanted to use a custom multiclass dataset with many minority classes and want to perform augmentation for all minority classes. Please help.

Quality of the Final Classification

@stevemart @ova-mariani Great paper, thank you!

Can you provide results of "Quality of the Final Classification" not only as figures but also with just numbers? It will be cool to get results, for example, in CSV, for every cases GAN, ACGAN, BAGAN, Plain and Mirror.

Problem with initialisation of GAN.

After the line, " print("BAGAN autoenc initialized, init gan") ", when the function self.init_gan() is called, an error occurs. The error occurs because it goes into the except block of the init_gan() function, the reason being, the function " self._get_lst_bck_name () " returns None.

Is this supposed to happen? As of now, I've just called the self.backup_point(0) to initialize the weights of the GAN. Is this alright ? This is just initialisation of the GAN right?

RecursionError

Got the following error when running the script in run.sh.
RecursionError:: maximum recursion depth exceeded while getting the str of an object.

Is normalization for CIFAR10 correct?

Hi,
Thank for your great work!

I would like to correct a little bit

In BatchGenerator you normalize the images by
img = img / 255 - 0.5

but in function save_image_array() , you recover it by
img = (img * 127.5 + 127.5).astype(np.uint8)

it leads to a little difference.

  • image from generator
    image
  • image from dataset
    image

Seems like we want to normalize between -1 and 1
In BatchGenerator It should be: img = (img - 127.5) / 127.5

Thanks

After run ./run.sh, No any output

/usr/local/lib/python3.5/dist-packages/h5py/init.py:36: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type.
from ._conv import register_converters as _register_converters
Using TensorFlow backend.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
Executing BAGAN.
read input data...
WARNING:tensorflow:From /data2/CZY/data/bagan/BAGAN/rw/batch_generator_mnist.py:23: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting dataset/mnist/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting dataset/mnist/train-labels-idx1-ubyte.gz
Extracting dataset/mnist/t10k-images-idx3-ubyte.gz
Extracting dataset/mnist/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.init (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting dataset/mnist/train-images-idx3-ubyte.gz
Extracting dataset/mnist/train-labels-idx1-ubyte.gz
Extracting dataset/mnist/t10k-images-idx3-ubyte.gz
Extracting dataset/mnist/t10k-labels-idx1-ubyte.gz
input data loaded...
Extracting dataset/mnist/train-images-idx3-ubyte.gz
Extracting dataset/mnist/train-labels-idx1-ubyte.gz
Extracting dataset/mnist/t10k-images-idx3-ubyte.gz
Extracting dataset/mnist/t10k-labels-idx1-ubyte.gz
Required GAN for class 0
Class counters: [309, 6179, 5470, 5638, 5307, 4987, 5417, 5715, 5389, 5454]
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py:1290: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py:1205: calling reduce_prod (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
uratio set to: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
dratio set to: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
gratio set to: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
BAGAN init_autoenc
BAGAN: training autoencoder
2019-01-09 13:45:09.014380: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-01-09 13:45:11.457403: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1344] Found device 0 with properties:
name: GeForce GTX 1080 Ti major: 6 minor: 1 memoryClockRate(GHz): 1.582
pciBusID: 0000:03:00.0
totalMemory: 10.91GiB freeMemory: 10.75GiB
2019-01-09 13:45:11.457487: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1423] Adding visible gpu devices: 0
2019-01-09 13:45:11.812727: I tensorflow/core/common_runtime/gpu/gpu_device.cc:911] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-01-09 13:45:11.812780: I tensorflow/core/common_runtime/gpu/gpu_device.cc:917] 0
2019-01-09 13:45:11.812793: I tensorflow/core/common_runtime/gpu/gpu_device.cc:930] 0: N
2019-01-09 13:45:11.814107: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1041] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10409 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:03:00.0, compute capability: 6.1)

Just waiting?

Output of BAGAN for CIFAR-10 is just solid colors.

@stevemart @ova-mariani
plot_class_0_epoch_0

I've tried to modify the BAGAN architecture for CIFAR-10, but unfortunately, the output of the BAGAN is random. I even used the same learning rate as mentioned in ACGAN. I've tried out different learning rates as well but to no avail.

I have made the requisite changes in all the related files to enable BAGAN to be trained on a multi-channel input. Could you guys mention the hyper-parameters that you used for CIFAR-10 so that I may be able to get to the bottom of this anomaly?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.