Giter VIP home page Giter VIP logo

torchcrf's Introduction

Torch CRF

CircleCI Coverage Status MIT License

Python Versions PyPI version

Implementation of CRF (Conditional Random Fields) in PyTorch

Requirements

  • python3 (>=3.6)
  • PyTorch (>=1.0)

Installation

$ pip install TorchCRF

Usage

>>> import torch
>>> from TorchCRF import CRF
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> batch_size = 2
>>> sequence_size = 3
>>> num_labels = 5
>>> mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
>>> labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device)  # (batch_size, sequence_size)
>>> hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
>>> crf = CRF(num_labels)

Computing log-likelihood (used where forward)

>>> crf.forward(hidden, labels, mask)
tensor([-7.6204, -3.6124], device='cuda:0', grad_fn=<ThSubBackward>)

Decoding (predict labels of sequences)

>>> crf.viterbi_decode(hidden, mask)
[[0, 2, 2], [4, 0]]

License

MIT

References

torchcrf's People

Contributors

andreabac3 avatar raynardj avatar rikeda71 avatar yanghaha11514 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

torchcrf's Issues

run test error

import torch

batch_size = 2
sequence_size = 3
num_labels = 5
labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).cuda() # (batch_size, sequence_size)
hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).cuda()

from TorchCRF import CRF
mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]).cuda() # (batch_size. sequence_size)
def myCRF(hidden, mask, labels):
crf = CRF(num_labels)
for _ in range(1000):
a = crf(hidden, labels, mask)
a.mean().backward()

Traceback (most recent call last):
File "/media/jdd/d/py_proj/events/event_distribute4/torchcrf.py", line 38, in
cProfile.run('myCRF(hidden, mask, labels)')
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 16, in run
return _pyprofile._Utils(Profile).run(statement, filename, sort)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/profile.py", line 55, in run
prof.run(statement)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 95, in run
return self.runctx(cmd, dict, dict)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 100, in runctx
exec(cmd, globals, locals)
File "", line 1, in
File "/media/jdd/d/py_proj/events/event_distribute4/torchcrf.py", line 35, in myCRF
a = crf(hidden, labels, mask)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 49, in forward
log_numerator = self._compute_numerator_log_likelihood(h, labels, mask)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 206, in _compute_numerator_log_likelihood
) for t in range(calc_range)])
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 206, in
) for t in range(calc_range)])
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 257, in _calc_trans_score_for_num_llh
return h_t * mask_t + trans_t * mask_t1
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float

Support for multiple GPUs

This library does not support multiple gpu yet.

I know CRF is tiny in memory size, but it can always been used as the top layer on other larger models.

There is an error in the annotation

    def forward(
        self, h: FloatTensor, labels: LongTensor, mask: BoolTensor
    ) -> FloatTensor:
        """
        :param h: hidden matrix (seq_len, batch_size, num_labels)
        :param labels: answer labels of each sequence
                       in mini batch (seq_len, batch_size)
        :param mask: mask tensor of each sequence
                     in mini batch (seq_len, batch_size)
        :return: The log-likelihood (batch_size)
        """

In the annotation of this function, the shape of param h\labels\mask should be (batch_size,seq_len,)

Performance issue

I greatly appreciated your work, both for its simplicity of use and for your commitment. I'm probably wrong, but the library is very slow to use compared to other packages that do the same job.

I checked and all tensor operations are performed on the GPU (GTX 1070).
The TQDM library estimates an iteration every two seconds during training but the waiting time is 2 hours per epoch. Using other libraries for the same model I get a waiting time of 15 minutes per epoch.

I can assure you that the mask, the CRF layer are run on GPU.

I also tried to force methods with to (device) but obviously nothing has changed.
self.crflayer = CRF(hparams.num_classes, pad_idx=0).to(device) self.model.crflayer.forward(outputs, goldLabels, mask).to(device)

How can I use TorchCRF with huggingface transformers?

When I try to implement this package on top of a bert model, I got an 'IndexError' due to padding token id.I think this is caused by subtokens in an entity that have -100 value by default. Should I eliminate subtokens for crf layer, If so how can I do it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.