Giter VIP home page Giter VIP logo

Comments (5)

JackCaoG avatar JackCaoG commented on July 30, 2024

technically there is, you can look at our dynamo implementation where we

  1. execute the tracing
    xla_out = xla_model(*xla_args)
  2. compute the hash + warm up the cache(compilation)
    graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
    if dynamo_debug:
    print("Graph Hash: ", graph_hash)
    # compiles and cache graph rooted at tensors in 'args_and_out'
    torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
  3. execute the hash with input
    res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)

Dynamo is suppose to do what you expected, it handles the input ordering, output ordering, functionization of the graph etc. If you uses these api directly you need to be very careful.

from xla.

mars1248 avatar mars1248 commented on July 30, 2024

Thank you very much for your answer. I have successfully run the forward calculation of the model according to your tips and referring to this ut, https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/test/dynamo/test_graph_input_matcher.py
but I do not know how to add the backward calculation and the optimizer state update?

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

technically you can do

loss = fwd(input)
loss.backward()
optimizer.step()
graph_hash = torch_xla._XLAC._get_graph_hash([loss] + [all_parameter_gradient]) 

From xla perspective there is not fwd and bwd, you just need to pass all of the output(int this case gradients) it will use those as root to construct the whole graph.

from xla.

mars1248 avatar mars1248 commented on July 30, 2024

Thank you very much for your reply.I went through the whole process, but I found that the parameters were not updated, resulting in the same loss(in my case is res[0]). I constructed a minimal single test that can reproduce this problem. @JackCaoG


import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree

class M(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(5, 3)

  def forward(self, x):
    return self.linear(x)

  def get_example_inputs(self):
    return (torch.rand(10, 5),)

xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
                        model.get_example_inputs())

xm.mark_step()
args_tensor_ids = [
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
    tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
output = model(*inputs).sum()
output.backward()
found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
optimizer.step(found_inf=found_inf)
opt_state = []
for name, p in model.named_parameters():
    if p.grad is not None:
        opt_state.append(p)
        opt_state.append(p.grad)
    else:
        print(name, "no grad")
output_list = [output] + opt_state
xla_graph_hash = torch_xla._XLAC._get_graph_hash(output_list)
torch_xla._XLAC._xla_warm_up_cache(output_list, [])
(
    graph_input_tensor_ids,
    graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(output_list)
xla_args_tensor_ids = set(
    tree_map_only(torch.Tensor,
                    lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
                    inputs))
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
                                        graph_input_tensor_ids,
                                        graph_input_xla_values,
                                        xla_args_tensor_ids)
for i in range(3):
    graph_input = graph_input_matcher(inputs)
    res = torch_xla._XLAC._run_cached_graph(xla_graph_hash, graph_input)
    print(res[0])

I think the code below is logically the same as the code above, but the loss will change for the code below, but not for the code above

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree

class M(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(5, 3)

  def forward(self, x):
    return self.linear(x)

  def get_example_inputs(self):
    return (torch.rand(10, 5),)

xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
                        model.get_example_inputs())

xm.mark_step()
args_tensor_ids = [
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
    tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
for i in range(3):
    output = model(*inputs).sum()
    output.backward()
    found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
    optimizer.step(found_inf=found_inf)
    optimizer.zero_grad()
    xm.mark_step()
    print("debug ", output)

from xla.

mars1248 avatar mars1248 commented on July 30, 2024

@JackCaoG @dewitt @sprt @ezyang
Hello, I have located the root cause of the problem in the above single test, because only the placeholder was assigned, but the parameters were not assigned, so although new_param was calculated, the parameters did not change.
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L816-L817
But I don't know how to fix this problem. Could you give me some ideas?

from xla.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.