Giter VIP home page Giter VIP logo

joint_learn's Introduction

Joint_learn

Pytorch implementation of below Model

Model Architecture

1.JoinBert

Architecture is referenced to monologg/JointBERT.Please click to see detail.
- Predict intent and slot at the same time from one BERT model (=Joint model)
- total_loss = intent_loss + coef * slot_loss (Change coef with --slot_loss_coef option)
- If you want to use CRF layer, give --use_crf option

2.AttnSeq2Seq

Using Attention mechanism based on RNN Encoder-Decoder.

- Encoder part  
 1. Bidirectional Rnn(LSTM) to encode source sents.
 2. Backward Lstm final hidden state to compute deocder init hidden state.  
  
- Attenion part
 1.Using neural network,last hidden state,encoder hiddens to compute  attention weight.
 2.using softmax to gain weight  
 
- Decoder part
 1.To predict current slots,feed last hidden state,last predict label,aligned encoder hidden,
   context vector.  
 2.using last hidden state & encoder hiddens to compute current context vector  
 3.using init decoder hidden & it's context to compute intent classification  
 4.total_loss= intent_loss + coef * slot_loss   

Dependencies

  • python>=3.7
  • torch==1.5.1
  • seqeval==1.2.2
  • transformers==4.3.0
  • pytorch-crf==0.7.2

Dataset

Train Dev Test Intent Labels Slot Labels
ATIS 4,478 500 893 21 120
Snips 13,084 700 700 7 72
SMP 4,623 199 199 60 311
  • The number of labels are based on the train dataset.
  • Add UNK for labels (For intent and slot labels which are only shown in dev and test dataset)
  • Add PAD for slot label

Train & Evaluation (Usage)

$ python main.py --task {task_name} \
                 --model_type {model_type} \
                 --model_dir {model_dir_name} \
                 --do_train --do_eval \
                 --use_crf

#For ATIS
$ python main.py --task atis \
                 --model_type joint_bert \
                 --model_dir ./atis_model/joint_bert \
                 --do_train --do_eval 

#For Snips
$ python main.py --task snips \
                 --model_type joint_bert \
                 --model_dir ./snips_model/joint_bert \
                 --do_train --do_eval
                 
#For smp
$ python main.py --task smp \
                 --model_type zh_joint_AttnS2S \
                 --model_dir ./smp_model/joint_AttnS2S \
                 --do_train --do_eval 

Prediction

$ python predict.py --input_file {Input_file} --output_file {Output_file} --model_dir {Model_dir} --model_type {Model_type}

Default hyperparams setting

  • BERT: Lr=1e-4 warm_steps=248 max_norm=1 train_epochs=10 dropout=0.1
  • Seq2Seq+Attention: lr=1e-3 train_epochs=10 dropout=0.1

Results

  • Run 5 ~ 10 epochs (Record the best result)
  • Only test with uncased model
  • Warm up steps 248 is the best
  • Seq2Seq model's teach forcing ratio is 0.5
Intent acc (%) Slot F1 (%) Sentence acc (%)
Snips BERT 99.14 96.11 99.0
BERT + CRF 98.85 95.81 98.71
Seq2Seq+Attention 98.14 95.07 97.57
ATIS BERT 97.4 98.07 97.40
BERT + CRF 97.8 98.07 97.80
Seq2Seq+Attention 98.2 96.8 98.20
SMP Seq2Seq+Attention 95.2 90.7 84.63

Sentence predict Result

References

joint_learn's People

Contributors

qeksveyay avatar yinghao1019 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.