Comments (6)
So for more context on this: it would be nice to be able to pass a torch.Tensor (or jax array) living on GPU directly to the selectors, instead of having to move the data back to main CPU memory.
A first pass would be to make sure all the function calls are compatible with PyTorch API, but given the high usage of Python for
loops in the selector code that might not give a lot of performance improvement. The second step would then be to rewrite the selector code to use more high-level operations & launch larger GPU kernels, and hopefully improve performance.
This is mostly unrelated to the autograd part of PyTorch, so even if we need to .detach()
the tensors before passing them, that would be fine with me. I would mostly like to be able to keep the data in GPU memory.
from scikit-matter.
My ideal user-facing interface for this would be to be able to do something like this:
import torch
from skmatter.feature_selection import CUR
X = torch.rand(300, 300, device="cuda") # or device="mps" on Apple M1/M2
selector = CUR(n_to_select=4)
selector.fit(X)
Xr = selector.transform(X)
# Xr is a torch tensor, with device=X.device
A first step for this would be to add a test trying to use skmatter with a torch tensor, and check where the code starts throwing errors.
Depending on the number of function call (e.g. np.sum
, …) that need to be updated, it might be interesting to use https://github.com/jcmgray/autoray to dispatch function calls to the right backend.
from scikit-matter.
This is put on the back burner for now, if you are interested in getting skmatter to run on GPU please voice your interest here!
from scikit-matter.
It looks like sklearn now has experimental support for PyTorch/CuPy (and thus GPU data) using the array API: https://scikit-learn.org/stable/modules/array_api.html. We could use the same here!
from scikit-matter.
We should experiment as well how the array api works with our selection methods. FPS is probably a good candidate because we do not use very complicated mathematical operations there. So hopefully there is not so much friction in making this work.
from scikit-matter.
More info on this array API in sklearn: https://labs.quansight.org/blog/array-api-support-scikit-learn.
from scikit-matter.
Related Issues (20)
- PCovR is not centering like PCA HOT 3
- Moving the paper-ore branch to a fork or another repo HOT 2
- Negative distances for fitted points with the DirectionalConvexHull HOT 1
- From docs it is not super clear that sample selection works analogously to feature selection HOT 2
- Move notebooks to sphinx gallery python scripts
- Interactive example of the 3d directional convex hull using chemiscope widget HOT 1
- Create a CONDA forge recipe HOT 3
- Tests are running slow HOT 2
- PCovR-WHODataset takes super long to compute HOT 3
- What should be number of characters/line HOT 2
- Give contributors more visibility HOT 1
- WHO dataset missing function call section in doc HOT 1
- Set up a doc formatter
- Still need a logo
- Implementation of local prediction rigidity HOT 2
- Small typo on PCovR documentation
- Consistent validation HOT 2
- Zero scores result in repeated selection and wrong scores at least for FPS
- Switch to sphinx doctest to avoid implicit import problem HOT 1
- Rank-one updates and other potential performance gains for CUR HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from scikit-matter.