Giter VIP home page Giter VIP logo

xlnet-pytorch's Introduction

XLNet-Pytorch arxiv:1906.08237

Simple XLNet implementation with Pytorch Wrapper!

You can see How XLNet Architecture work in pre-training with small batch size(=1) example.

To Usage

$ git clone https://github.com/graykode/xlnet-Pytorch && cd xlnet-Pytorch

# To use Sentence Piece Tokenizer(pretrained-BERT Tokenizer)
$ pip install pytorch_pretrained_bert

$ python main.py --data ./data.txt --tokenizer bert-base-uncased \
   --seq_len 512 --reuse_len 256 --perm_size 256 \
   --bi_data True --mask_alpha 6 --mask_beta 1 \
   --num_predict 85 --mem_len 384 --num_epoch 100

Also, You can run code in Google Colab easily.

  • Hyperparameters for Pretraining in Paper.

#### Option
  • —data(String) : .txt file to train. It doesn't matter multiline text. Also, one file will be one batch tensor. Default : data.txt

  • —tokenizer(String) : I just used huggingface/pytorch-pretrained-BERT's Tokenizer as subword tokenizer(I'll edit it to sentence piece soon). you can choose in bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased. Default : bert-base-uncased

  • —seq_len(Integer) : Sequence length. Default : 512

  • —reuse_len(Interger) : Number of token that can be reused as memory. Could be half of seq_len. Default : 256

  • —perm_size(Interger) : the length of longest permutation. Could be set to be reuse_len. Default : 256

  • --bi_data(Boolean) : whether to create bidirectional data. If bi_data is True, biz(batch size) should be even number. Default : False

  • —mask_alpha(Interger) : How many tokens to form a group. Defalut : 6

  • —mask_beta(Integer) : How many tokens to mask within each group. Default : 1

  • —num_predict(Interger) : Num of tokens to predict. In Paper, it mean Partial Prediction. Default : 85

  • —mem_len(Interger) : Number of steps to cache in Transformer-XL Architecture. Default : 384

  • —num_epoch(Interger) : Number of Epoch. Default : 100

What is XLNet?

XLNet is a new unsupervised language representation learning method based on a novel generalized permutation language modeling objective. Additionally, XLNet employs Transformer-XL as the backbone model, exhibiting excellent performance for language tasks involving long context.

Model MNLI QNLI QQP RTE SST-2 MRPC CoLA STS-B
BERT 86.6 92.3 91.3 70.4 93.2 88.0 60.6 90.0
XLNet 89.8 93.9 91.8 83.8 95.6 89.2 63.6 91.8

Keyword in XLNet

  1. How did XLNet benefit from Auto-Regression and Auto-Encoding models?

    • Auto-Regression Model
    • Auto-Encoding Model
  2. Permutation Language Modeling with Partial Prediction

    • Permutation Language Modeling

    • Partial Prediction

  3. Two-Stream Self-Attention with Target-Aware Representation

    • Two-Stram Self-Attention

    • Target-Aware Representation

Author

  • Because the original repository is subject to the Apache2.0 license, it is subject to the same license.
  • Tae Hwan Jung(Jeff Jung) @graykode, Kyung Hee Univ CE(Undergraduate).
  • Author Email : [email protected]

xlnet-pytorch's People

Contributors

graykode avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

xlnet-pytorch's Issues

Low accuracy on sample task

Hi author,

Any idea why the performance accuracy is so low on the sample task you provide?

I added the following lines to test accuracy:
output = logits.transpose(1, 2)
predicts = F.softmax(output, dim=1).argmax(dim=1)
print((predicts == target).sum().item() / len(target))

The accuracy for the task is ~0.01% after 100 epochs.

Error and general question

First the error — I get this both when trying to run the notebook locally (ubuntu 18.04) and from Colab:

Traceback (most recent call last):
  File "main.py", line 89, in <module>
    num_predict=args.num_predict)
  File "/content/xlnet-Pytorch/data_utils.py", line 345, in make_permute
    reuse_len)
  File "/content/xlnet-Pytorch/data_utils.py", line 292, in _local_perm
    non_mask_tokens = (~is_masked) & non_func_tokens
RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'other' in call to _th_and

Any ideas?

For the general question; I really want to pretrain from scratch with my own small corpus. Any tips on how I might go about doing that?

Thanks

how to do inference?

I want to do sentence embedding using your pytorch code. but i do not find how to make test data input and inference code

RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'other'

If I run the code with default arguments (and use data.txt from the repository) I get the following message:

Traceback (most recent call last):
  File "C:/Users/matej/git/xlnet-Pytorch/main.py", line 89, in <module>
    num_predict=args.num_predict)
  File "C:\Users\matej\git\xlnet-Pytorch\data_utils.py", line 345, in make_permute
    reuse_len)
  File "C:\Users\matej\git\xlnet-Pytorch\data_utils.py", line 292, in _local_perm
    non_mask_tokens = (~is_masked) & non_func_tokens
RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'other'

I use Python 3.6.9 and PyTorch 1.2.0.

Re-implementation Performance

Hello authors,

Thank you for your re-implementation. I look forward to using it. I just wanted to confirm your reimplementation results because I did not see any in the README/Colab.

Have you reproduced the XLNet results (any of Squad/IMDB/GLUE/etc.) by finetuning with PyTorch?

Model name 'bert-large-uncased' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese).

can you help me ?
D:\xlnet-Pytorch-master>python main.py
Model name 'bert-large-uncased' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese). We assumed 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt' was a path or url but couldn't find any file associated to this path or url.
Traceback (most recent call last):
File "main.py", line 58, in
model = xlnet.XLNet(n_token=len(sp.vocab), n_layer=6, n_head=4, d_head=8,
AttributeError: 'NoneType' object has no attribute 'vocab'

Batch training

@graykode Can you explain how batch training would be conducted? For example, what if we had multiple input files for training data?

Currently, training is done using only a single data file. For multiple data files, would data_utils._create_data have to return a batch of features?

Runtime Error in colab

image
When I run the code on the colab, I got the above error. I wonder where did I do wrong and what is the suitable codes environment requirements. Thank you very much!

TypeError:can't convert np.ndarray of type numpy.bool_

Traceback (most recent call last):
File "main.py", line 89, in
num_predict=args.num_predict)
File "/home/hemengge/xlnet-pytorch/data_utils.py", line 335, in make_permute
is_masked = torch.Tensor(feature.pop("is_masked"))

TypeError: can't convert np.ndarray of type numpy.bool_. The only supported types are: double, float, float16, int64, int32, and uint8

Confusion about the relative position embedding with attn_type='bi' but bsz=1

The default setting is to use the bidirectional data, attn_type='bi', but bsz=1.
But in this function,

def relative_positional_encoding(self, qlen, klen, d_model, clamp_len, attn_type,

It shows the bidirectional data only works when bsz%2 ==0. However in default, bsz = 1.
I am confused, if bsz=1, the setting for the beg, and end in the following code, is it right?

xlnet-Pytorch/xlnet.py

Lines 380 to 387 in cb793a1

if attn_type == 'bi':
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif attn_type == 'uni':
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(attn_type))

Could anyone help me with this confusion?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.