Giter VIP home page Giter VIP logo

causalgan's Introduction

CausalGAN/CausalBEGAN in Tensorflow

Tensorflow implementation of CausalGAN: Learning Causal Implicit Generative Models with Adversarial Training

Top: Random samples from do(Bald=1); Bottom: Random samples from cond(Bald=1)

alt text

Top: Random samples from do(Mustache=1); Bottom: Random samples from cond(Mustache=1)

alt text

Requirements

Getting Started

First download CelebA datasets with:

$ apt-get install p7zip-full # ubuntu
$ brew install p7zip # Mac
$ pip install tqdm
$ python download.py

Usage

The CausalGAN/CausalBEGAN code factorizes into two components, which can be trained or loaded independently: the causal_controller module specifies the model which learns a causal generative model over labels, and the causal_dcgan or causal_began modules learn a GAN over images given those labels. We denote training the causal controller over labels as "pretraining" (--is_pretrain=True), and training a GAN over images given labels as "training" (--is_train=True)

To train a causal implicit model over labels and then over the image given the labels use

$ python main.py --causal_model big_causal_graph --is_pretrain True --model_type began --is_train True

where "big_causal_graph" is one of the causal graphs specified by the keys in the causal_graphs dictionary in causal_graph.py.

Alternatively, one can first train a causal implicit model over labels only with the following command:

$ python main.py --causal_model big_causal_graph --is_pretrain True

One can then train a conditional generative model for the images given the trained causal generative model for the labels (causal controller), which yields a causal implicit generative model for the image and the labels, as suggested in [arXiv link to the paper]:

$ echo CC-MODEL_PATH='./logs/celebA_0810_191625_0.145tvd_bcg/controller/checkpoints/CC-Model-20000'
$ python main.py --causal_model big_causal_graph --pt_load_path $CC-MODEL_PATH --model_type began --is_train True 

Instead of loading the model piecewise, once image training has been run once, the entire joint model can be loaded more simply by specifying the model directory:

$ python main.py --causal_model big_causal_graph --load_path ./logs/celebA_0815_170635 --model_type began --is_train True 

Tensorboard visualization of the most recently created model is simply (as long as port 6006 is free):

$ python tboard.py

To interact with an already trained model I recommend the following procedure:

ipython
In [1]: %run main --causal_model big_causal_graph --load_path './logs/celebA_0815_170635' --model_type 'began'

For example to sample N=22 interventional images from do(Smiling=1) (as long as your causal graph includes a "Smiling" node:

In [2]: sess.run(model.G,{cc.Smiling.label:np.ones((22,1), trainer.batch_size:22})

Conditional sampling is most efficiently done through 2 session calls: the first to cc.sample_label to get, and the second feeds that sampled label to get an image. See trainer.causal_sampling for a more extensive example. Note that is also possible combine conditioning and intervention during sampling.

In [3]: lab_samples=cc.sample_label(sess,do_dict={'Bald':1}, cond_dict={'Mustache':1},N=22)

will sample all labels from the joint distribution conditioned on Mustache=1 and do(Bald=1). These label samples can be turned into image samples as follows:

In [4]: feed_dict={cc.label_dict[k]:v for k,v in lab_samples.iteritems()}
In [5]: feed_dict[trainer.batch_size]=22
In [6]: images=sess.run(trainer.G,feed_dict)

Configuration

Since this really controls training of 3 different models (CausalController, CausalGAN, and CausalBEGAN), many configuration options are available. To make things managable, there are 4 files corresponding to configurations specific to different parts of the model. Not all configuration combinations are tested. Default parameters are gauranteed to work.

configurations: ./config.py : generic data and scheduling ./causal_controller/config : specific to CausalController ./causal_dcgan/config : specific to CausalGAN ./causal_began/config : specific to CausalBEGAN

For convenience, the configurations used are saved in 4 .json files in the model directory for future reference.

Results

Causal Controller convergence

We show convergence in TVD for Causal Graph 1 (big_causal_graph in causal_graph.py), a completed version of Causal Graph 1 (complete_big_causal_graph in causal_graph.py, and an edge reversed version of the complete Causal Graph 1 (reverse_big_causal_graph in causal_graph.py). We could get reasonable marginals with a complete DAG containing all 40 nodes, but TVD becomes very difficult to measure. We show TVD convergence for 9 nodes for two complete graphs. When the graph is incomplete, there is a "TVD gap" but reasonable convergence.

alt text

Conditional vs Interventional Sampling:

We trained a causal implicit generative model assuming we are given the following causal graph over labels: For the following images when we condition or intervene, these operations can be reasoned about from the graph structure. e.g., conditioning on mustache=1 should give more male whereas intervening should not (since the edges from the parents are disconnected in an intervention).

CausalGAN Conditioning vs Intervening

For each label, images were randomly sampled by either intervening (top row) or conditioning (bottom row) on label=1.

alt text Bald

alt text Mouth Slightly Open

alt text Mustache

alt text Narrow Eyes

alt text Smiling

alt text Eyeglasses

alt text Wearing Lipstick

CausalBEGAN Conditioning vs Intervening

For each label, images were randomly sampled by either intervening (top row) or conditioning (bottom row) on label=1.

alt text Bald

alt text Mouth Slightly Open

alt text Mustache

alt text Narrow Eyes

alt text Smiling

alt text Eyeglasses

alt text Wearing Lipstick

CausalGAN Generator output (10x10) (randomly sampled label)

alt text

CausalBEGAN Generator output (10x10) (randomly sampled label)

alt text

<--- Repo originally forked from these two

Related works

Authors

Christopher Snyder / @22csnyder Murat Kocaoglu / @mkocaoglu

causalgan's People

Contributors

22csnyder avatar mkocaoglu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

causalgan's Issues

train a causal implicit model using multiple GPUs

Hi @mkocaoglu Murat, I try to first train a causal implicit model by using your source code. It ran well if using default one gpu. However, when I set more GPUs, e.g.
python main.py --causal_model big_causal_graph --is_pretrain True --num_gpu 8
I met the error as below. Please help me! Thanks a lot!

setting up pretrain: CausalController
Traceback (most recent call last):
File "main.py", line 89, in
trainer=get_trainer()
File "main.py", line 76, in get_trainer
trainer=Trainer(config,cc_config,model_config)
File "/home/qw/projects/CausalGAN/trainer.py", line 60, in init
self.cc.build_pretrain(label_queue)
File "/home/qw/projects/CausalGAN/causal_controller/CausalController.py", line 179, in build_pretrain
grad_cost,self.dcc_slopes=Grad_Penalty(real_inputs,fake_inputs,self.DCC,self.config)
File "/home/qw/projects/CausalGAN/causal_controller/models.py", line 45, in Grad_Penalty
interpolates = alpha*real_data + ((1-alpha)*fake_data)#Could do more if not fixed batch_size
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 866, in binary_op_wrapper
return func(x, y, name=name)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 1131, in _mul_dispatch
return gen_math_ops.mul(x, y, name=name)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 5042, in mul
"Mul", x=x, y=y, name=name)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
op_def=op_def)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1792, in init
control_input_ops)
File "/home/qw/anaconda3/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1631, in _create_c_op
raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 16 and 128 for 'mul_1' (op: 'Mul') with input shapes: [16,1], [128,9].

can't execute the download.py

The download link of download.py can not work, could you provide a new download link of your dataset, which is trained by your code.

Problem Running the Demo on the CPU

Hello Murat,

I tried to run the code using 0 GPUs, and it gave me errors on both BEGAN and DCGAN. I am not able to test on the default 1 GPU because I don't have any. Do you know what might be causing this?

The command I entered was the following:

python main.py --causal_model big_causal_graph --is_pretrain True --model_type dcgan --is_train True --num_gpu 0

And I got the following output:

C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:455: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:456: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:457: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:458: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:459: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Users\kayan\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\dtypes.py:462: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
tf: resetting default graph!
Loaded ./config.py
Loaded ./causal_controller/config.py
Loaded ./causal_dcgan/config.py
Loaded ./causal_began/config.py
saving config because load path not given
[] MODEL dir: logs\celebA_0728_103822
[
] PARAM path: logs\celebA_0728_103822\params.json
[] PARAM path: logs\celebA_0728_103822\cc_params.json
[
] PARAM path: logs\celebA_0728_103822\dcgan_params.json
[*] PARAM path: logs\celebA_0728_103822\began_params.json
setting up CausalController
causal graph size: 9
setting up data
setup pretrain
setting up pretrain: CausalController
causalcontroller has 58 summaries
WARNING:CausalGAN.rec_loss_coff= 0.0
Filling queue with 202 Celeb images before starting to train. I don't know how long this will take
2021-07-28 10:38:32.942722: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.942937: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.943087: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.943198: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.943308: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.943489: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.944325: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:38:32.945032: W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
Traceback (most recent call last):
File "main.py", line 87, in
trainer=get_trainer()
File "main.py", line 74, in get_trainer
trainer=Trainer(config,cc_config,model_config)
File "C:\Users\kayan\OneDrive\Desktop\CausalGAN\trainer.py", line 95, in init
self.model.build_train_op()
File "C:\Users\kayan\OneDrive\Desktop\CausalGAN\causal_dcgan\CausalGAN.py", line 287, in build_train_op
.minimize(self.g_loss, var_list=self.g_vars)
AttributeError: 'CausalGAN' object has no attribute 'g_loss'

For the BEGAN, I ran:

python main.py --causal_model big_causal_graph --is_pretrain True --model_type began --is_train True --num_gpu 0

and got the following error:

C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:458: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:459: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:460: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:461: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:462: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\dtypes.py:465: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
tf: resetting default graph!
Loaded ./config.py
Loaded ./causal_controller/config.py
Loaded ./causal_dcgan/config.py
Loaded ./causal_began/config.py
saving config because load path not given
[] MODEL dir: logs\celebA_0728_100704
[
] PARAM path: logs\celebA_0728_100704\params.json
[] PARAM path: logs\celebA_0728_100704\cc_params.json
[
] PARAM path: logs\celebA_0728_100704\dcgan_params.json
[*] PARAM path: logs\celebA_0728_100704\began_params.json
setting up CausalController
causal graph size: 9
setting up data
setup pretrain
setting up pretrain: CausalController
causalcontroller has 58 summaries
Filling queue with 202 Celeb images before starting to train. I don't know how long this will take
2021-07-28 10:07:14.554201: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.554405: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.554999: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.555208: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.555410: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.555604: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.555802: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2021-07-28 10:07:14.556000: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\36\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
Traceback (most recent call last):
File "main.py", line 87, in
trainer=get_trainer()
File "main.py", line 74, in get_trainer
trainer=Trainer(config,cc_config,model_config)
File "C:\Users\kayan\OneDrive\Desktop\CausalGAN\trainer.py", line 95, in init
self.model.build_train_op()
File "C:\Users\kayan\OneDrive\Desktop\CausalGAN\causal_began\CausalBEGAN.py", line 278, in build_train_op
g_optim = self.g_optimizer.apply_gradients(g_grads, global_step=self.step)
File "C:\Users\kayan\anaconda3\envs\tf2\lib\site-packages\tensorflow\python\training\optimizer.py", line 423, in apply_gradients
raise ValueError("No variables provided.")
ValueError: No variables provided.

Undefined names can raise NameErrors at runtime

Undefined names can raise NameError at runtime.

flake8 testing of https://github.com/mkocaoglu/CausalGAN on Python 2.7.13

$ flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics

./causal_dcgan/utils.py:147:16: F821 undefined name 'x'
        return x.astype(np.uint8)
               ^
./causal_dcgan/utils.py:149:18: F821 undefined name 'x'
        return ((x+1)/2*255).astype(np.uint8)
                 ^
./figure_scripts/sample.py:157:51: F821 undefined name 'list_labels'
    for name, lab, dfl in zip(model.cc.node_names,list_labels,list_d_fake_labels):
                                                  ^
./figure_scripts/sample.py:157:63: F821 undefined name 'list_d_fake_labels'
    for name, lab, dfl in zip(model.cc.node_names,list_labels,list_d_fake_labels):
                                                              ^
./figure_scripts/sample.py:442:47: F821 undefined name 'nsamples'
            outputs[k]=np.vstack(outputs[k])[:nsamples]
                                              ^
./figure_scripts/sample.py:507:41: F821 undefined name 'nsamples'
                if len(completed_rows)>=nsamples:
                                        ^
./figure_scripts/sample.py:585:40: F821 undefined name 'nsamples'
                outputs[k]=outputs[k][:nsamples]
                                       ^
./figure_scripts/sample.py:592:32: F821 undefined name 'nsamples'
            values=outputs[k][:nsamples]
                               ^

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.