Giter VIP home page Giter VIP logo

seqgan_tensorflow's Introduction

SeqGAN_tensorflow

This code is used to reproduce the result of synthetic data experiments in "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient" (Yu et.al). It replaces the original tensor array implementation with higher level tensorflow API for better flexibility.

Introduction

The baisc idea of SeqGAN is to regard sequence generator as an agent in reinforcement learning. To train this agent, it applies REINFORCE (Williams, 1992) algorithm to train the generator and a discriminator is trained to provide the reward. To calculate the reward of partially generated sequence, Monte-Carlo sampling is used to rollout the unfinished sequence to get the estimated reward. seqgan

Some works based on training method used in SeqGAN:

  • Recurrent Topic-Transition GAN for Visual Paragraph Generation (Liang et.al, ICCV 2017)
  • Towards Diverse and Natural Image Descriptions via a Conditional GAN (Dai et.al, ICCV 2017)
  • Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner (Chen et.al, ICCV 2017)
  • Adversarial Ranking for Language Generation (Lin et.al, NIPS 2017)
  • Long Text Generation via Adversarial Training with Leaked Information (Guo et.al, AAAI 2018)

Prerequisites

  • Python 2.7
  • Tensorflow 1.3

Run the code

Simply run python train.py will start the training process. It will first pretrain the generator and discriminator then start adversarial training.

Results

The output in experiment.log would be something similar to below, which is close to reported result in original implementation

pre-training...
epoch:	0	nll:	10.1971
epoch:	5	nll:	9.4694
epoch:	10	nll:	9.2169
epoch:	15	nll:	9.17986
epoch:	20	nll:	9.16206
epoch:	25	nll:	9.1344
epoch:	30	nll:	9.12127
epoch:	35	nll:	9.0948
epoch:	40	nll:	9.10186
epoch:	45	nll:	9.10108
epoch:	50	nll:	9.0971
epoch:	55	nll:	9.11246
epoch:	60	nll:	9.1182
epoch:	65	nll:	9.10095
epoch:	70	nll:	9.09244
epoch:	75	nll:	9.08816
epoch:	80	nll:	9.10319
epoch:	85	nll:	9.08916
epoch:	90	nll:	9.08348
epoch:	95	nll:	9.09661
epoch:	100	nll:	9.10361
epoch:	105	nll:	9.11718
epoch:	110	nll:	9.10492
epoch:	115	nll:	9.1038
adversarial training...
epoch:	0	nll:	9.09558
epoch:	5	nll:	9.03083
epoch:	10	nll:	8.96725
epoch:	15	nll:	8.91415
epoch:	20	nll:	8.87554
epoch:	25	nll:	8.82305
epoch:	30	nll:	8.76805
epoch:	35	nll:	8.73597
epoch:	40	nll:	8.71933
epoch:	45	nll:	8.71653
epoch:	50	nll:	8.71746
epoch:	55	nll:	8.7036
epoch:	60	nll:	8.68666
epoch:	65	nll:	8.68931
epoch:	70	nll:	8.68588
epoch:	75	nll:	8.69977
epoch:	80	nll:	8.69636
epoch:	85	nll:	8.69916
epoch:	90	nll:	8.6969
epoch:	95	nll:	8.71021
epoch:	100	nll:	8.72561
epoch:	105	nll:	8.71369
epoch:	110	nll:	8.71723
epoch:	115	nll:	8.72388
epoch:	120	nll:	8.71293
epoch:	125	nll:	8.70667
epoch:	130	nll:	8.70341
epoch:	135	nll:	8.69929
epoch:	140	nll:	8.69793
epoch:	145	nll:	8.67705
epoch:	150	nll:	8.65372

Note: Part of this code (dataloader, discriminator, target LSTM) is based on original implementation by Lantao Yu. Many thanks to his code

seqgan_tensorflow's People

Contributors

chenchengkuan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

seqgan_tensorflow's Issues

真实数据数字和汉字对照表

你好,我想看一下这个模型的生成效果。我注意到你的生成文件都是数字,那相应的汉字对照表能否发一下呢?谢谢~

关于build_adversarial_network中loss的问题

你好,有个问题想请教一下:
self.gen_loss_adv = -tf.reduce_sum(
tf.reduce_sum(
tf.one_hot(tf.to_int32(tf.reshape(self.input_seqs_adv, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
tf.clip_by_value(tf.reshape(self.softmax_list_reshape, [-1, self.num_emb]), 1e-20, 1.0)
), 1) * tf.reshape(self.rewards, [-1]))
请问下面这一项的作用是什么?对应原论文中的哪部分?
tf.clip_by_value(tf.reshape(self.softmax_list_reshape, [-1, self.num_emb]), 1e-20, 1.0)
我大概理解这个self.softmax_list_reshape里存放了input进入lstm+softmax后的输出,所以大概是一个概率值。
我的理解是这个loss对执行了的action动作的reward进行了求和,但好像不需要用到上面一项(虽然进行反向传播的话好像是需要的,我有点混乱),我结合原文看了半天,但还是没找到这部分对应的理解

About real dataset

你好,我使用真实的语料集作为训练数据,但是发现生成文本的质量很不好,请问有什么在训练上的建议嘛?谢谢~

关于target_lstm中的参数问题

在做文本训练的时候,嵌入维度及词汇表维度改变,相应的self.g_embeddings = tf.Variable(self.params[0])就会发生改变,还有后面涉及到的params中的参数,这个参数读取也得改变,但是其读取的维度已经是固定的,如果自己随机初始化这些参数会有什么影响?已有真实数据的话,感觉这里这个target_lstm是计算损失还有的作用。还有就是在训练的时候,损失越小但生成样本的质量却越来越差(从长度、分布情况与真实样本来看),不知道是什么问题

generator

您好,有些疑问,生成模型中def build_sample_network(self):函数的输入是什么?不太理解函数里面的sample_word?麻烦您讲讲

关于SeqGAN的疑惑

您好,我在学习研究您这个关于SeqGAN的实现,现在有一些疑惑,还请您指点一下。
代码里面的target_lstm是用于生成正样本和衡量Generator与oracle model的相似度的?如果我只是需要利用SeqGAN这个模型生成新的文本,那是否我可以不使用target_lstm,只需将我现有的文本作为正样本,将其放进模型中进行训练?
谢谢!

question about reward and loss function

您好,相较于原版,你的这个版本简直太清晰了!点赞👍
我有一个关于generator的loss function的问题,在原版的SeqGAN里有 repo
我看到你的代码里也是取的正确类别的accuracy:
reward_allseq = np.concatenate((reward_rollout_seq, reward_last_tok), axis=0)[:,1]
跟原版一致。
请问你怎么看待这个问题呢?

questions about rollout and discriminator

Q1. it seems that you call tf.get_variable_scope().reuse_variables() function many times in the scope "teller" in rollout.py, in my understanding the first time you call this function makes the reuse flag True in the whole scope, why call it many times?
Q2. are the weights of the word embedding in discriminator different from that in the generator? can i reuse the word embedding of generator in the discriminator?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.