Giter VIP home page Giter VIP logo

m92vyas / implementing_attention_mechanism_language_translation Goto Github PK

View Code? Open in Web Editor NEW
1.0 2.0 0.0 49.6 MB

Bahdanau Attention Mechanism | Tensorflow Custom Layers/Model/Loss Function/Metrics | LSTM | Encoder | Decoder | Cross-Attention | Language Translation | Bleu Score | Dropout

Python 14.85% Jupyter Notebook 85.15%
attention-mechanism custom-elements encoder-decoder-architecture lstm preprocessing-data tensorflow loss-function tokenizer

implementing_attention_mechanism_language_translation's Introduction

Bahdanau Attention Mechanism Implementation for Language Translation.

Introduction:

  • We have 362861 rows of Italian to English translated sentences as raw data.
  • Appropriate preprocessing was done. Input english sentences(related to decoder block) were appended by 'start' and output decoder sentences were appended with 'end' token.
  • 307077 sentences were used for training, 54191 sentences for validation and 1088 sentences as test dataset.
  • Both italian and english sentences were tokenize and maximum sequence length of 20 tokens was selected. Finally we had 13335 english tokens and 27402 italian tokens.
  • Appropriate Dataset Loader code was written to return encoder sequence, decoder input/output sequence.

Model:

image ref: https://guillaumegenthial.github.io/sequence-to-sequence.html

Encoder Layer:

  • Italian tokens were embedded to vectors as per given dimensions using embedding layer.
  • Output dimensions: [batch, max_len, embed-size]
  • Individual LSTM output are used to get cross attention score in further layers.
  • LSTM output dimensions: [batch, max_len, lstm-units]

Attention Mechanism Layer:

  • Decoder input is transformed to match encoder output dimensions and attention weights are calculated based on similarity score using dot products and weighted sum of encoder hidden state vector is returned as context vector to be used by decoder.
  • Context vector dimensions: [batch,encoder_lstm_units]
  • Attention weights dimensions: [batch,max_len,1]

Decoder Encoder Cross Attention Layer:

  • It performs cross attention between embedded decoder input(can be glove vectors) and embedded encoder input using previous attention mechanism layer, concatanate the attention updated/weighted decoder input with embedded decoder input and pass it to to lstm layer. Then to dense layer having units equal to output vocab size. The decoder input is passeed one word at a time over batch(matrix form) i.e. cross attention is performed one embedded token at a time over whole batch.
  • Final output dimensions: [batch,tar_vocab_size]

Decoder Layer:

  • It performs cross attention using Decoder Encoder Cross Attention Layer and gives logits values for full decoder input length.
  • Final Logits output shape: [batch,max_len,tar_vocab_size]

Final Translation Model:

  • Using dataset generator appropriate data is passed to encoder and decoder block and final logits values are returned over whole batch.

Custom Loss Function and Metric:

  • Custom loss function and metric will not consider the loss for padded zero.

Training:

  • Following hyperparameters are choosen for training the model:

    • encoder_inputs_length = 20
    • decoder_inputs_length = 20
    • vocab_size_ita = vocab_size_ita
    • vocab_size_eng = vocab_size_eng
    • embedding_dim_enc = 100
    • embedding_dim_dec = 100
    • enc_units = 128
    • dec_units = 128
    • lstm_dropout = 0.2
    • recurrent_dropout = 0.2
    • optimizer = tf.keras.optimizers.Adam()
  • After 70 epochs we get validation accuracy of 0.86% (model not trained further due to resources constraints)

  • Some translated sentence from test datasets:

    • Italian: vedo cosavete fatto lì

      English True: i see what you did there

      Model Translation: i i see what you have done there

    • Italian: tom non è un fisico

      English True: tom is not a physician

      Model Translation: tom tom is not a physician

    • Italian: cè un costo di consegna

      English True: is there a delivery charge

      Model Translation: there there is a charge of the delivery

    • Italian: è un tizio strano

      English True: he is a strange guy

      Model Translation: it it is a strange guy

    • Italian: tutti qua sanno che non mangiamo la carne di maiale

      English True: everyone here knows that we do not eat pork

      Model Translation: everyone everyone here knows we do not eat pork

  • Average test data bleu score: 0.4451662890214658

  • Average test data cumulative 4-gram bleu score: 1.479362713798278e-231

implementing_attention_mechanism_language_translation's People

Contributors

m92vyas avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.