Giter VIP home page Giter VIP logo

textcnn-conv-deconv-pytorch's Introduction

[WIP]textcnn-conv-deconv-pytorch

Text convolution-deconvolution auto-encoder and classification model in PyTorch.
PyTorch implementation of Deconvolutional Paragraph Representation Learning described in NIPS 2017.
This repository is still developing.

Requirement

  • Python 3
  • PyTorch >= 0.3
  • numpy

Usage

Train

Paragraph reconstruction

Download data. Hotel reviews
Then, run following command.

$ python main_reconstruction.py -data_path=path/to/hotel_reviews.p 

Specify download data path by -data_path.

About other parameters.

usage: main_reconstruction.py [-h] [-lr LR] [-epochs EPOCHS]
                              [-batch_size BATCH_SIZE]
                              [-lr_decay_interval LR_DECAY_INTERVAL]
                              [-log_interval LOG_INTERVAL]
                              [-test_interval TEST_INTERVAL]
                              [-save_interval SAVE_INTERVAL]
                              [-save_dir SAVE_DIR] [-data_path DATA_PATH]
                              [-shuffle SHUFFLE] [-sentence_len SENTENCE_LEN]
                              [-embed_dim EMBED_DIM]
                              [-kernel_sizes KERNEL_SIZES] [-tau TAU]
                              [-use_cuda] [-enc_snapshot ENC_SNAPSHOT]
                              [-dec_snapshot DEC_SNAPSHOT]

text convolution-deconvolution auto-encoder model

optional arguments:
  -h, --help            show this help message and exit
  -lr LR                initial learning rate
  -epochs EPOCHS        number of epochs for train
  -batch_size BATCH_SIZE
                        batch size for training
  -lr_decay_interval LR_DECAY_INTERVAL
                        how many epochs to wait before decrease learning rate
  -log_interval LOG_INTERVAL
                        how many steps to wait before logging training status
  -test_interval TEST_INTERVAL
                        how many epochs to wait before testing
  -save_interval SAVE_INTERVAL
                        how many epochs to wait before saving
  -save_dir SAVE_DIR    where to save the snapshot
  -data_path DATA_PATH  data path
  -shuffle SHUFFLE      shuffle data every epoch
  -sentence_len SENTENCE_LEN
                        how many tokens in a sentence
  -embed_dim EMBED_DIM  number of embedding dimension
  -kernel_sizes KERNEL_SIZES
                        kernel size to use for convolution
  -tau TAU              temperature parameter
  -use_cuda             whether using cuda
  -enc_snapshot ENC_SNAPSHOT
                        filename of encoder snapshot
  -dec_snapshot DEC_SNAPSHOT
                        filename of decoder snapshot

Semi-supervised sequence classification

Run follow command.

$ python main.py -data_path=path/to/trainingdata -label_path=path/to/labeldata

Specify training data and label data by -data_path and -label_data arguments.
Both data must have same lines and training data must be separated by blank.

About other parameters.

usage: main_classification.py [-h] [-lr LR] [-epochs EPOCHS]
                              [-batch_size BATCH_SIZE]
                              [-lr_decay_interval LR_DECAY_INTERVAL]
                              [-log_interval LOG_INTERVAL]
                              [-test_interval TEST_INTERVAL]
                              [-save_interval SAVE_INTERVAL]
                              [-save_dir SAVE_DIR] [-data_path DATA_PATH]
                              [-label_path LABEL_PATH] [-separated SEPARATED]
                              [-shuffle SHUFFLE] [-sentence_len SENTENCE_LEN]
                              [-mlp_out MLP_OUT] [-dropout DROPOUT]
                              [-embed_dim EMBED_DIM]
                              [-kernel_sizes KERNEL_SIZES] [-tau TAU]
                              [-use_cuda] [-enc_snapshot ENC_SNAPSHOT]
                              [-dec_snapshot DEC_SNAPSHOT]
                              [-mlp_snapshot MLP_SNAPSHOT]

text convolution-deconvolution auto-encoder model

optional arguments:
  -h, --help            show this help message and exit
  -lr LR                initial learning rate
  -epochs EPOCHS        number of epochs for train
  -batch_size BATCH_SIZE
                        batch size for training
  -lr_decay_interval LR_DECAY_INTERVAL
                        how many epochs to wait before decrease learning rate
  -log_interval LOG_INTERVAL
                        how many steps to wait before logging training status
  -test_interval TEST_INTERVAL
                        how many steps to wait before testing
  -save_interval SAVE_INTERVAL
                        how many epochs to wait before saving
  -save_dir SAVE_DIR    where to save the snapshot
  -data_path DATA_PATH  data path
  -label_path LABEL_PATH
                        label path
  -separated SEPARATED  how separated text data is
  -shuffle SHUFFLE      shuffle the data every epoch
  -sentence_len SENTENCE_LEN
                        how many tokens in a sentence
  -mlp_out MLP_OUT      number of classes
  -dropout DROPOUT      the probability for dropout
  -embed_dim EMBED_DIM  number of embedding dimension
  -kernel_sizes KERNEL_SIZES
                        kernel size to use for convolution
  -tau TAU              temperature parameter
  -use_cuda             whether using cuda
  -enc_snapshot ENC_SNAPSHOT
                        filename of encoder snapshot
  -dec_snapshot DEC_SNAPSHOT
                        filename of decoder snapshot
  -mlp_snapshot MLP_SNAPSHOT
                        filename of mlp classifier snapshot

Reference

Deconvolutional Paragraph Representation Learning
Yizhe Zhang, Dinghan Shen, Guoyin Wang, Zhe Gan, Ricardo Henao, Lawrence Carin
arXiv:1708.04729 [cs.CL]

textcnn-conv-deconv-pytorch's People

Contributors

ymym3412 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

textcnn-conv-deconv-pytorch's Issues

Issues running your model

Hello,

First of all, let me thank you for putting this together. I was very curious about the paper, but their TF implementation is rather poor and very hard to understand. Yours is very clean and makes a lot more sense!

I ran your model with default parameters in the reconstruction mode on the Hotel dataset on a single Tesla K80 machine. It took 20+ hours to train for 10 epochs, and the model didn't converge (see below). The loss has never moved below 22,000.

I have a few questions:

  1. Is there something that I am doing wrong? Are there any parameters that need to be specified to make the model work? I checked the defaults for the parameters and they looked in line with the paper.

  2. You use log softmax as the loss function for the deconvolutional model and I assume that's why the model is taking so long to train. I know that's what the paper recommends, but have you tried using adapative softmax instead?

  3. What are your thoughts on seeding the embedding matrix with pre-learned embeddings? I am curious if using L2-normalized Glove embeddings would speed up the training.

  4. I also tried to train jointly with a classifier using AG News dataset, but MLP classifier is unhappy about the dimensions it gets.

h = encoder(feature) print(h.shape) prob = decoder(h) log_prob = mlp(h.squeeze())

h = torch.Size([64, 500, 5, 1])
The last dimension gets squeezed, but 64, 500, 5 vector is not compatible with the 500x300 FC layer:

RuntimeError: size mismatch, m1: [32000 x 5], m2: [500 x 300] at /Users/soumith/minicondabuild3/conda-bld/pytorch_1518385717421/work/torch/lib/TH/generic/THTensorMath.c:1434

I would greatly appreciate any guidance you could give me on these!

======= RESULTS ==========

Input Sentence:
stayed two nights in this hotel for our 20th anniversary . the location is fantastic , near great shopping , restaraunts and entertainment . the staff was great . the bed was the most comfortable i have ever slept in . i wanted to take it home with me ! the rooms and halls were quiet and peaceful . the bathroom was incredible , sparkling marble , huge space , impecably clean . the only down side was how expensive it was to park our car . yikes ! over all we could not have asked for a better hotel and we will definately stay here again . it was worth every penny . END_TOKEN

Output Sentence:
ricca raggiungibile duur raggiungibile uhr tasse toujours nuestro nogal tren dava frequentato bagno altre uhr krijg salir krap toujours krijg uhr deve l'albergo misma uhr frequentato quand atencion standaard frequentato uhr avere uhr cambiare arredamento precios preso gevraagd bekommt dotate interessante parken l'albergo z'n uhr accanto uhr raggiungibile uhr stanze uhr krijg uhr spazi aeropuerto kwamen uhr mocht ruido frequentato uhr avere bekommt all'arrivo salir totalmente uhr zentral bekommt spettacolare l'albergo llegamos dava frequentato servizio pesar bekommt metropolitana serviable stanze salir relativamente jahre relativamente bekommt arrivati passa z'n uhr trova naechte necesario suis raam l'albergo necesario l'albergo z'n servizio hemos l'albergo enkel aeropuerto citta foi zoek nostro estar salir avere l'albergo heerlijk verkennen andando salir particolarmente trova pagamento trova trovate trova acondicionado trova frigorifero trova trovate trova trovate trova acondicionado trova frigorifero trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova acondicionado trova trovate trova trovate trova trovate trova strasse trova trovate trova trovate trova trovate trova strasse trova trovate trova trovate trova trovate trova trova
Epoch: 10

Epoch: 10
Steps: 108920
Loss: 22058.16015625
Eval
Evaluation - loss: 683.1286144549368 Rouge1: 1.5889867148342671e-06 Rouge2: 0.0
Finish!!!

Need help with understanding the code

Hi @ymym3412 ,

Thank you for the great work. Sorry if I sound ignorant. but I am trying to understand how the code works thoroughly

May I ask what do the following lines do in the main_reconstruction script? Thanks a lot for your help in advance.

t1 = args.sentence_len + 2 * (args.filter_shape - 1)

t2 = int(math.floor((t1 - args.filter_shape) / 2) + 1) # "2" means stride size

t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1) - 2

Issue with the reconstruction run.

The language of the output after running main_reconstruction.py doesn't seem to be in English. Not sure where the error is as the value of parameter lang in util.transform_id2word is set to en. Please let me know if I am missing something.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.