Giter VIP home page Giter VIP logo

textsum-gan's Introduction

GAN for Text Summarization

Tensorflow re-implementation of Generative Adversarial Network for Abstractive Text Summarization.

Requirements

  • Python3
  • Tensorflow >= 1.4 (tested on Tensorflow 1.4.1)
  • numpy
  • tqdm
  • sklearn
  • rouge
  • pyrouge

You can use the python package manager of your choice (pip/conda) to install the dependencies. The code is tested on Ubuntu 16.04 operating system.

Quickstart

  • Dataset

    Please follow the instructions here for downloading and preprocessing the CNN/DailyMail dataset. After that, store data files train.bin, val.bin, test.bin and vocabulary file vocab into specified data directory, e.g., ./data/.

  • Prepare negative samples for discriminator

    You can download the prepared data discriminator_train_data.npz for discriminator from dropbox and store into specified data directory, e.g., ./data/.

  • Train the full model

    python3 main.py --mode=train --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --pretrain_dis_data_path=./data/discriminator_train_data.npz --restore_best_model=False
    
  • Decode

    python3 main.py --mode=decode --data_path=./data/test.bin --vocab_path=./data/vocab --log_root=./log --single_pass=True
    

References

textsum-gan's People

Contributors

iwangjian avatar mayank-k91 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

textsum-gan's Issues

ValueError: Cannot feed value of shape (4, 100) for Tensor 'Placeholder:0', which has shape '(4, 1)'

I got an error
`
File "main.py", line 210, in
tf.compat.v1.app.run()
File "/home/eric/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/home/eric/.local/lib/python3.6/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/home/eric/.local/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "main.py", line 204, in main
decoder.decode()
File "/home/eric/Documents/textrank_summarization/textsum-gan/decode.py", line 108, in decode
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
File "/home/eric/Documents/textrank_summarization/textsum-gan/beam_search.py", line 104, in run_beam_search
enc_states, dec_in_state = model.run_encoder(sess, batch)
File "/home/eric/Documents/textrank_summarization/textsum-gan/generator.py", line 510, in run_encoder
self.global_step], feed_dict) # run the encoder
File "/home/eric/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/home/eric/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1149, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (4, 100) for Tensor 'Placeholder:0', which has shape '(4, 1)'

`
my tensorflow version is tensorflow (1.14.0), ubuntu 18.04 system

Decoding problem: ValueError

I trained the model exactly as the steps of README, only for 145 train steps. And I interrupt the training process to have a decoding try. I also typed the command exactly the same as the decode part from README. However, it shows the ValueError.

Loading checkpoint ./log/train/model.ckpt-145
INFO:tensorflow:Restoring parameters from ./log/train/model.ckpt-145
INFO:tensorflow:Wrote example 0 to file
INFO:tensorflow:Wrote example 1 to file
INFO:tensorflow:Wrote example 2 to file
INFO:tensorflow:Wrote example 3 to file
INFO:tensorflow:Wrote example 4 to file
Traceback (most recent call last):
File "main.py", line 202, in
tf.app.run()
File "/home/amax/anaconda3/envs/tf14/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "main.py", line 196, in main
decoder.decode()
File "/home/amax/zhj/textsum-gan/decode.py", line 108, in decode
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
File "/home/amax/zhj/textsum-gan/beam_search.py", line 132, in run_beam_search
prev_coverage=prev_coverage)
File "/home/amax/zhj/textsum-gan/generator.py", line 580, in decode_onestep
results = sess.run(to_return, feed_dict=feed) # run the decoder step
File "/home/amax/anaconda3/envs/tf14/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 889, in run
run_metadata_ptr)
File "/home/amax/anaconda3/envs/tf14/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1096, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (3, 256) for Tensor 'seq2seq/reduce_final_st/Relu:0', which has shape '(4, 256)'

The error message made me confused. Could you please help me with this problem please?

Segmentation fault (core dumped) when importing Batcher in the main.py

Hi developer,

when I try to run the command
"python main.py --mode=train --data_path=./data/finished_files/train.bin --vocab_path =./data /finished_files/vocab --log_root=./log --pretrain_dis_data_path=./data/discriminator_train_data.npz --restore_best_model=False
",
I have encountered "Segmentation fault (core dumped) ", occuring at "from batcher import Batcher' in the main.py.
No other information is given.

So could you be kind to help me out? I do appreciate it!

Bests,

Qiana

你好

请问有pytorch版本的代码吗?

你们的都可以运行吗?

为什么我的在命令行输入python3 main.py --mode=train --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --pretrain_dis_data_path=./data/discriminator_train_data.npz --restore_best_model=False
报错直接抛出异常“ sys.exit(main(argv))
File "main.py", line 218, in main
raise ValueError("The 'mode' flag must be one of pretrain/train/decode")
ValueError: The 'mode' flag must be one of pretrain/train/decode”

Not getting expected results

I'm trying to run the code on the CNN/Dailymail dataset following the instructions in the readme but the loss doesn't seem to be decreasing and when I try to decode I get output like the following. This is after over 1000000 training steps.

, , , a , , peterson , , .
, , the , .
in .
, .
.
in , , in , .
a , .
to , .

Do you have any idea where the problem could be?

main.py

I'm getting the following error

Traceback (most recent call last):
  File "main.py", line 210, in <module>
    tf.app.run()
  File "/home/game-of-codes/anaconda3/envs/py3/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "main.py", line 192, in main
    generator_batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass)
  File "/home/game-of-codes/Work/Projects/TS Experiments/textsum-gan/batcher.py", line 250, in __init__
    self._example_queue = queue.Queue(BATCH_QUEUE_MAX * self._hps.batch_size)
TypeError: unsupported operand type(s) for *: 'int' and 'Flag'

How to make a test on specific document ?

Hi, after follow the instruction on README file. I want to make a summarize on a single test document.
How can I do that. Thank you.

I'm running on:

  • OS: Ubuntu 20.04
  • Python: 3.6.9
  • Tensorflow: 1.4.1

Model saving for adversarial training

Hi @iwangjian ,
Sorry for asking such stupid question, but i cannot find where the model checkpoints are saved during the adversarial training. Looking forward for your reply.

I just noticed the saving in the tf.train.Supervisor. Please delete this issue.

我打不开dropbox链接

You can download the generated data discriminator_train_data.npz for discriminator from dropbox. Meanwhile, you can follow the instructions below to prepare negative samples by yourself.但是点击dropbox提示是无效链接,可以重新发一下链接吗

您好,我想问下代码的逻辑结构

先是通过双向lstm得到隐藏状态,这是encoder,然后通过decoder也就是lstm将摘要输出,再将这个假摘要和真摘要输入cnn判别网络中得到score,来进行更新判别器吗?
另外我想问下train.bin文件是原文章的二进制表示吗,真实摘要保存在哪里呀

Error in generator.py(line 320)

Dear Wang,

I found your article very exciting, so I decided to download your package.
I have successfully evaluated the decoder and after that I decided to begin a new training session.
After starting the training phase, the first 100 stories were successfully read
(see I0905 11:05:25.497097 7352 attention_decoder.py:151] Adding attention_decoder timestep 99 of 100) .

But I get the following error message (see the annexed error message).
How can I proceed?
Thank you in advance for your advice!

Best Regards,

Istvan

*File "k:\summarization\gan_summarization\textsum-gan\generator.py", line 320, in _add_seq2seq
loss_with_reward = self.D_reward*tf.stack(sample_loss_per_step, axis=1)roll_mask

File "g:\Anaconda353\lib\site-packages\tensorflow\python\ops\math_ops.py", line 884, in binary_op_wrapper
return func(x, y, name=name)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1180, in _mul_dispatch
return gen_math_ops.mul(x, y, name=name)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 6878, in mul
"Mul", x=x, y=y, name=name)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
op_def=op_def)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\framework\ops.py", line 2027, in init
control_input_ops)
File "g:\Anaconda353\lib\site-packages\tensorflow\python\framework\ops.py", line 1867, in _create_c_op
raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 32 and 100 for 'seq2seq/loss/mul_201' (op: 'Mul') with input shapes: [32,100], [32,100,1].

您好,我想问下discriminator.py文件中loss函数为什么是交叉熵呀,而不是min φ  EY ∼pdata [logDφ(Y )] ] EY ∼Gθ [log(1 1 Dφ(Y ))]

代码如下:
self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")#(?, 2)两列分别代表了为假的概率和为真的概率
self.ypred_for_auc = tf.nn.softmax(self.scores)#(?, 2)
self.predictions = tf.argmax(self.scores, 1, name="predictions")#(?,)#0代表预测的是假,1代表预测的是真

        # CalculateMean cross-entropy loss
        with tf.name_scope("loss"):
            losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)
            self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss

请问下这块代码是否有误呀

time for pretraining step

It is mentioned in the repo that the pretraining step should run for some time, please mention after how much time i should interrupt it.

Also i can't use the pretrained npz file as i'm planning to train with my custom dataset.

parser.add_argument('--decode_dir', required=True, help="root of the decoded directory").

In gen_sample.py file, parser.add_argument('--decode_dir', required=True, help="root of the decoded directory")
this line refers to the decoded directory with two child directories (reference and decoded). I don't understand how to create these directories effectively. Please help me.

I run the gen_sample.py file and got this output: python3 gen_sample.py --data_dir data/ --decode_dir decode_dir --vocab_path data/vocab

vocab length: 199869
positive samples: 0
negative samples: 0
file saved: data/discriminator_train_data.npz

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.