Giter VIP home page Giter VIP logo

onlineminingtripletloss's People

Contributors

dependabot[bot] avatar f-fl0 avatar negation avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

onlineminingtripletloss's Issues

Problem with deterministic indexing

Hello.

Pytorch have unsolved issue with indexing if detereministic behavior is on: pytorch/pytorch#61032

Therefore the next minimal code raises the error:

Code:

import os
import torch
import torch.nn.functional as F
from online_triplet_loss.losses import batch_hard_triplet_loss
from torch import nn

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ':4096:8'
print('pytorch threads ', torch.get_num_threads())
print('version', torch.version.__version__)

torch.set_deterministic(True)

dev = torch.device('cuda')

model = nn.Embedding(1000, 512).to(dev)
labels = torch.randint(high=1000, size=(256,)).to(dev)  # our five labels

embeddings = F.normalize(model(labels))
print('Labels:', labels.cpu())
print('Embeddings:', embeddings.cpu())
loss = batch_hard_triplet_loss(labels, embeddings, margin=0.5, device=dev)
print('Loss:', loss.cpu())
loss.backward()

Output:

pytorch threads  6
version 1.9.0+cu111
/home/alex/.local/lib/python3.6/site-packages/torch/__init__.py:472: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead
  "torch.set_deterministic is deprecated and will be removed in a future "
Labels: tensor([928, 203, 388, 709, 982, 593, 339, 820, 264, 838,  69, 136, 571, 493,
        145,  38, 650, 424,  27, 864, 766, 419, 764, 199, 487, 487, 217, 669,
        103, 248, 770, 889, 574, 983, 384, 307, 188, 947, 215, 339, 889, 133,
        640, 541, 457, 798, 969, 107, 426, 294, 823, 875, 877,  33, 522, 834,
        244, 372, 289,  23, 647, 543, 832,  47, 370, 278, 997, 800, 240, 224,
        365, 548,  19, 132, 296, 730, 105, 722, 444, 643, 598, 477, 655, 753,
        644, 223, 265, 300, 381,  96, 157, 428, 882, 414, 307, 258, 539, 440,
        448, 697, 674, 507, 379, 630, 817, 441, 363, 333, 193,  35, 261, 237,
        496,  22, 442, 184, 906, 798, 934, 554, 228, 431, 491, 876, 455, 341,
        405, 738, 761, 471, 399, 949, 325,  74, 100, 530, 136, 981, 239, 924,
        940, 656,  46, 618,  43,   7, 327,  30, 357, 381, 514, 290, 273, 268,
        329, 328, 175, 610, 706, 266, 540, 789, 359, 722,  11, 268, 529, 649,
        648, 868, 271, 399,  42,  19, 610, 944, 547, 772, 104, 550, 427, 348,
        927, 553, 865,  39, 243, 878, 547, 179, 123, 335, 470, 115, 380, 885,
        646, 166, 765, 697, 169, 664, 837, 527, 776, 255, 177, 702, 123, 157,
        731,  16,  29, 278, 294, 949, 772, 148, 788, 552, 177, 174, 838, 350,
        491, 453, 292, 692, 525, 393, 704, 599, 226, 269, 939,  26, 459, 458,
        589, 463, 352, 372,  69, 210,  91, 732, 401, 461, 188, 569, 987, 847,
        847, 689, 616, 193])
Embeddings: tensor([[ 0.0209, -0.0439, -0.0527,  ..., -0.1167,  0.0228, -0.0133],
        [ 0.0438, -0.0272, -0.0811,  ..., -0.0136, -0.0501, -0.0121],
        [-0.1156,  0.0446, -0.0168,  ...,  0.0205,  0.0586, -0.0488],
        ...,
        [ 0.0235, -0.0020, -0.0030,  ...,  0.0379,  0.1423, -0.0285],
        [-0.0275,  0.0005,  0.0295,  ...,  0.0124, -0.0216, -0.0346],
        [ 0.0283, -0.0131,  0.0502,  ..., -0.0324,  0.0144, -0.0048]],
       grad_fn=<CopyBackwards>)
Traceback (most recent call last):
  File "/snap/pycharm-community/248/plugins/python-ce/helpers/pydev/pydevd.py", line 1483, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/snap/pycharm-community/248/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/mnt/sdb/baga.py", line 21, in <module>
    loss = batch_hard_triplet_loss(labels, embeddings, margin=0.5, device=dev)
  File "/home/x/.local/lib/python3.6/site-packages/online_triplet_loss/losses.py", line 143, in batch_hard_triplet_loss
    tl[tl < 0] = 0
RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Indexing.cu":253, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor2561

The error does not occur if torch.set_deterministic(False) but I need reproducibility.

The issue may be fixed with use of ReLU instead of tl[tl < 0] = 0

License ?

Hi Joakim,

This is a great implementation of batch-hard and batch-all triplet sampling, and I would like to use it for my work and even contribute to it. In this regard, is it possible for you to add a license to this project ?

Rohit

hard triplet convergence

I use triple loss between data of two modalities to reduce the distance between different modalities of the same class and increase the distance between different modalities of different class. But when I use batch_all loss, the valid set loss has not changed; now using hard_loss, the valid set loss still has not changed. What is the reason? I found some answers that triplet is difficult to converge. What do you do to deal with triplet loss convergence?

Example usage is confusing

From the README, the example usage is given as

from triplet_loss import batch_hard_triplet_loss

labels = torch.randint(5) # our five labels

embeddings = model(labels)

loss = batch_hard_triplet_loss(labels, embeddings, margin=0.2)
loss.backward()
# and so on

Is this repository oriented around models which produce embeddings of labels? Or is the line model(labels) purely evocative?

It seems that a more standard usage would be to provide some kind of data (such as an image) and receive embeddings; then the triplet loss operates on the embeddings and labels together to determine the loss.

RuntimeError while using batch_hard_triplet_loss()

I'm trying to use batch_hard_triplet_loss() as my loss function but I keep getting the following error:

Screenshot from 2023-02-12 01-02-14

The embedding shape is (batch_size, 1000)

I tried different batch sizes starting from 2^4 up to 2^15 and I keep getting the same error, and I tried using GPU and also tried working without but still getting the same error, could you please help me with that?

You can find the full code here

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.