Giter VIP home page Giter VIP logo

Comments (6)

riccardofelluga avatar riccardofelluga commented on June 2, 2024 1

After further inspection @IvanYashchuk I still think this are two sightly different things. In the PR the output is a ready to run python code for the fusion and the method you explained allows to get similar information but missing the stride and the code to run the fusions. However, I agree with you on that it might be better to move the code from that PR to the examine.py file. Say, making something similar to get_fusion_symbols but that returns the repro code for the fusions, what do you think?

Even better, I think I can get the information about the inputs from the trace using your technique, eliminating the need to modify the nvfuserex_impl.py file at all.

from lightning-thunder.

IvanYashchuk avatar IvanYashchuk commented on June 2, 2024

A nvFuser fusion region in the Thunder trace is represented as a BoundSymbol. You can get all bound symbols by accessing TraceCtx.bound_symbols and for all bound symbols, you can access their inputs with BoundSymbol.args, BoundSymbol.kwargs, or BoundSymbol.flat_args.
All nvFuser bound symbols have BoundSymbol.is_fusion set to True and their name starts with nvFusion, you can filter out all other symbols with a simple list comprehension

nvfuser_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.name.startswith("nvFusion")]

There's also thunder.examine.get_fusion_symbols function that does the same using is_fusion

def get_fusion_symbols(trace: TraceCtx, warn_if_fusions_unavailable: bool = True) -> list[BoundSymbol]:

Here's an example session using bound symbols info to retrieve information on inputs:

In [1]: import torch

In [2]: import thunder

In [3]: @thunder.jit
   ...: def func(x):
   ...:     t1 = thunder.prims.var(x, (0, 1), correction=1)
   ...:     t2 = thunder.prims.add(t1, t1)
   ...:     return t2
   ...: 

In [4]: x = torch.randn(512, 512, device="cuda")

In [5]: out = func(x)

In [6]: thunder.last_traces(func)[-1]
Out[6]: 
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cuda:0 f32[512, 512]"
  [t2] = nvFusion0(x)
    # t1 = prims.var(x, (0, 1), correction=1)  # t1: "cuda:0 f32[]"
    # t2 = prims.add(t1, t1)  # t2: "cuda:0 f32[]"
  del x
  return t2

In [7]: import thunder.examine

In [8]: thunder.examine.get_fusion_symbols(thunder.last_traces(func)[-1])
Out[8]: 
[[t2] = nvFusion0(x)
   # t1 = prims.var(x, (0, 1), correction=1)  # t1: "cuda:0 f32[]"
   # t2 = prims.add(t1, t1)  # t2: "cuda:0 f32[]"]

In [9]: trace = thunder.last_traces(func)[-1]

In [10]: nvfuser_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.name.startswith("nvFusion")]

In [11]: nvfuser_symbols
Out[11]: 
[[t2] = nvFusion0(x)
   # t1 = prims.var(x, (0, 1), correction=1)  # t1: "cuda:0 f32[]"
   # t2 = prims.add(t1, t1)  # t2: "cuda:0 f32[]"]

In [12]: nvfuser_symbols[0].args
Out[12]: (x,)

In [13]: nvfuser_symbols[0].args[0].shape
Out[13]: (512, 512)

In [14]: nvfuser_symbols[0].args[0].dtype
Out[14]: float32

I propose to add a mechanism to retrieve input information for a fusion definition.

Given the information above what mechanism are you planning to add?

from lightning-thunder.

riccardofelluga avatar riccardofelluga commented on June 2, 2024

Thanks for the comment! To see what mechanism I came up with before you had a chance to comment, please check out the linked PR #388.

With this added context, I'll check out how I can reuse the examine mechanism in my PR and update this issue.

from lightning-thunder.

IvanYashchuk avatar IvanYashchuk commented on June 2, 2024

What is the goal?
Depending on the goal strides information might not be needed or essential to have.

Let's think of debugging scenarios, here are example I could come up with:

  1. thunder.jit(fn)(inputs) worked but the result is incorrect or performance is slow and we'd like to rerun only one specific fusion region that happens to be nvFuser since this issue is about improving Thunder+nvFuser experience. We can query the FusionDefinition, we can print the nvFuser's representation, and call FusionDefinition.execute. Having strides information of the actual inputs for this fusion region is a must-have information because contiguity of a tensor is a static property of given FusionDefinition instance also affecting performance.
  2. thunder.jit(fn)(inputs) is not used, instead the Thunder Trace object is constructed manually following the steps usually happening in thunder.jit and now before we even attempt to run the full execution trace we want to test individual fusion regions. At this time we don't know what strides could be in a real program because we're executing in isolation. We can create sample inputs using the shape, type, and device information from the trace and pass to our FusionDefinitionWrapper that has a stride info cache. Allowing users to specify different memory layouts for inputs could be beneficial to test performance and correctness of similar but different FusionDefinitions.

That's all specific for nvFuser as a FusingExecutor. How can this be extended to run any slice of a trace that involves any FusingExecutor and/or OperatorExecutor ops?

from lightning-thunder.

csarofeen avatar csarofeen commented on June 2, 2024

CC @kevinstephano

from lightning-thunder.

kevinstephano avatar kevinstephano commented on June 2, 2024

I am not sure what exactly is desired but we could add a kwarg to nvFuser's FusionDefinition::execute() method such as make_repro such as FusionDefinition.execute(inputs, make_repro=True). The issue is that we don't understand inputs unless we seem them through the execute() method as we have only a symbolic view of tensors in our definition.

from lightning-thunder.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.