ericjang / draw Goto Github PK
View Code? Open in Web Editor NEWTensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
License: Apache License 2.0
TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
License: Apache License 2.0
Not issue, just out of curious
grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
if N
= 3, delta = 1
grid_i
will be [0 1 2]
then grid_i - N / 2
is grid_i - 1. = [-1 0 1]
then grid_i - N / 2 - 0.5
is [-1.5 -0.5 0.5]
but I think [-1 0 1]
is reasonable value, the mean location will be [gx-1, gx, gx+1]
why need to subtract 0.5, just follow the paper or I miss something?
thanks
FYI, README.md has a missing word:
You can visualize the results by running the script
python plot_data.py <prefix> <output_data>
For example,
python
plot_data.pymyattn /tmp/draw/draw_data.npy
How to open npy file?
Nothing
Hi Eric,
thanks for this implementation! Its been hugely useful in replicating the paper. One thing that I noticed is that in the filterbank function you're squaring the entire exponent rather than just the numerator which is what you want.
i/e your filters are slightly off from equations 24 and 25 in the paper.
Thanks!
Raza
I have been playing with this code for a few days. I can reproduce the GIF animation showed in the first page of this repository. However, this other implementation based on theano (https://github.com/jbornschein/draw) achieves much (subjectively) nicer results (look at their GIF animation). I have tried to use their parameters (like T=64 and read_window=2) in the tensorflow code but I was unable to reproduce results that look that nice. Do you have any idea why there is such a difference and how we can achieve results like that using this tensorflow code?
By niceness I mean the animation looks more realistic, which probably means what the model learns is closer to the actual causal process that happens in human handwriting.
Thanks for sharing this elegant DRAW model!
However, I found in kl divergence computation in draw.py, line 191, the last term should be 0.5 instead of 0.5*T according to the paper's equation 11.
Even though this constant term won't affect the optimization process, I think you may get a different but reasonable loss curve. Because in this situation, it is possible to get negative KL divergence.
Thanks!
It is about 70, lower than most reported results.
I also find one bug. I think you are misleading by the Eq. (12)., the equation is used to compute each element of vector z.
kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch)
should be
kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma-1,1) # each kl term is (1xminibatch)
# kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*z_size*.5 # alternatively
Or the kl term will blow up with large z_size.
Another issue is the mnist data in your code is not binarized. But it won't make much difference.
Nothing
It may not be a problem, but I am just curious about why x_hat (involving true data) is also used for prediction period. Because I think, after training, the model should produce data independently, not by means of true data.
Details as follows:
Read x as well as x_hat
x = filter_img(x, Fx, Fy, gamma, read_n) # batch x (read_n*read_n) x_hat = filter_img(x_hat, Fx, Fy, gamma, read_n)
After the training:
canvases = sess.run(cs, feed_dict) # generate some examples canvases = np.array(canvases) # T x batch x img_size
It seems that, x_hat is also fed into the model. But x_hat contains the true data.
Thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.