Giter VIP home page Giter VIP logo

snakeztc / neuraldialog-zsdg Goto Github PK

View Code? Open in Web Editor NEW
133.0 10.0 27.0 13.78 MB

PyTorch codebase for zero-shot dialog generation SIGDIAL 2018, It is released by Tiancheng Zhao (Tony) from Dialog Research Center, LTI, CMU

Home Page: https://www.cs.cmu.edu/~tianchez

License: Apache License 2.0

Python 100.00%
zero-shot-learning end-to-end-machine-learning neural-dialogue-agents domain-adaptation dialog sigdial-2018

neuraldialog-zsdg's Introduction

Zero-shot Dialog Generation (ZSDG) for End-to-end Neural Dialog Models

Codebase for Zero-Shot Dialog Generation with Cross-Domain Latent Actions, published as a long paper in SIGDIAL 2018. Reference information is in the end of this page. Presentation slides can be found here.

This work won the best paper award at SIGDIAL 2018.

If you use any source codes or datasets included in this toolkit in your work, please cite the following paper. The bibtex are listed below:

@article{zhao2018zero,
  title={Zero-Shot Dialog Generation with Cross-Domain Latent Actions},
  author={Zhao, Tiancheng and Eskenazi, Maxine},
  journal={arXiv preprint arXiv:1805.04803},
  year={2018}
}

Requirements

python 2.7
pytorch >= 0.3.0.post4
numpy
nltk

Datasets

The data folder contains three datasets:

Getting Started

The following scripts implement 4 different models, including:

  • Baseline: standard attentional encoder-decoder and encoder with pointer-sentinel-mixture decoder (see the paper for details).
  • Out Models: cross-domain Action Matching training for the above two baseline systems.

Training

Run the following to experiment on the SimDial dataset

python simdial-zsdg.py

Run the following to experiment on the Stanford Multi-Domain Dataset

python stanford-zsdg.py

Switching Model

The hyperparameters are exactly the same for the above two scripts. To train different models, use the following configurations. The following examples are for simdial-zsdg.py, which also apply to stanford-zsdg.py.

For baseline model with attetnion decoder:

python simdial-zsdg.py --action_match False --use_ptr False

For baseline model with pointer-sentinel mixture decoder:

python simdial-zsdg.py --action_match False --use_ptr True    

For action matching model with attetnion decoder:

python simdial-zsdg.py --action_match True --use_ptr False

For action matching model with attetnion decoder:

python simdial-zsdg.py --action_match True --use_ptr True    

Hyperparameters

The following are some of key hyperparameters:

  • action_match: if or not using the proposed AM algorithm for training
  • target_example_cnt: the number of seed response from each domain used for training.
  • use_ptr: if or not using pointer-sentinel-mixture decoder
  • black_domains: define which domains are excluded from training
  • black_ratio: the percentage of training data from black_domains are excluded. Range=[0,1], where 1 means removed 100% of the training data.
  • forward_only: use existing model or train a new one
  • load_sess: the path to the existing model
  • rnn_cell: the type of RNN cell, supporting LSTM or GRU
  • dropout: the chance for dropout.

Test a existing model

All trained models and log files are saved to the log folder. To run a existing model, you can:

  • Set the forward_only argument to be True
  • Set the load_sess argument to te the path to the model folder in log
  • Run the script

neuraldialog-zsdg's People

Contributors

ishalyminov avatar snakeztc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

neuraldialog-zsdg's Issues

View size is not compatible with input tensor's size when --use_ptr is True

ZeroShotPtrHRED (
  (embedding): Embedding(1899, 200, padding_idx=0), parameters=379800
  (utt_encoder): RnnUttEncoder(
    (embedding): Embedding(1899, 200, padding_idx=0)
    (rnn): EncoderRNN(
      (input_dropout): Dropout(p=0.0)
      (rnn): GRU(201, 256, batch_first=True, dropout=0.3, bidirectional=True)
    )
  ), parameters=1084824
  (ctx_encoder): EncoderRNN(
    (input_dropout): Dropout(p=0.0)
    (rnn): LSTM(512, 512, batch_first=True, dropout=0.3)
  ), parameters=2101248
  (policy): Linear(in_features=512, out_features=512, bias=True), parameters=262656
  (connector): LinearConnector(
    (linear_h): Linear(in_features=512, out_features=512, bias=True)
    (linear_c): Linear(in_features=512, out_features=512, bias=True)
  ), parameters=525312
  (decoder): DecoderPointerGen(
    (input_dropout): Dropout(p=0.3)
    (rnn): LSTM(200, 512, batch_first=True, dropout=0.3)
    (embedding): Embedding(1899, 200, padding_idx=0)
    (attention): Attention(
      (linear_out): Linear(in_features=1024, out_features=512, bias=True)
      (dec_w): Linear(in_features=512, out_features=512, bias=True)
      (attn_w): Linear(in_features=512, out_features=512, bias=True)
      (query_w): Linear(in_features=512, out_features=1, bias=True)
    )
    (project): Linear(in_features=512, out_features=1899, bias=True)
  ), parameters=3867396
  (nll_loss): NLLEntropy(), parameters=0
  (l2_loss): L2Loss(), parameters=0
) Total Parameters=8221236
**** Training Begins ****
**** Epoch 0/50 ****
Number of left over sample 16
Train begins with 238 batches
Train add with 22 warm up batches
Traceback (most recent call last):
  File "stanford-zsdg.py", line 161, in <module>
    main(config)
  File "stanford-zsdg.py", line 142, in main
    train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=hred_utils.generate)
  File "/NeuralDialog-ZSDG/zsdg/main.py", line 105, in train
    loss = model(batch, mode=TEACH_FORCE)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/NeuralDialog-ZSDG/zsdg/models/models.py", line 499, in forward
    mode=mode, gen_type=gen_type)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/NeuralDialog-ZSDG/zsdg/enc2dec/decoders.py", line 443, in forward
    decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed)
  File "/NeuralDialog-ZSDG/zsdg/enc2dec/decoders.py", line 362, in forward_step
    rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1)
RuntimeError: invalid argument 2: View size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /pytorch/aten/src/THC/generic/THCTensor.c:276
$ python -c "import torch; print torch.__version__"
0.4.0
$ python 
Python 2.7.12

invalid argument 2: view size is not compatible with input tensor's size and stride

**** Training Begins ****
**** Epoch 0/50 ****
Train init with 1521 batches with 0 left over samples
Train add with 30 warm up batches
/home/yuanzhuo.wyz/.conda/envs/yizhen27/lib/python2.7/site-packages/torch/nn/functional.py:995: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
Traceback (most recent call last):
File "simdial-zsdg.py", line 174, in
main(config)
File "simdial-zsdg.py", line 156, in main
train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=hred_utils.generate)
File "/home/yuanzhuo.wyz/humanlike/neuraldialog-ZSDG/zsdg/main.py", line 105, in train
loss = model(batch, mode=TEACH_FORCE)
File "/home/yuanzhuo.wyz/.conda/envs/yizhen27/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/yuanzhuo.wyz/humanlike/neuraldialog-ZSDG/zsdg/models/models.py", line 499, in forward
mode=mode, gen_type=gen_type)
File "/home/yuanzhuo.wyz/.conda/envs/yizhen27/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/yuanzhuo.wyz/humanlike/neuraldialog-ZSDG/zsdg/enc2dec/decoders.py", line 443, in forward
decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed)
File "/home/yuanzhuo.wyz/humanlike/neuraldialog-ZSDG/zsdg/enc2dec/decoders.py", line 362, in forward_step
rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /opt/conda/conda-bld/pytorch_1532571898140/work/aten/src/THC/generic/THCTensor.cpp:226
(yizhen27)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.