Giter VIP home page Giter VIP logo

self-attention-gan-tensorflow's Introduction

Self-Attention-GAN-Tensorflow

Simple Tensorflow implementation of "Self-Attention Generative Adversarial Networks" (SAGAN)

Requirements

  • Tensorflow 1.8
  • Python 3.6

Related works

Summary

Framework

framework

Code

def attention(self, x, ch):
  f = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='f_conv') # [bs, h, w, c']
  g = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='g_conv') # [bs, h, w, c']
  h = conv(x, ch, kernel=1, stride=1, sn=self.sn, scope='h_conv') # [bs, h, w, c]

  # N = h * w
  s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]

  beta = tf.nn.softmax(s)  # attention map

  o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
  gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

  o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
  x = gamma * o + x

  return x

Code2 (Google Brain)

def attention_2(self, x, ch):
    batch_size, height, width, num_channels = x.get_shape().as_list()
    f = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='f_conv')  # [bs, h, w, c']
    f = max_pooling(f)

    g = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='g_conv')  # [bs, h, w, c']

    h = conv(x, ch // 2, kernel=1, stride=1, sn=self.sn, scope='h_conv')  # [bs, h, w, c]
    h = max_pooling(h)

    # N = h * w
    s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)  # # [bs, N, N]

    beta = tf.nn.softmax(s)  # attention map

    o = tf.matmul(beta, hw_flatten(h))  # [bs, N, C]
    gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

    o = tf.reshape(o, shape=[batch_size, height, width, num_channels // 2])  # [bs, h, w, C]
    o = conv(o, ch, kernel=1, stride=1, sn=self.sn, scope='attn_conv')
    x = gamma * o + x

    return x

Usage

dataset

> python download.py celebA
  • mnist and cifar10 are used inside keras
  • For your dataset, put images like this:
├── dataset
   └── YOUR_DATASET_NAME
       ├── xxx.jpg (name, format doesn't matter)
       ├── yyy.png
       └── ...

train

  • python main.py --phase train --dataset celebA --gan_type hinge

test

  • python main.py --phase test --dataset celebA --gan_type hinge

Results

ImageNet

 

CelebA (100K iteration, hinge loss)

celebA

Author

Junho Kim

self-attention-gan-tensorflow's People

Contributors

taki0112 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

self-attention-gan-tensorflow's Issues

Asking about the result of testing

Hi, I have some questions
1.The results are shown in this project from training stage? because you just take same data with training when running testing?
2. Could you also shown input image of testing and corresponding testing result? because when i tried test the model with new data which excluding in your training dataset, somehow it returns the result like randomly image, I dont see the related between testing input and output result
3. could you upload your model with 100 iteration?
Thanks

Need a clarification (code)

What are the significance of modules below and why they are created separately (Just to know)

  1. attention()
  2. google_attention ()

Similarly: I am not figure out why below 2 are created separately (Understanding code line line by)
def up_resblock()
def down_resblock()

Implementation is very helpful.
Is there any way things can be extended to 3D? Any code reference or any suggestions

Thank You.

Spectral normaliation on the attention layer

Hi, I notice in your implementation, there is no spectral normalization for the attention layer. According to the spectral normalization paper, it seems to make more sense to normalize the specral norm of each layer so that the overall Lipschitz constant is less than 1. Therefore, is there any reasons not to do the spectral normalization on the attention layer. Thanks!

About SN & BN

Hi,

Great work!
I’m wondering if applying BN with scale=True make the network no longer be a Lipschitz-1 function
(Which should be the target of SN?)

In the paper of SN, we treat the entire network as a function and calculate its
Lipschitz constant by multiplying every Lipschitz constant of each component
And since BN with scale=True introduces an additional scaling parameter,
we should also take it into consideration when calculating its spectral norm, right?
If so, applying BN after SN seems to destroy the work done by SN?

How to use in 3D conv?

The paper and the code are both for the 2D convolution of the sn limit w, then how to deal with w in the 3D convolution?

iterations vs epochs

soft iteration not hard
self.iteration = len(self.data) // self.batch_size
then used for control the repeataion
self.epoch = 200K

which layers should add attention and TTUR

Thank you for sharing!

If I use the structure of VGG+ deconvolution, which layers should I add attention to? Do all 16 layers of VGG need to be added as attention structure?

Can you briefly introduce where two-timescale update rule(TTUR) is used in your code?

Poor results on mini-imagenet

I used your network to train the mini-imagenet data set, but the generated effect was very poor and of little value. How can I improve it?

Possibly wrong implementation of self-attention module?

I think your implementation is wrong. The tensor shapes were messed up:

f = conv(x, ch // 8, kernel=1, stride=1, scope='f_conv') # [bs H W C//8]
g = conv(x, ch // 8, kernel=1, stride=1, scope='g_conv') # [bs H W C//8]
h = conv(x, ch, kernel=1, stride=1, scope='h_conv') # [bs H W C]
s = tf.matmul(g, f, transpose_b=True) # [bs H W W]
attention_shape = s.shape
s = tf.reshape(s, shape=[attention_shape[0], -1, attention_shape[-1]]) # [bs, HW, W]
beta = tf.nn.softmax(s, axis=1) # [bs, HW, W]
beta = tf.reshape(beta, shape=attention_shape) # bs H W W
o = tf.matmul(beta, h) # [bs H W C]

The self-attention module is essentially a non-local module from:
X. Wang, R. Girshick, A. Gupta, and K. He. Non-local neural networks. In CVPR, 2018.
Based on Figure 2 of above paper, attention_shape should be [bs, H*W, H*W].

the benefit of hinge loss

image

i find citation 13,16,30 and do not know exact principle of hinge loss.
i feel confused about why don't we use WGAN loss function.
cause it has better performance than WGAN loss function?
thx

Setting of Gamma

When do you set gamma to something that is not 0? It seem to me that you do not use attention at all.

Trying to share variable discriminator/D_logit/kernel, but specified shape (32768, 1) and found shape (73728, 1)

When I train my custom dataset with 940*940*3 pictures, it occurs the following error. However, when i resize pictures to 64*64*3, it works.
By the way, I know the default img_size is 128. When I train my 940 dataset, I have set the img_size to 940 !

Following is the error:

Traceback (most recent call last): File "main.py", line 110, in main() File "main.py", line 91, in main gan.build_model() File "C:\Users\tt\workspace\GAN\SAGAN.py", line 244, in build_model fake_logits = self.discriminator(fake_images, reuse=True) File "C:\Users\tt\workspace\GAN\SAGAN.py", line 157, in discriminator x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') File "C:\Users\tt\workspace\GAN\ops.py", line 70, in fully_conneted initializer=weight_init, regularizer=weight_regularizer) File "C:\Users\tt\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1317, in get_variable constraint=constraint) File "C:\Users\tt\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1079, in get_variable constraint=constraint) File "C:\Users\tt\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 425, in get_variable constraint=constraint) File "C:\Users\tt\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 394, in _true_getter use_resource=use_resource, constraint=constraint) File "C:\Users\tt\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 738, in _get_single_variable found_var.get_shape())) ValueError: Trying to share variable discriminator/D_logit/kernel, but specified shape (32768, 1) and found shape (73728, 1).

Results shown on the GITHUB page from your trained model?

Hi,

I noticed the excellent results for imagenet dataset on your GITHUB page. These results look similar to ones in the paper. Are these results generated from your model? If no, can you please put some results for imagenet dataset for your model as well?

self attention in my own network

hello i am trying to add attention in my network
def attention(x, ch=128):

with tf.variable_scope("conv_f"):
     f = conv(x, 1, ch // 8, 1)  # [bs, h, w, c']
with tf.variable_scope("conv_g"):
     g = conv(x, 1, ch // 8,  1)  # [bs, h, w, c']
with tf.variable_scope("conv_h"):
     h = conv(x, 1, ch, 1)  # [bs, h, w, c]
    

# N = h * w
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)  # # [bs, N, N]

beta = tf.nn.softmax(s)  # attention map

o = tf.matmul(beta, hw_flatten(h))  # [bs, N, C]
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

o = tf.reshape(o, shape=[batch_size, height, width, num_channels // 2])  # [bs, h, w, C]
o = conv(o, 1,  ch, 1)
x = gamma * o + x

return x

def hw_flatten(x) :
return tf.reshape(x, shape=[x.shape[0], [-1], x.shape[-1]])

i got this error " ValueError: Cannot reshape a tensor with 16777216 elements to shape [32,64,64,64] (8388608 elements) for 'generator/encoder_5/Reshape_3' (op: 'Reshape') with input shapes: [32,4096,128], [4] and with input tensors computed as partial shapes: input[1] = [32,64,64,64].

"

what does "ch" mean?

Hi, I saw you use ch variables in generator and discriminator. So could you explain what ch means?
image
image

Thanks

ImportError: cannot import name 'prefetch_to_device'

WARNING:tensorflow:From E:\Users\Raytine\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
Using TensorFlow backend.
Traceback (most recent call last):
File "F:/Image-Forgery/niubiSelf-Attention-GAN-Tensorflow-master (1)/Self-Attention-GAN-Tensorflow-master/main.py", line 1, in
from SAGAN import SAGAN
File "F:\Image-Forgery\niubiSelf-Attention-GAN-Tensorflow-master (1)\Self-Attention-GAN-Tensorflow-master\SAGAN.py", line 4, in
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
ImportError: cannot import name 'prefetch_to_device'

Problem with mnist

I tried to train it on mnist with default params, the results are not at all good. Do I need to change params to make it work for mnist?

selt-attention

s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]

image
The transposed matrix, f(x).T, should be in the front position, according to the paper.
def hw_flatten(x) :

It seems your want to reshape the tensor. x.shape[0] means batchsize and x.shape[-1] means channel. What does -1 mean? Because the author didn't publish the code, just read the paper ,i don't know why your want to reshape the tensor?
thx!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.