Comments (8)
Here's how I think we could use propagate_xla_data
for solving this problem. Note that this is not a solution, but an initial idea. In summary, whenever it's called inside the dispatch of an in-place operation, we would need to:
- Get the original XLA tensor (the one with
alias_id == unique_id
) that holds the original shared buffer- Doable, since we already keep track of these ids
- Check whether it indeed shares the buffer (e.g. was created by the DLPack API)
- We could add a flag to
XLATensor
for that
- We could add a flag to
- Use the DLPack API for creating a CUDA tensor out of the XLA tensor
- Copy the contents of the new output to the storage of the CUDA tensor
This, however, won't work. Once we call torch._sync
on the original XLA tensor, we will run the in-place operation again, which might give incorrect results.
On another note, we could use this (propagate_xla_data
) for warning the user that they are not really modifying the underlying storage. Basically, check whether the tensor we are calling the in-place operation on shares storage (again, with a new XLATensor
flag).
from xla.
This behavior should be the results of our functionalization pass. @alanwaketan to confirm the expected behavior. Either way, let's have a dlpack
documentation/tutorial that goes through example use cases and fully explains correct behavior @ysiraichi.
from xla.
Thanks for the issue. I checked buffer pointer at more places:
>>> t0 = torch.arange(10, device=xm.xla_device())
>>> xm.mark_step(wait=True)
>>>
>>> capsule = xdlpack.to_dlpack(t0)
>>> t1 = xdlpack.from_dlpack(capsule)
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
True
>>>
>>> t0[0] = 100
>>> xm.mark_step()
>>>
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
True
>>> print(t0.eq(t1).all().item())
False
>>>
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
False
Could you elaborate on That's because even though functionalization emulates views and mutation, PyTorch/XLA doesn't really have the concept of views and can't mutate a given tensor.
? Do you mean when we do t0[0]=100
, the underlying pjrt buffer is not mutated hence t1
is not updated, even though t0 and t1 share the same storage? Let me also look into what torch_xla does when we do t0[0]=100
from xla.
Yes, exactly. In summary, functionalized lazy tensors is composed of:
Tensor(
impl=FunctionalTensorWrapper(
value=Tensor(
impl=XLATensorImpl(
tensor=XLATensor(handle or tensor_data or ir_value)
)
)
)
)
Suppose t0
and t1
share the same storage using the DLPack API. Whenever an in-place operation is called, e.g. t0.add_(1)
, the functionalization layer actually calls the functional variant (XLANativeFunctions::add
), which generates a new XLATensor
. Later, that is wrapped by a new FunctionalTensorWrapper
(let's call it temp
). In the end, the functionalization layer replaces the FunctionalTensorWrapper::value
of t0
by the one inside temp
. Thus, t0
ends up with the updated value, while t1
remains with the old one.
from xla.
Try this: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L2703
from xla.
Hmm. Not sure I get it. Could you explain a bit more?
from xla.
That's a helper where we can bridge information through intermediate tensors created by functionalization for in-place ops.
from xla.
When we do the in-place op t0[0] = 100
, I see XLANativeFunctions::_propagate_xla_data
invoked twice by:
- at::functionalization::fill__Scalar
- at::functionalization::copy_
in sequence. So it seems the helper is already being used?
from xla.
Related Issues (20)
- [API Usability] Delete `check_view_sharing`
- [API Usability] Internalize `reduce_gradients`
- [Fori Loop] Inconsistent Shape Behavior HOT 2
- Equivalent of get_worker_info to split an IterableDataset HOT 18
- Is there any way to directly execute the cached computational graph HOT 5
- Op info test for `T .. arange` HOT 1
- CUDA and GPU-Flavoured Docker/Container Image Missing CUDA Support HOT 1
- Graph dump to optimize HOT 9
- Invalid version identifier in filenames of nightly builds HOT 6
- How to test on a subset of TPUs in a TPU Pod HOT 7
- Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) HOT 1
- Iteration of MpDeviceLoader doesn't work HOT 1
- Improve device auto-detection HOT 2
- libtpu not installed with nightly build HOT 4
- PyTorch/XLA usability progress tracking
- inconsistency in calling `get_ordinal` and `world_size` calls HOT 2
- Effectively manage API usability changes
- Make `torch_xla.launch` work transparently in notebooks
- Support portable executables in `torch_xla.launch`
- `xmp.spawn(_mp_fn, nprocs=1)` failure HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xla.