Giter VIP home page Giter VIP logo

spacecutter's Introduction

spacecutter

spacecutter is a library for implementing ordinal regression models in PyTorch. The library consists of models and loss functions. It is recommended to use skorch to wrap the models to make them compatible with scikit-learn.

Installation

pip install spacecutter

Usage

Models

Define any PyTorch model you want that generates a single, scalar prediction value. This will be our predictor model. This model can then be wrapped with spacecutter.models.OrdinalLogisticModel which will convert the output of the predictor from a single number to an array of ordinal class probabilities. The following example shows how to do this for a two layer neural network predictor for a problem with three ordinal classes.

import numpy as np
import torch
from torch import nn

from spacecutter.models import OrdinalLogisticModel


X = np.array([[0.5, 0.1, -0.1],
              [1.0, 0.2, 0.6],
              [-2.0, 0.4, 0.8]],
             dtype=np.float32)

y = np.array([0, 1, 2]).reshape(-1, 1)

num_features = X.shape[1]
num_classes = len(np.unique(y))

predictor = nn.Sequential(
    nn.Linear(num_features, num_features),
    nn.ReLU(),
    nn.Linear(num_features, 1)
)

model = OrdinalLogisticModel(predictor, num_classes)

y_pred = model(torch.as_tensor(X))

print(y_pred)

# tensor([[0.2325, 0.2191, 0.5485],
#         [0.2324, 0.2191, 0.5485],
#         [0.2607, 0.2287, 0.5106]], grad_fn=<CatBackward>)

Training

It is recommended to use skorch to train spacecutter models. The following shows how to train the model from the previous section using cumulative link loss with skorch:

from skorch import NeuralNet

from spacecutter.callbacks import AscensionCallback
from spacecutter.losses import CumulativeLinkLoss

skorch_model = NeuralNet(
    module=OrdinalLogisticModel,
    module__predictor=predictor,
    module__num_classes=num_classes,
    criterion=CumulativeLinkLoss,
    train_split=None,
    callbacks=[
        ('ascension', AscensionCallback()),
    ],
)

skorch_model.fit(X, y)

Note that we must add the AscensionCallback. This ensures that the ordinal cutpoints stay in ascending order. While ideally this constraint would be factored directly into the model optimization, spacecutter currently hacks an SGD-compatible solution by utilizing a post-backwards-pass callback to clip the cutpoint values.

spacecutter's People

Contributors

ethanrosenthal avatar sumanthratna avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

spacecutter's Issues

Using TensorRT

Hello,
Firstly, thanks a lot for making this useful repo public.
My question is regarding the conversion of an OrdinalLogisticModel() wrapped model to its tensorRT equivalent.
I am unable to do it and I think it has something to do with the LogisticCumulativeLink at the end of the wrapper class.
Have you personally tried converting the model to tensorrt for faster inference?
If yes, would you be kind enough to let me know.
Thank you

why use deepcopy in models.py

it's strange for me, why you use deepcopy in there? I have tested the model can work without deepcopy, is there any reason for use it?

line 77: self.predictor = deepcopy(predictor)

hi ! I have a problem

Actually I dont want to use skorch (more flexible), so how should i change my training code? i dont know how to add callback to my train code ..... Thanks very much !!!

There's a problem with my program.

RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 1
I'm working on a six classification task[0, 1, 2, 3, 4, 5].I found that cutpoints = num_ class - 1 in the code。 But in this way, the dimensions of cutpoints and X are not equal. I've seen your example, but I still don't know how to solve it.
thank you!

Skorch supports torch 0.4.1 no more.

Hi! Thanks for your work, I think it's useful and inspiring, however, Skorch may support torch 0.4.1 no more. When I try to run 'from Skorch.callbacks import Callback, ProgressBar', I met 'ImportWarning: Skorch depends on a newer version of PyTorch (at least 1.1.0, not 0.4.1). Visit https://pytorch.org for installation details', may be spacecutter should upgrade torch version? Thanks a lot!

Under OrdinalModule

Hi could u help understand the below peace of code.
Why do we subtract the elements in linkmat ,then concatinating them .
Isnt just cutpoints-X is sufficient ?

sigmoids=cutpoints-X
 link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]
        link_mat = torch.cat((
                sigmoids[:, [0]],
                link_mat,
                (1 - sigmoids[:, [-1]])
            ),
            dim=1
  1. when does this AscensionCallback gets called up.. start of every batch,epoch,or end of batch or epoch

dtype error

Hi,

I am trying to train an OrdinalLogit model and have stripped down the model to this:

pred_dosages_tensor = torch.tensor(pred_dosages)
true_dosages_tensor = torch.tensor(true_dosages, dtype=torch.long)
predictor = torch.nn.Sequential()
num_classes = len(np.unique(true_dosages))

scaling = NeuralNet(
	  module=OrdinalLogisticModel,
	  module__predictor=predictor,
	  module__num_classes=num_classes,
	  criterion=CumulativeLinkLoss,
	  train_split=None,
	  callbacks=[
	      ('ascension', AscensionCallback()),
	  ],
)
scaling.fit(true_dosages_tensor, pred_dosages_tensor)

However, when I try to run this code, I get the following error:

File "/home/unix/ssadhuka/.conda/envs/shuvomenv/lib/python3.7/site-packages/spacecutter/losses.py", line 68, in cumulative_link_loss
    likelihoods = torch.clamp(torch.gather(y_pred, 1, y_true), eps, 1 - eps)
RuntimeError: gather_out_cpu(): Expected dtype int64 for index

EDIT: Resolved

Fix:
In losses.py, change
likelihoods = torch.clamp(torch.gather(y_pred, 1, y_true), eps, 1 - eps) to likelihoods = torch.clamp(torch.gather(y_pred, 1, y_true.to(torch.int64)), eps, 1 - eps)

Docs error?

Dear Torch friends

Perhaps I am missing something, but I think the docs of "reduction_" function is not correct. Shall it be?

def _reduction(loss: torch.Tensor, reduction: str) -> torch.Tensor:
    """
    Reduce loss
    Parameters
    ----------
    loss : torch.Tensor, [batch_size, 1]
        Batch losses.
    reduction : str
        Method for reducing the loss. Options include 'elementwise_mean',
        'none', and 'sum'.
    Returns
    -------
    loss : torch.Tensor
        Reduced loss.
    """
    if reduction == 'elementwise_mean':
        return loss.mean()
    elif reduction == 'none':
        return loss
    elif reduction == 'sum':
        return loss.sum()
    else:
        raise ValueError(f'{reduction} is not a valid reduction')

?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.