Tensorflow re-implementation of "Generative Adversarial Network for Abstractive Text Summarization" (AAAI-18).
- Python3 (tested on Python 3.6)
- Tensorflow >= 1.4 (tested on Tensorflow 1.4.1)
- numpy
- tqdm
- sklearn
- rouge
- pyrouge
You can use the python package manager of your choice (pip/conda) to install the dependencies. The code is tested on Ubuntu 16.04 operating system.
-
Dataset
Please follow the instructions here for downloading and preprocessing the CNN/DailyMail dataset. After that, store data files
train.bin
,val.bin
,test.bin
and vocabulary filevocab
into specified data directory, e.g../data/
-
Prepare negative samples for discriminator
You can download the generated data
discriminator_train_data.npz
for discriminator from dropbox. Meanwhile, you can follow the instructions below to prepare negative samples by yourself.Firstly, pretrain generator for some steps:
python3 main.py --mode=pretrain --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --restore_best_model=False
After pretraining some steps, stop it, then restore the model for training (NOTE: Set
restore_best_model
asTrue
this step):python3 main.py --mode=pretrain --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --restore_best_model=True
Secondly, decode training data using pretrained generator:
python3 main.py --mode=decode --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --single_pass=True
Finally, generate
.npz
file containing both positive and negative samples:python3 gen_sample.py --data_dir=./data --decode_dir=./log/decode_xxxx --vocab_path=./data/vocab
After that,
discriminator_train_data.npz
is generated indata_dir
. -
Train the full model
python3 main.py --mode=train --data_path=./data/train.bin --vocab_path=./data/vocab --log_root=./log --pretrain_dis_data_path=./data/discriminator_train_data.npz --restore_best_model=False
-
Decode
python3 main.py --mode=decode --data_path=./data/test.bin --vocab_path=./data/vocab --log_root=./log --single_pass=True
[1] "Generative Adversarial Network for Abstractive Text Summarization" (AAAI-18)
[2] https://github.com/abisee/pointer-generator
[3] https://github.com/LantaoYu/SeqGAN