Giter VIP home page Giter VIP logo

albert-pytorch's Introduction

ALBERT-Pytorch

Simply implementation of ALBERT(A LITE BERT FOR SELF-SUPERVISED LEARNING OF LANGUAGE REPRESENTATIONS) in Pytorch. This implementation is based on clean dhlee347/pytorchic-bert code.

Please make sure that I haven't checked the performance yet(i.e Fine-Tuning), only see SOP(sentence-order prediction) and MLM(Masked Langauge model with n-gram) loss falling.

CAUTION Fine-Tuning Tasks not yet!

File Overview

This contains 9 python files.

  • tokenization.py : Tokenizers adopted from the original Google BERT's code
  • models.py : Model classes for a general transformer
  • optim.py : A custom optimizer (BertAdam class) adopted from Hugging Face's code
  • train.py : A helper class for training and evaluation
  • utils.py : Several utility functions
  • pretrain.py : An example code for pre-training transformer

PreTraining

With WikiText 2 Dataset to try Unit-Test on GPU(t2.xlarge). You can also use parallel Multi-GPU or CPU.

$ CUDA_LAUNCH_BLOCKING=1 python pretrain.py \
            --data_file './data/wiki.train.tokens' \
            --vocab './data/vocab.txt' \
            --train_cfg './config/pretrain.json' \
            --model_cfg './config/albert_unittest.json' \
            --max_pred 75 --mask_prob 0.15 \
            --mask_alpha 4 --mask_beta 1 --max_gram 3 \
            --save_dir './saved' \
            --log_dir './logs'
			
cuda (1 GPUs)
Iter (loss=19.162): : 526it [02:25,  3.58it/s]
Epoch 1/25 : Average Loss 18.643
Iter (loss=12.589): : 524it [02:24,  3.63it/s]
Epoch 2/25 : Average Loss 13.650
Iter (loss=9.610): : 523it [02:24,  3.62it/s]
Epoch 3/25 : Average Loss 9.944
Iter (loss=10.612): : 525it [02:24,  3.60it/s]
Epoch 4/25 : Average Loss 9.018
Iter (loss=9.547): : 527it [02:25,  3.66it/s]
...

TensorboardX : loss_lm + loss_sop.

# to use TensorboardX
$ pip install -U protobuf tensorflow
$ pip install tensorboardX
$ tensorboard --logdir logs # expose http://server-ip:6006/

Introduce Keywords in ALBERT with code.

  1. SOP(sentence-order prediction) loss : In Original BERT, creating is-not-next(negative) two sentences with randomly picking, however ALBERT use negative examples the same two consecutive segments but with their order swapped.

    is_next = rand() < 0.5 # whether token_b is next to token_a or not
    
    tokens_a = self.read_tokens(self.f_pos, len_tokens, True)
    seek_random_offset(self.f_neg)
    #f_next = self.f_pos if is_next else self.f_neg
    f_next = self.f_pos # `f_next` should be next point
    tokens_b = self.read_tokens(f_next, len_tokens, False)
    
    if tokens_a is None or tokens_b is None: # end of file
    self.f_pos.seek(0, 0) # reset file pointer
    return
    
    # SOP, sentence-order prediction
    instance = (is_next, tokens_a, tokens_b) if is_next \
    else (is_next, tokens_b, tokens_a)
  2. Cross-Layer Parameter Sharing : ALBERT use cross-layer parameter sharing in Attention and FFN(FeedForward Network) to reduce number of parameter.

    class Transformer(nn.Module):
        """ Transformer with Self-Attentive Blocks"""
        def __init__(self, cfg):
            super().__init__()
            self.embed = Embeddings(cfg)
            # Original BERT not used parameter-sharing strategies
            # self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
    
            # To used parameter-sharing strategies
            self.n_layers = cfg.n_layers
            self.attn = MultiHeadedSelfAttention(cfg)
            self.proj = nn.Linear(cfg.hidden, cfg.hidden)
            self.norm1 = LayerNorm(cfg)
            self.pwff = PositionWiseFeedForward(cfg)
            self.norm2 = LayerNorm(cfg)
            # self.drop = nn.Dropout(cfg.p_drop_hidden)
    
        def forward(self, x, seg, mask):
            h = self.embed(x, seg)
    
            for _ in range(self.n_layers):
                # h = block(h, mask)
                h = self.attn(h, mask)
                h = self.norm1(h + self.proj(h))
                h = self.norm2(h + self.pwff(h))
    
            return h
  3. Factorized Embedding Parameterziation : ALBERT seperated Embedding matrix(VxD) to VxE and ExD.

    class Embeddings(nn.Module):
        "The embedding module from word, position and token_type embeddings."
     def __init__(self, cfg):
            super().__init__()
            # Original BERT Embedding
            # self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.hidden) # token embedding
    
            # factorized embedding
            self.tok_embed1 = nn.Embedding(cfg.vocab_size, cfg.embedding)
            self.tok_embed2 = nn.Linear(cfg.embedding, cfg.hidden)
    
            self.pos_embed = nn.Embedding(cfg.max_len, cfg.hidden) # position embedding
            self.seg_embed = nn.Embedding(cfg.n_segments, cfg.hidden) # segment(token type) embedding
  4. n-gram MLM : MLM targets using n-gram masking (Joshi et al., 2019). Same as Paper, I use 3-gram. Code Reference from XLNET implementation.

Cannot Implemente now

  • In Paper, They use a batch size of 4096 LAMB optimizer with learning rate 0.00176 (You et al., 2019), train all model in 125,000 steps.

Author

  • Tae Hwan Jung(Jeff Jung) @graykode, Kyung Hee Univ CE(Undergraduate).
  • Author Email : [email protected]

albert-pytorch's People

Contributors

dhlee347 avatar graykode avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

albert-pytorch's Issues

how to train on custom dataset ?

hey there, sadly i dont clearly understand how the wiki.train.tokens and vocab.txt file is produced. followingly could you please show me the steps on how to reproduce the same train.tokens and vocab.txt file on a custom dataset ? (i am trying to implement this on a different language)

out of memory error

I'm running classify on the MRPC dataset. In trainer.train trainer.train(get_loss,model_file,True), it allows only three arguments not 4 so I cant use the pretrain file.

Also it runs out of memory,
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 4.00 GiB total capacity; 3.02 GiB already allocated; 43.35 MiB free; 223.00 KiB cached)
Iter (loss=X.XXX): 0%| | 0/115 [00:00<?, ?it/s]

Please help.

I'm using cfg.hidden instead of cfg.dim and a drop out probability of 0.5

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.