Comments (9)
when you said you dumping the optimized HLO IR
do you use torch.export
? I believe if you use export @qihqi and @lsy323 figured out how to do the parameter mapping.
from xla.
Thanks for your reply!
I dump the HloModule by setting the environment variables:
XLA_FLAGS="--xla_dump_to=mlp_hlo_graph \
--xla_dump_hlo_as_text \
--xla_hlo_graph_addresses \
--xla_eliminate_hlo_implicit_broadcast \
--xla_dump_hlo_as_proto \
--xla_hlo_graph_sharding_color \
"
For example, in the txt file, the first line is:
HloModule SyncTensorsGraph.51, is_scheduled=true, entry_computation_layout={(s64[2]{0}, s64[2,16]{1,0}, s64[], s64[16]{0}, s64[16,16]{1,0}, /*index=5*/s64[16]{0}, s64[16,4]{1,0}, s64[4]{0})->(s64[2]{0})}
I want to know which tensor (in the PyTorch model, input, or somewhere else) corresponds to which parameter in the entry_computation_layout.
from xla.
yea if you dump with XLA_FLAGS
there is no easy way to do the parameter mapping from torch tensor.
I can tell you how it works through. In Pytorch/XLA we first identified a set of tensor that has pending IR and make them as the output of the HLO. And then we run the post order traversal from these output and find all of the input XLAData and we used them to construct the HLO and execute the HLO.
In this process we don't really know which XLAData is corresponding to which pytorch python tensor, the Pytorch python tensor might not exist anymore. For example
a = torch.tensor(100, device=torch_xla.device())
b = a + 5
a = b + 3
To calculate b
you need the origional value of the a, but now a points to the different data now.
from xla.
I find a way to output all the parameter values in the entry_computation_layout
.
I used the handle to identify the parameter. Handle can be dumped in the ir_builder.h file, inside the MakeDeviceData(for non-scalar tensor), and for scalar tensor, I dumped the handle in file torch_xla/csrc/ops/device_data.cpp: DeviceData::DeviceData() function.
And to dump the tensor data, in file torch_xla/csrc/tensor_util.cpp: TensorToXlaData function, I dump the at::Tensor, and can link them to the handle I dumped in the above way.
from xla.
Because the new backend is for cryptography use, I have to dump all the intermediate values including the input. So for each intermediate value in the Hlo IR module, do you know if there is a good way to get its value?
from xla.
The only way I am aware of is to set all of them to be the output, otherwise compiler is allowed to optimize the intermediate value again, which is an important optimization..
from xla.
So how to set them to be the output? Is there an option setting for that?
from xla.
I am not aware of a way. It works like you first specified the output, and the pytorch/xla will generate the HLO for you. You can use
xla/torch_xla/core/xla_model.py
Lines 1227 to 1228 in ea2a6f7
to specify the output tensor. However in our lowering, each torch op maps to multiple HLO. for example
y = torch.longsumexp(x)
the log_sum_exp
will get lowered into a sequence of HLO ops. I am not aware of a way to output the value for all of those intermediate HLO ops.
from xla.
OK. Thanks! Because I have to get both the input and the output for non-linear ops like exp, div, etc, I'll have to find a way to dump for each of them.
from xla.
Related Issues (20)
- Equivalent of get_worker_info to split an IterableDataset HOT 18
- Is there any way to directly execute the cached computational graph HOT 5
- Op info test for `T .. arange` HOT 1
- CUDA and GPU-Flavoured Docker/Container Image Missing CUDA Support HOT 1
- Graph dump to optimize HOT 9
- Invalid version identifier in filenames of nightly builds HOT 6
- How to test on a subset of TPUs in a TPU Pod HOT 7
- Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) HOT 1
- Iteration of MpDeviceLoader doesn't work HOT 1
- Improve device auto-detection HOT 2
- libtpu not installed with nightly build HOT 4
- PyTorch/XLA usability progress tracking
- inconsistency in calling `get_ordinal` and `world_size` calls HOT 2
- Effectively manage API usability changes
- Make `torch_xla.launch` work transparently in notebooks
- Support portable executables in `torch_xla.launch`
- `xmp.spawn(_mp_fn, nprocs=1)` failure HOT 4
- Device init before `xmp.spawn()` HOT 3
- Does PyTorch/XLA nightly provide GPU support? HOT 3
- introduce torch.tpu.is_available() HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xla.