Giter VIP home page Giter VIP logo

nast's Introduction

NAST

This repository contains the codes and model outputs for the paper NAST: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer (Findings of ACL 2021)

overview

Outputs

We release the outputs of NAST under outputs.

YELP: outputs/YELP/{stytrans|latentseq}_{simple|learnable}/{pos2neg|neg2pos}.txt
GYAFC: outputs/GYAFC/{stytrans|latentseq}_{simple|learnable}/{fm2inf|inf2fm}.txt
  • {stytrans|latentseq} indicates the base model, i.e., StyTrans or LatentSeq.
  • {simple|learnable} indicates the two alignment strategies.

Requirements

  • python >= 3.7
  • pytorch >= 1.7.1
  • cotk == 0.1.0 (pip install cotk == 0.1.0)
  • transformers == 3.0.2
  • tensorboardX

Evaluation

The evaluation code is under eval.

We use 6 metrics in paper:

  • PPL: The perplexity of transferred sentences, which is evaluated by a finetuned GPT-2 base.
  • Acc: The accuracy of the transferred sentences' style, which is evaluated by a finetuned Roberta-base.
  • SelfBleu: The bleu score between the source sentences and the transferred sentences.
  • RefBleu: The bleu score between the transferred sentences and the human references.
  • G2: Geometric mean of Acc and RefBLEU, sqrt(Acc * RefBLEU).
  • H2: Harmonic mean of Acc and RefBLEUAcc * RefBLEU / (Acc + RefBLEU).

The codes also provides other 3 metrics:

  • self_G2: Geometric mean of Acc and SelfBLEU, sqrt(Acc * SelfBleu).
  • self_H2: Harmonic mean of Acc and SelfBLEUAcc * SelfBLEU / (Acc + SelfBLEU).
  • Overall: Use G2 if available, otherwise self_G2.

Data Preparation

The YELP data can be downloaded here and should be put under data/yelp.

We cannot provide the GYAFC data because copyright issues. You can download the data and the human references, and then preprocess the data following the format as the YELP data. We use family&relationship domain in all our experiments. The GYAFC data should be put under data/GYAFC_family.

Pretrained Classifier & Language Model

The evaluation codes require a pretrained classifier and a language model. We provide our pretrained models below.

Classifier Language Model
YELP Link Link
GYAFC Link Link

Download the models and put them under the ./eval/model/.

See the training instructions for how to train the classifier and language model. You should keep the same classifier and language model to evaluate NAST and baselines, otherwise the results cannot be compared.

Usage

For YELP:

cd eval
python eval_yelp.py --test0 test0.out --test1 test1.out

test0.out and test1.out should be the generated outputs.

Other arguments (Optional):

--allow_unk (Allow unknown tokens in generated outputs)
--dev0 dev0.out  (Evaluate the result on the dev set)
--dev1 dev1.out  (Evaluate the result on the dev set)
--datadir DATADIR (The data path, default: ../yelp_transfer_data)
--clsrestore MODELNAME (The file name of the pretrained classifier, default: cls_yelp_best. The corresponding path is ./model/MODELNAME.model)
--lmrestore MODELNAME (The file name of the pretrained language model, default: lm_yelp_best. Indicating ./model/MODELNAME.model)
--cache  (Build cache to make the evaluation faster)

For GYAFC:

python eval_GYAFC.py --test0 test0.out --test1 test1.out

The other arguments are similar with YELP.

Example Outputs

domain  acc     self_bleu       ref_bleu        ppl   self_g2    self_h2    g2     h2     overall
test0   0.862   0.629   0.491   156.298 0.737   0.727   0.650   0.625   0.650
test1   0.910   0.638   0.633   88.461  0.762   0.750   0.759   0.747   0.759

You can find results of NAST here.

Train your Classifier / Language Model

Training scripts:

cd eval
% train a classifier
python run_cls.py --name CLSNAME --dataid ../data/yelp --pos_weight 1 --cuda
% train a language model
python run_lm.py --name LMNAME --dataid ../data/yelp --cuda

Arguments:

  • name can be an arbitrary string, which is used for identifying checkpoints and tensorboard curves.
  • dataid specifies the data path.
  • pos_weight specifies the sample weight for label 1 (positive sentences in YELP dataset). A number bigger than 1 make the model bias to the label 1. (In GYAFC, we use pos_weight=2.)
  • cuda specifies the model use GPU in training.

See run_cls.py or run_lm.py for more arguments.

You can track the training process by Tensorboard, where the log will be under ./eval/tensorboard.

The trained model will be saved in ./eval/model.

Training: Style Transformer

Data Preparation

Follow the same instructions as here.

Use the Pretrained Classifier

The classifier is used for validation.

You can download a pretrained classifier or train a classifier yourself. Then put them under ./styletransformer/model.

Train NAST

Simple Alignment:

cd styletransformer
python run.py --name MODELNAME --dataid ../data/yelp --clsrestore cls_yelp_best

Learnable Alignment:

cd styletransformer
python run.py --name MODELNAME --dataid ../data/yelp --clsrestore cls_yelp_best --use_learnable --pretrain_batch 1000

Arguments:

  • name can be an arbitrary string, which is used for identifying checkpoints and tensorboard curves.
  • dataid specifies the data path.
  • clsrestore specifies the name of pretrained model.
  • use_learnable uses learnable alignment.
  • pretrain_batch specifies steps for pretraining (only use cycle loss).

See run.py for more arguments.

You can track the training process by Tensorboard, where the log will be under ./styletransformer/tensorboard.

The trained model will be saved in ./styletransformer/model.

Todo

  • Add the implementation for LatentSeq

Acknowledgement & Related Repository

Thanks DualRL for providing multiple human references and some baselines' outputs. Thanks StyIns for other baselines' outputs. Thanks StyTrans and LatentSeq for providing great base models.

Citing

Please kindly cite our paper if this paper and the codes are helpful.

@inproceedings{huang2021NAST,
  author = {Fei Huang and Zikai Chen and Chen Henry Wu and Qihan Guo and Xiaoyan Zhu and Minlie Huang},
  title = {{NAST}: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer},
  booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics: Findings},
  year = {2021}
}

nast's People

Contributors

hzhwcmhf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

nast's Issues

Validation takes very long time (10x)

I realized that validation was weirdly long. To show this, I changed Line 84-96 of NAST/styletransformer/seq2seq.py to the followings:

while self.now_epoch < args.epochs:
	self.now_epoch += 1
	self.updateOtherWeights()
	import time 
	st_bo = time.time()
	self.train(args.eval_steps)
	logging.info("took %.2f minutes", (time.time() - st_bo) / 60) # <------------ add to measure train time
	
	import time
	st_bo = time.time()
	devloss_detail = self.evaluate("dev", hasref=False, write_output=False)
	self.devSummary(self.now_batch, devloss_detail)
	logging.info("epoch %d, evaluate dev", self.now_epoch)
	logging.info("took %.2f minutes", (time.time() - st_bo) / 60) # <------------ add to measure val time

	import time
	st_bo = time.time()
	testloss_detail = self.evaluate("test", hasref=True, write_output=True)
	self.testSummary(self.now_batch, testloss_detail)
	logging.info("epoch %d, evaluate test", self.now_epoch)
	logging.info("took %.2f minutes", (time.time() - st_bo) / 60) # <------------ add to measure test time

And the logs show:

Training start......                                                                                                                 
train_0 set restart, 2769 batches and 2 left                                                                                         
train_1 set restart, 4156 batches and 57 left
[iter 505] d_adv_loss: 2.7750  f_slf_loss: 6.8572  f_cyc_loss: 9.1048  f_adv_loss: 1.7901  f_slf_length_loss: 0.0000  f_cyc_length_loss: 0.0000  temp: 1.0000  f_slf_gen_error: 0.0000 f_cyc_gen_error: 0.0000
...
15:04:14 seq2seq.py[line:90] took 1.05 minutes # training
dev_0 set restart, 31 batches and 16 left
dev_1 set restart, 31 batches and 16 left
15:14:47 seq2seq.py[line:98] epoch 1, evaluate dev
15:14:47 seq2seq.py[line:99] took 10.55 minutes # validation
test_0 set restart, 7 batches and 52 left
15:17:12 seq2seq.py[line:106] epoch 1, evaluate test                                                                                 
15:17:12 seq2seq.py[line:107] took 2.42 minutes # test

So validation alone runs 10x slower than training. Do you have any idea why ? @hzhwcmhf

TypeError

There is a bug in styletransformer/seq2seq.py line 350, which is :
"{TypeError}can only concatenate str (not "list") to str"
I think the result["gen"][i] should be the sentence instead of token list

How to generate data?

Thanks to share your work.

When I run the following code, I found there is no test0.out and test1.out. So, how to get them?

cd eval
python eval_yelp.py --test0 test0.out --test1 test1.out

Please.

Code Release ETA

Hi,

Great work, the paper was a very interesting read. Just curious what the rough ETA is for the code release?

Best,
Max

a problem when running run.py

When running run.py, there is an issue IndexError: invalid index to scalar variable. It seems to be the dimension problem of the weights, I don't understand it, may I ask how to solve this problem? The following error message is displayed:
Traceback (most recent call last):
File "/root/nas/NAST/styletransformer/run.py", line 161, in
run(*sys.argv[1:])
File "/root/nas/NAST/styletransformer/run.py", line 157, in run
main(args)
File "/root/nas/NAST/styletransformer/main.py", line 71, in main
model.train_process()
File "/root/nas/NAST/styletransformer/seq2seq.py", line 90, in train_process
devloss_detail = self.evaluate("dev", hasref=False, write_output=False)
File "/root/nas/NAST/styletransformer/seq2seq.py", line 356, in evaluate
result0 = inference(datakey, 0)
File "/root/nas/NAST/styletransformer/seq2seq.py", line 341, in inference
result = metric.close()
File "/opt/conda/envs/torch/lib/python3.9/site-packages/cotk/metric/metric.py", line 242, in close
res.update(metric.close())
File "/root/nas/NAST/styletransformer/../utils/cotk_private/metric/name_changer.py", line 14, in close
return {self.prefix + key: value for key, value in self.target.close().items()}
File "/opt/conda/envs/torch/lib/python3.9/site-packages/cotk/metric/bleu.py", line 189, in close
corpus_bleu(self.refs, self.hyps, weights=weights, smoothing_function=SmoothingFunction().method3),
File "/opt/conda/envs/torch/lib/python3.9/site-packages/nltk/translate/bleu_score.py", line 200, in corpus_bleu
weights[0][0]
IndexError: invalid index to scalar variable.

The settings for Hyper-Parameters

Hi, I'm wondering what are the best settings for Hyper-Parameters α,β1, β2, γ in the overall objective?And why did you select β from{0.5,1,1.5,3,5,10,15} ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.