import torch
batch_size = 2
sequence_size = 3
num_labels = 5
labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).cuda() # (batch_size, sequence_size)
hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).cuda()
from TorchCRF import CRF
mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]).cuda() # (batch_size. sequence_size)
def myCRF(hidden, mask, labels):
crf = CRF(num_labels)
for _ in range(1000):
a = crf(hidden, labels, mask)
a.mean().backward()
Traceback (most recent call last):
File "/media/jdd/d/py_proj/events/event_distribute4/torchcrf.py", line 38, in
cProfile.run('myCRF(hidden, mask, labels)')
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 16, in run
return _pyprofile._Utils(Profile).run(statement, filename, sort)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/profile.py", line 55, in run
prof.run(statement)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 95, in run
return self.runctx(cmd, dict, dict)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/cProfile.py", line 100, in runctx
exec(cmd, globals, locals)
File "", line 1, in
File "/media/jdd/d/py_proj/events/event_distribute4/torchcrf.py", line 35, in myCRF
a = crf(hidden, labels, mask)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 49, in forward
log_numerator = self._compute_numerator_log_likelihood(h, labels, mask)
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 206, in _compute_numerator_log_likelihood
) for t in range(calc_range)])
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 206, in
) for t in range(calc_range)])
File "/home/jdd/.conda/envs/py3.6.8/lib/python3.6/site-packages/TorchCRF/init.py", line 257, in _calc_trans_score_for_num_llh
return h_t * mask_t + trans_t * mask_t1
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float