Comments (5)
technically there is, you can look at our dynamo implementation where we
- execute the tracing
xla/torch_xla/core/dynamo_bridge.py
Line 337 in 08e63e3
- compute the hash + warm up the cache(compilation)
xla/torch_xla/core/dynamo_bridge.py
Lines 395 to 399 in 08e63e3
- execute the hash with input
xla/torch_xla/core/dynamo_bridge.py
Line 497 in 08e63e3
Dynamo is suppose to do what you expected, it handles the input ordering, output ordering, functionization of the graph etc. If you uses these api directly you need to be very careful.
from xla.
Thank you very much for your answer. I have successfully run the forward calculation of the model according to your tips and referring to this ut, https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/test/dynamo/test_graph_input_matcher.py
but I do not know how to add the backward calculation and the optimizer state update?
from xla.
technically you can do
loss = fwd(input)
loss.backward()
optimizer.step()
graph_hash = torch_xla._XLAC._get_graph_hash([loss] + [all_parameter_gradient])
From xla perspective there is not fwd and bwd, you just need to pass all of the output(int this case gradients) it will use those as root to construct the whole graph.
from xla.
Thank you very much for your reply.I went through the whole process, but I found that the parameters were not updated, resulting in the same loss(in my case is res[0]
). I constructed a minimal single test that can reproduce this problem. @JackCaoG
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree
class M(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 3)
def forward(self, x):
return self.linear(x)
def get_example_inputs(self):
return (torch.rand(10, 5),)
xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
model.get_example_inputs())
xm.mark_step()
args_tensor_ids = [
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
output = model(*inputs).sum()
output.backward()
found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
optimizer.step(found_inf=found_inf)
opt_state = []
for name, p in model.named_parameters():
if p.grad is not None:
opt_state.append(p)
opt_state.append(p.grad)
else:
print(name, "no grad")
output_list = [output] + opt_state
xla_graph_hash = torch_xla._XLAC._get_graph_hash(output_list)
torch_xla._XLAC._xla_warm_up_cache(output_list, [])
(
graph_input_tensor_ids,
graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(output_list)
xla_args_tensor_ids = set(
tree_map_only(torch.Tensor,
lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
inputs))
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
graph_input_tensor_ids,
graph_input_xla_values,
xla_args_tensor_ids)
for i in range(3):
graph_input = graph_input_matcher(inputs)
res = torch_xla._XLAC._run_cached_graph(xla_graph_hash, graph_input)
print(res[0])
I think the code below is logically the same as the code above, but the loss will change for the code below, but not for the code above
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree
class M(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 3)
def forward(self, x):
return self.linear(x)
def get_example_inputs(self):
return (torch.rand(10, 5),)
xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
model.get_example_inputs())
xm.mark_step()
args_tensor_ids = [
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
for i in range(3):
output = model(*inputs).sum()
output.backward()
found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
optimizer.step(found_inf=found_inf)
optimizer.zero_grad()
xm.mark_step()
print("debug ", output)
from xla.
@JackCaoG @dewitt @sprt @ezyang
Hello, I have located the root cause of the problem in the above single test, because only the placeholder was assigned, but the parameters were not assigned, so although new_param was calculated, the parameters did not change.
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L816-L817
But I don't know how to fix this problem. Could you give me some ideas?
from xla.
Related Issues (20)
- [API Usability] Internalize `ToXlaTensorArena`
- [API Usability] Delete `check_view_sharing`
- [API Usability] Internalize `reduce_gradients`
- [Fori Loop] Inconsistent Shape Behavior HOT 2
- Equivalent of get_worker_info to split an IterableDataset HOT 18
- Op info test for `T .. arange` HOT 1
- CUDA and GPU-Flavoured Docker/Container Image Missing CUDA Support HOT 1
- Graph dump to optimize HOT 9
- Invalid version identifier in filenames of nightly builds HOT 6
- How to test on a subset of TPUs in a TPU Pod HOT 7
- Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) HOT 1
- Iteration of MpDeviceLoader doesn't work HOT 1
- Improve device auto-detection HOT 2
- libtpu not installed with nightly build HOT 4
- PyTorch/XLA usability progress tracking
- inconsistency in calling `get_ordinal` and `world_size` calls HOT 2
- Effectively manage API usability changes
- Make `torch_xla.launch` work transparently in notebooks
- Support portable executables in `torch_xla.launch`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xla.