Comments (6)
After further inspection @IvanYashchuk I still think this are two sightly different things. In the PR the output is a ready to run python code for the fusion and the method you explained allows to get similar information but missing the stride and the code to run the fusions. However, I agree with you on that it might be better to move the code from that PR to the examine.py
file. Say, making something similar to get_fusion_symbols
but that returns the repro code for the fusions, what do you think?
Even better, I think I can get the information about the inputs from the trace using your technique, eliminating the need to modify the nvfuserex_impl.py
file at all.
from lightning-thunder.
A nvFuser fusion region in the Thunder trace is represented as a BoundSymbol. You can get all bound symbols by accessing TraceCtx.bound_symbols
and for all bound symbols, you can access their inputs with BoundSymbol.args
, BoundSymbol.kwargs
, or BoundSymbol.flat_args
.
All nvFuser bound symbols have BoundSymbol.is_fusion
set to True and their name starts with nvFusion
, you can filter out all other symbols with a simple list comprehension
nvfuser_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.name.startswith("nvFusion")]
There's also thunder.examine.get_fusion_symbols
function that does the same using is_fusion
lightning-thunder/thunder/examine/__init__.py
Line 207 in dd42bb3
Here's an example session using bound symbols info to retrieve information on inputs:
In [1]: import torch
In [2]: import thunder
In [3]: @thunder.jit
...: def func(x):
...: t1 = thunder.prims.var(x, (0, 1), correction=1)
...: t2 = thunder.prims.add(t1, t1)
...: return t2
...:
In [4]: x = torch.randn(512, 512, device="cuda")
In [5]: out = func(x)
In [6]: thunder.last_traces(func)[-1]
Out[6]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cuda:0 f32[512, 512]"
[t2] = nvFusion0(x)
# t1 = prims.var(x, (0, 1), correction=1) # t1: "cuda:0 f32[]"
# t2 = prims.add(t1, t1) # t2: "cuda:0 f32[]"
del x
return t2
In [7]: import thunder.examine
In [8]: thunder.examine.get_fusion_symbols(thunder.last_traces(func)[-1])
Out[8]:
[[t2] = nvFusion0(x)
# t1 = prims.var(x, (0, 1), correction=1) # t1: "cuda:0 f32[]"
# t2 = prims.add(t1, t1) # t2: "cuda:0 f32[]"]
In [9]: trace = thunder.last_traces(func)[-1]
In [10]: nvfuser_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.name.startswith("nvFusion")]
In [11]: nvfuser_symbols
Out[11]:
[[t2] = nvFusion0(x)
# t1 = prims.var(x, (0, 1), correction=1) # t1: "cuda:0 f32[]"
# t2 = prims.add(t1, t1) # t2: "cuda:0 f32[]"]
In [12]: nvfuser_symbols[0].args
Out[12]: (x,)
In [13]: nvfuser_symbols[0].args[0].shape
Out[13]: (512, 512)
In [14]: nvfuser_symbols[0].args[0].dtype
Out[14]: float32
I propose to add a mechanism to retrieve input information for a fusion definition.
Given the information above what mechanism are you planning to add?
from lightning-thunder.
Thanks for the comment! To see what mechanism I came up with before you had a chance to comment, please check out the linked PR #388.
With this added context, I'll check out how I can reuse the examine mechanism in my PR and update this issue.
from lightning-thunder.
What is the goal?
Depending on the goal strides information might not be needed or essential to have.
Let's think of debugging scenarios, here are example I could come up with:
thunder.jit(fn)(inputs)
worked but the result is incorrect or performance is slow and we'd like to rerun only one specific fusion region that happens to be nvFuser since this issue is about improving Thunder+nvFuser experience. We can query the FusionDefinition, we can print the nvFuser's representation, and callFusionDefinition.execute
. Having strides information of the actual inputs for this fusion region is a must-have information because contiguity of a tensor is a static property of given FusionDefinition instance also affecting performance.thunder.jit(fn)(inputs)
is not used, instead the Thunder Trace object is constructed manually following the steps usually happening inthunder.jit
and now before we even attempt to run the full execution trace we want to test individual fusion regions. At this time we don't know what strides could be in a real program because we're executing in isolation. We can create sample inputs using the shape, type, and device information from the trace and pass to our FusionDefinitionWrapper that has a stride info cache. Allowing users to specify different memory layouts for inputs could be beneficial to test performance and correctness of similar but different FusionDefinitions.
That's all specific for nvFuser as a FusingExecutor. How can this be extended to run any slice of a trace that involves any FusingExecutor and/or OperatorExecutor ops?
from lightning-thunder.
from lightning-thunder.
I am not sure what exactly is desired but we could add a kwarg
to nvFuser's FusionDefinition::execute()
method such as make_repro
such as FusionDefinition.execute(inputs, make_repro=True)
. The issue is that we don't understand inputs unless we seem them through the execute()
method as we have only a symbolic view of tensors in our definition.
from lightning-thunder.
Related Issues (20)
- Recursion error in transformer module with NeMo Stable Diffusion HOT 3
- Hang using thunder.jit with tokenizer in NeMo Stable Diffusion HOT 5
- Constraints to insert static numbers
- CI: Re-Enable torchrun call in Zero to Thunder notebook
- dtype inconsistencies when dividing/rounding tensors
- thunder.jit of AutoEncoder in NeMo Stable Diffusion slower than eager HOT 4
- Dynamic shape needs to be modeled in trace
- OOM errors for Gemma-7, pythia-12b, Llama-2-13b-hf and Nous-Hermes-13b with FSDP zero3 and 2x8 H100 HOT 1
- Refine recording of source locations HOT 5
- Nous-Hermes-13b on 1x8 H100 FSDP zero2 with thunder_cudnn is 23% slower than with inductor
- fsdp(jit(...)) transform can use more memory compared to jit(fsdp(...)) HOT 1
- nvfuserex has problems taking getitem. HOT 3
- load/save_state_dict hooks for early transforms
- Training Llama-2-13b-hf on 2x8 H100 with Thunder inductor is 47% slower than with Inductor
- FP8 Linear and conv with cudnn HOT 1
- Support RN50 BatchNorm fusions with cudnn
- CI : PyTorch nightly CI failing with `FutureWarning: is_compiling is deprecated. Use torch.compiler.is_compiling() instead.`
- Distill API for module transformations from distributed / quantization uses of ThunderModule attributes
- TransformerEngine API changed and caused test failure `AttributeError: 'TELinear' object has no attribute 'fp8_weight_shapes'`
- FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning-thunder.