Comments (5)
This is something that I said in the May meeting (I think!) but I want to make sure it's durably recorded here: if for some reason you wanted to do CPU/CUDA computation with XLA, I think it would be substantially more idiomatic and natural for the alternate backend to be able to operate directly on traditional CPU/CUDA tensors.
This presents a UX tension with XLATensor which holds onto a jax.Array object is a plausible point in the design space, but when that jax.Array is CPU/CUDA, it's duplicative with traditional CPU/CUDA tensors... but also not really because if you have a CPU tensor that internally holds a jax.Array, you suddenly get more expressivity because of the interoperability with JAX thing (but not really, e.g., for the point @Chillee raised that is mentioned here.)
So, this is what I'm hoping to see:
- XLATensor is a Python tensor subclass which wraps a jax.Array. In the limit, it is a full, eager-mode compatible translation layer that translates PyTorch API calls into equivalent JAX API calls, if you have some PyTorch code and you want to jax.jit it, as long as that code works when passed XLATensors instead it should work. Things like jax.grad would not work with backward hooks, but this is simply "as designed" (and will need to be emphasized in user documentation--it's worth noting that the interaction here is pretty similar to the interaction of PyTorch autograd and functorch grad, cc @zou3519).
- The XLA Dynamo backend will promote plain CPU/CUDA tensors to XLATensors with relatively little runtime overhead (in particular, it shouldn't be necessary to copy the tensors into XLA's workspace)
- Dynamo can deal with passed in XLATensor ala Option 2. They can be handled specially similarly to FakeTensor.
from xla.
I'm interested in hearing more about the requirements and plan for interop with JAX transforms.
- Do you want jax transforms to work over the following function? (this seems reasonable to do even without torch.compile because when doing JAX tracing, XLATensor2 will desugar into JAX ops).
- Should torch.compile work over the following function?
- can we mix and match torch.compile and JAX transforms?
def f(jax_array_1, jax_array_2):
wraps jax_array_1, jax_array_2 into XLATensor2
call torch
return unwraped
from xla.
Relevant issue when trying to trace through XLATensor2
without using traceable_tensor_subclasses
: pytorch/pytorch#128160.
Discussed offline that it's not ideal to require to using traceable_tensor_subclasses
for this case.
from xla.
Hi @ezyang, totally agree on the 3 points listed. For promoting CUDA tensors to XLATensor seems simple with dlpack. We'll pursue this direction.
from xla.
I'm interested in hearing more about the requirements and plan for interop with JAX transforms.
- Do you want jax transforms to work over the following function? (this seems reasonable to do even without torch.compile because when doing JAX tracing, XLATensor2 will desugar into JAX ops).
- Should torch.compile work over the following function?
- can we mix and match torch.compile and JAX transforms?
def f(jax_array_1, jax_array_2): wraps jax_array_1, jax_array_2 into XLATensor2 call torch return unwraped
Yes: jax transforms should work on this function.
can we mix and match torch.compile and JAX transforms?
I never dared to wish for this to be honest. Right now, if I call a jax function from a torch program usingcall_jax
I don't expect dynamo to be able to handle it.
Although, if we somehow make this call_jax
into a custom_op (a la https://colab.sandbox.google.com/drive/1xCh5BNHxGnutqGLMHaHwm47cbDL9CB1g) it should just work? With the issue that higher order ops (since call_jax
takes callables as input) cannot use this custom op API but need to use the HigherOrderOps
which IS traced through by dynamo. Presumably, dynamo needs to trace to figure out the returned dtype and shape. So another approach would be make HigherOrderOps not traced by have the implementer declare returned shape / dtype.
from xla.
Related Issues (20)
- [API Usability] Deprecate `get_local_ordinal`
- [API Usability] Deprecate xla_real_devices
- [API Usability] Deprecate `xla_device_hw`
- [API Usability] Delete `unlazy`
- [API Usability] Internalize `RateTracker`
- [API Usability] Internalize `ToXlaTensorArena`
- [API Usability] Delete `check_view_sharing`
- [API Usability] Internalize `reduce_gradients`
- [Fori Loop] Inconsistent Shape Behavior HOT 2
- Equivalent of get_worker_info to split an IterableDataset HOT 18
- Is there any way to directly execute the cached computational graph HOT 5
- Op info test for `T .. arange` HOT 1
- CUDA and GPU-Flavoured Docker/Container Image Missing CUDA Support HOT 1
- Graph dump to optimize HOT 9
- Invalid version identifier in filenames of nightly builds HOT 6
- How to test on a subset of TPUs in a TPU Pod HOT 7
- Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) HOT 1
- Iteration of MpDeviceLoader doesn't work HOT 1
- Improve device auto-detection HOT 2
- libtpu not installed with nightly build HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xla.