Giter VIP home page Giter VIP logo

sodeep's Introduction

SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates

Code associated with the paper SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates

This code contains the loss functions derived from the following metrics.
  • Spearman correlation
  • Mean average Precision
  • Recall

It also contains the code to train the approximation of the rank function (synthetic data generation, model architecture, training script).

Author and contact: Martin Engilberge

Main dependencies

This code is written in python. To use it you will need:

  • Python 3.7
  • Pytorch 1.1
  • Numpy
  • TensorboardX

Getting started

To be able to use the loss function, the first step is to train a sorter to approximate the ranking function.

python train.py

Once the training is finished, the sorter's checkpoints are stored in the weights folder. By default the model used in the paper (lstm_large) will be selected and the sequence length will be set to 100. More models are present in model.py and can be selected with the argument -m.

python train.py -m gruc -n model_gruc

The GRU based model were developed after the publication of the paper and might perform better. The sorter_exact model doesn't need to be trained and can be used as a reliable baseline.

By default the training scripts use gpu, you can switch to cpu mode by uncommenting device = torch.device("cpu") at the beginning of the script.

Using the loss function

Once you have trained a sorter or if you decided to use the algorithmic one (sorter_exact) you can use the loss functions.

There are four losses

SpearmanLoss(sorter_type, seq_len=None, sorter_state_dict=None)
MapRankingLoss(sorter_type, seq_len=None, sorter_state_dict=None)
#MultiModal rank based Loss
RankLoss(sorter_type, seq_len=None, sorter_state_dict=None)
#Hard negative MultiModal rank based Loss
RankHardLoss(sorter_type, seq_len=None, sorter_state_dict=None, margin=0.2)

Each loss function can take three main arguments:

  • sorter_type: the model of sorter used
  • seq_len: the length of sequence the sorter has been trained on
  • sorter_state_dict: The state dict containing the weights of the sorter

The function load_sorter is provided to load the required argument all at once from a sorter checkpoints.

import sys
sys.path.append("/path/to/sodeep/folder/")
from sodeep import load_sorter, SpearmanLoss

criterion = SpearmanLoss(*load_sorter(sorter_checkpoint_path))
criterion.to(device)

The syntax to use the programed sorter would be as follows:

import sys
sys.path.append("/path/to/sodeep/folder/")
from sodeep import SpearmanLoss

criterion = SpearmanLoss("exa")
criterion.to(device)

The model.py file also contains an UpdatingWrapper which can be used to update the sorter on real data while it is used in a loss. For stability reason it might be necessary to use the proposed loss in combination with another loss.

On some regression task we noticed that initialization with an L1 loss for a couple of epochs was required before using the SpearmanLoss.

Reference

If you found this code useful, please cite the following paper:

@inproceedings{engilberge2019sodeep,
	title={SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates},
	author={Engilberge, Martin and Chevallier, Louis and P{\'e}rez, Patrick and Cord, Matthieu},
	booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
	year={2019}
}

License

This project is licensed under the terms of BSD 3-clause Clear license. by downloading this program, you commit to comply with the license as stated in the LICENSE.md file.

sodeep's People

Contributors

tchadmin avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.