Comments (4)
So I got this more complete example to work where I apply the soft_sort
operator on a parameter of a nn.Module
on a cuda device. I don't know how general my approach is and whether it reliably solves the issue for all scenarios but I figured I would share it anyway:
import torch
import torch.nn as nn
from fast_soft_sort import pytorch_ops
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter(torch.tensor([[0.1, 0.7, 0.2]]))
def forward(self, x):
w_sorted = pytorch_ops.soft_sort(self.weight.cpu()).cuda()
return 2.0 * w_sorted * x
net = Net()
net = net.cuda()
x = torch.tensor([[4.0, 5.0, 6.0]]).cuda()
y = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64).cuda()
yhat = net(x)
loss = nn.functional.mse_loss(yhat, y)
loss.backward()
print("Loss = {:.2f}".format(loss))
print(f"Grad = {net.weight.grad}")
Still curious to see whether there is a better way to handle that internally.
from fast-soft-sort.
At the moment we do not have a GPU implementation of the projection operators, which is the cause for the error. We decided not to do this conversion implicitly as we want the user to be aware that a device copy is necessary. If you want that behavior, can you write a small util function like
def soft_sort(array):
return pytorch_ops.soft_sort(array.cpu()).cuda()
and then use it as a plug-in replacement. Would that work?
from fast-soft-sort.
Thank you for your help! Yes, it does work indeed. :)
from fast-soft-sort.
I have been working on a pure PyTorch implementation here: https://github.com/teddykoker/torchsort, complete with the isotonic regression code written as a C++ extension. The CPU implementation is much faster (see benchmarks), and I am working on the CUDA implementation which should be done soon.
from fast-soft-sort.
Related Issues (20)
- Feature Request: Configurable comparator function HOT 3
- No gradient flow for PyTorch Soft Rank HOT 1
- could not run in eager mode? HOT 1
- Does it support fp16 training? HOT 2
- Does it make argsort differentiable too ?
- Wrong behavior with pytorch HOT 2
- Understanding soft-sorting HOT 3
- I tried to embed the soft rank into the pytorch model and loss function and the following error was reported.
- I tried to embed the soft rank into the pytorch model and loss function and the following error was reported.
- hi! Putting the tensor and model to GPU will not allow back propagation?
- How to use this layer with KerasTensor HOT 1
- Please publish as a pip package HOT 2
- torch.from_numpy throws error from pytorch_ops.backward
- Weighted Correlation Computation
- Make README.md more thorough
- Moving isotonic to thirdparty broke the python setup.py install HOT 1
- Zero Regularization strength should fall back to normal methods
- Gradients not backpropagated in Pytorch HOT 3
- Unable to jit jax ops HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fast-soft-sort.