Giter VIP home page Giter VIP logo

lightning-thunder's People

Contributors

aidyn-a avatar anerudhan avatar apaz-cli avatar awaelchli avatar borda avatar carmocca avatar crcrpar avatar dependabot[bot] avatar ivanyashchuk avatar izzyputterman avatar jacobhinkle avatar jjsjann123 avatar k223kim avatar kevinstephano avatar kiya00 avatar kshitij12345 avatar lantiga avatar mruberry avatar nikitaved avatar parthmannan avatar pl-ghost avatar pre-commit-ci[bot] avatar rdspring1 avatar riccardofelluga avatar robieta avatar t-vi avatar tfogal avatar vedaanta-nvidia avatar wujingyue avatar young768 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightning-thunder's Issues

CI fails to build `cuda 12.1 | torch 2.3 /test | cudnn FE v1.2`

๐Ÿ› Bug

CI fails to build build_push cuda 12.1 | torch 2.3 /test | cudnn FE v1.2 failed, apparently because PyTorch bumped the Triton dependency to 2.3.0.
https://github.com/Lightning-AI/lightning-thunder/runs/23676066094

To Reproduce

82.88 The conflict is caused by:
82.88     The user requested triton==2.2.0
82.88     torch 2.3.0+cu121 depends on triton==2.3.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12"

@Borda

Handle returning NamedTuples from the JIT

NamedTuples have (some) support in the interpreter, but to return them from the JIT, we would need some more things:

  • ensure value tracking covers creation (populating the attribute_wrappers) - this is in interpreter,
  • when seeing NamedTuples being returned, add epilogue code to create the NamedTuple from its contents (which are available),

This is nontrivial as it requires return values to be routed through the epilogue, but I don't think that is grave.

This is probably an intermediate to advanced issue for people fond of the JIT / frontend bits.

How do I access the `ThunderModule` if I'm compiling a function?

๐Ÿš€ Feature

Motivation

Sometimes the code requires that a ThunderModule is passed, however, if the user is compiling a function that takes the module as an argument, the user doesn't have a way to get a reference to it.

For example, #96 implements a workaround for this issue with the no_sync context manager.

Pitch

Provide an API to get this reference. Maybe it's something like thunder.compile_data(jitted_function).module.

Additional context

The design might need to consider the presence of multiple ThunderModules.

Distributed Tests failing but CI is green

On latest main 94c9494, CI flow for distributed shows success https://github.com/Lightning-AI/lightning-thunder/runs/23172744261

But looking at the log, there are a few tests that have failed.

Sample

=================================== FAILURES ===================================
_ CompileDDPTest.test_fsdp_grad_parity_with_without_bucketing_executor_nvfuser_bucketing_block_zero2 _
/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_distributed.py:533: in wrapper
    self._join_processes(fn)
/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_distributed.py:752: in _join_processes
    self._check_return_codes(elapsed_time)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

Link to log: https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=196660&view=logs&j=47e66f3c-897a-5428-da11-bf5c7745762e&t=97be8351-284a-5dba-49eb-f9fe7c3ed1a2&l=811

cc @Borda

The `_FabricModule` cannot be jitted after #78

๐Ÿ› Bug

extensions/thunder/pretrain.py:146: in setup
    main(
extensions/thunder/pretrain.py:233: in main
    fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
extensions/thunder/pretrain.py:253: in fit
    validate(fabric, model, val_dataloader, max_iters=2)  # sanity check
../nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
extensions/thunder/pretrain.py:389: in validate
    loss = forward_and_loss(model, input_ids, targets)
../lightning-thunder/thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
../lightning-thunder/thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
../lightning-thunder/thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6669: in fn_
    raise e
../lightning-thunder/thunder/core/interpreter.py:6632: in fn_2
    return fn(*args, **kwargs)
extensions/thunder/pretrain.py:371: in forward_and_loss
    logits = model(input_ids)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../lightning/src/lightning/fabric/wrappers.py:142: in forward
    with precision.forward_context():
../lightning/src/lightning/fabric/plugins/precision/half.py:54: in forward_context
    return self.tensor_init_context()
../lightning/src/lightning/fabric/plugins/precision/half.py:46: in tensor_init_context
    return _DtypeContextManager(self._desired_input_dtype)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def __init__(self, dtype: torch.dtype) -> None:
>       self._previous_dtype: torch.dtype = torch.get_default_dtype()
E       NotImplementedError: Trying to call function torch.get_default_dtype, but it is not yet supported. Please file an issue requesting support. To find out which operations are not yet recongnized by `thunder.jit`, please run `examine` as per:
E       
E       from thunder.examine import examine
E       examine(<your thunder.jit callable argument>, ...)

../lightning/src/lightning/fabric/plugins/precision/utils.py:33: NotImplementedError

Jitting the _FabricModule is currently necessary to compile the joint forward and loss

To Reproduce

from lightning import Fabric
import torch
import thunder

fabric = Fabric(devices=1, precision="16-true")
model = torch.nn.Linear(1, 1, bias=False, device=fabric.device)
x = torch.randn(1, 1)
x = fabric.to_device(x)

fmodel = fabric.setup(model)
tmodel = thunder.jit(fmodel)

print(tmodel(x))

cc @nikitaved

Support `is_cuda`

Something like the following should work

import thunder
import torch

def foo(x):
    if not x.is_cuda:
        x = x.to('cuda')
    return x * x

x = torch.randn(3, device='cpu')
jit_foo = thunder.jit(foo)
o = jit_foo(x)

print(thunder.last_traces(jit_foo)[-1])

Above fails with

  File "/home/kkalambarkar/lightning-thunder/thunder/core/proxies.py", line 1234, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_cuda

Support for CUDA kernels

๐Ÿš€ Feature

Hi there ๐Ÿ‘‹

From the main readme file I noticed that Thunder except custom kernels, but only the ones that are written in Trition.
Is there a plan to support CUDA kernels?

Motivation

I'm only in the beginning of the custom kernels journey, so I might misunderstand something.

From what I saw online, there are many of highly optimized CUDA kernels already available (since CUDA has been around for quite a while). Plus, there is a high chance that someone with a lot of experience in writing CUDA kernels (but not Trition) want's to use Thunder (or even integrate into an existing project).

I personally would like to write custom CUDA kernels for the LitGPT repo after I finish reading PMPP book.

`TensorBase.cuda`

๐Ÿš€ Feature

Implement Tensor.cuda that returns a cuda-backed copy of the given tensor.

Motivation

NeMo text-to-image model. It's plausible that the source tensor used in the model is a GPU tensor already, so we might be able to get by with just returning a tensor copy without worrying about movement between devices.

Sunset `thunder/benchmarks/distributed.py` and Improve `thunder/benchmarks/benchmark_litgpt.py`

  • [cosmetic] improve the format of JSON output of benchmark_litgpt.py

https://github.com/Lightning-AI/lightning-thunder/blob/cdd43a7fc1110eec10f1854250299b84d1c3b2a8/thunder/benchmarks/distributed.py has been useful but I would find it not easy to extend, e.g. to support gradient accumulation.

https://github.com/Lightning-AI/lightning-thunder/blob/cdd43a7fc1110eec10f1854250299b84d1c3b2a8/thunder/benchmarks/benchmark_litgpt.py would be easy to work with as in #45 which is adding gradient accumulation with no_sync.

cc @crcrpar @carmocca @awaelchli

Skipped distributed tests show up as passed (return 0)

          Note that this can be very misleading because a skipped test also returns 0, so it can make it seem like a test passed when it didn't run

Originally posted by @carmocca in #130 (comment)

python -um pytest -sv "$test" --pythonwarnings ignore --junitxml="$test-results.xml" 2>&1 > "$test-output.txt"
pytest_status=$?
printf "$test status >>> $pytest_status\n"
if [ $pytest_status -ne 0 ]; then
status=$pytest_status
cat "$test-output.txt"
fi

cc @Borda

caching in make_aug_forward_and_backward breaks TE executor.

As discussed offline, Caching in make_aug_forward_and_backward leads to reusing the symbols created by transformer_engine_ex which are stateful and lead to incorrect program.
Ref:

key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs))
cached_result = _cache.get(key, None) if subkey is not None else None
if cached_result is not None:
return cached_result

Sample Program

import torch
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex
from transformer_engine.pytorch import fp8_autocast
dim = 256

class ThunderModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(dim, dim, bias=False)
        self.fc2 = torch.nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        return self.fc2(torch.nn.functional.relu(self.fc1(x)))

x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()

thunder_model = ThunderModel().cuda()

jit_model = thunder.jit(thunder_model, executors=(transformer_engine_ex,),)

with fp8_autocast():
    o = jit_model(x).sum()

print(thunder.last_traces(jit_model)[-1])

Generated Trace (te_linear_0 is called twice):

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, t1, t2, = args
  del args
  (t6, ctx_te_1) = te_linear_0(t0, t1, None)
  t7 = torch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
    # t7 = ltorch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
      # t7 = prims.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
  t8 = torch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
    # t8 = ltorch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
      # t8 = prims.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
  del t6
  (t13, C12) = te_linear_0(t8, t2, None)
  del t8
  return {'output': t13, 'flat_args': [t0, t1, t2], 'flat_output': (t13,)}, ((t7,), (C12, ctx_te_1))

Add stride operation primitive

๐Ÿš€ Feature

I would like to have Thunder manage stride information to allow for tensor manipulations.

In particular I think the following points need to be discussed:

  • Where does the stride information go? Should it be part of TensorProxy?
  • What can a stride manipulation primitive look like? Are there any particular things we need to be careful about?

Motivation

This will enable us to add new operators that reshape the tensor using the stride like torch.as_strided or torch.Tensor.unfold.

Make DDP/FSDP a regular transform

๐Ÿš€ Feature

Make DDP/FSDP a regular transform (to a large part including making transforms flexible enough to support this).

Motivation

Currently DDP/FSDP is not a regular transform, leading to things like #94 and limiting composability / sequencing.
One of the key bits is that DDP/FSDP would need to do the adjustments we currently do to the prologue during tracing with DDP/FSDP in the transform, so we need to allow mutation of prologues through transforms. This is also in line with similar needs for other transforms (lora, quantization, but also value-and-grad-things) that change prologue signatures, so this generalization should happen.

cc @carmocca @awaelchli @crcrpar

[lit-GPT] Thunder with torch.compile executor performs consistently worse than Thunder on all model sizes/batch sizes on Pythia models

๐Ÿ› Bug

The performance of using the hybridized torch.compile executor w/ Thunder is worse than plain Thunder on Pythia models. These set of models differ from LLaMa architecture in few main ways -

  1. Use LayerNorm instead of RMSNorm
  2. Use GeLU instead of 'SiLU(x) * x`
  3. Uses parallel residual (i.e. the MLP block is computed with an input computed before the Attention block, not after)

Example performance on H100 Single Node FP16 for Pythia6.9B, MBS=1, GBS=8, FSDP ZeRO2 w/o bucketing
Thunder iteration time (ms) = 232.74 ms
Thunder + torch.compile iteration time (ms) = 239.23 ms

cc @crcrpar @apaz-cli

Cuda only?

Hi

Thanks for sharing this with the community. Much appreciated.

I am wondering if this works only with cuda hardware. For example, does it work with AMD GPUs through rocm?

Non-supported diffusion transformer operators

TensorBase.bfloat16
_set_grad_enabled of torch._C
_VariableFunctionsClass.empty of torch
TensorBase.long
TensorBase.type
TensorBase.__setitem__
_VariableFunctionsClass.lerp of torch
device of torch
TensorBase.clone
TensorBase.masked_fill_
TensorBase.get_device
TensorBase.grad_fn
_VariableFunctionsClass.linspace of torch

cc @apaz-cli

[feature request] Indexing with boolean masks

๐Ÿš€ Feature

No indexing with boolean mask, for example:

import torch; import thunder                                                                                   
m = x <= 0.5
                                                                                                                                                                                                                          
def f(x, m):                                                                                                   
    return x[m]                                    
                                                                                                        
                                                                                                                       
jf = thunder.jit(f)                                
jf(x, m)  

fails with

RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors, but found a tensor with dtype bool8 and 1 dimensions

cc @apaz-cli

Build op provenance tracking into compile trace output

๐Ÿš€ Feature

The request is to be able to connect the practitioner's model code clearly to the produced graph trace by Thunder. Ideally, each traced node should be able to map back to the model code for which it got generated.

Motivation

This would tremendously help debugging issues around graph capture, graph optimization (such as rematerialization, DCE etc.). This also helps improve user understanding of what Thunder is doing. It could also be very helpful for developers to build tools that can operate on top of Thunder graphs.
Examples from TorchInductor in the Pitch.

Pitch/Additional Context

Example of FX Graph debug from TorchInductor - mapping traced graph decomposed nodes back to practitioner model code.

        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:35, code: out = self.conv(x)
        convolution: f32[16, 64, 56, 56] = torch.ops.aten.convolution.default(primals_7, primals_1, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
        
        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:36, code: out = self.bn(out)
        add: i64[] = torch.ops.aten.add.Tensor(primals_6, 1);  primals_6 = None
        var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
        getitem: f32[1, 64, 1, 1] = var_mean[0]
        getitem_1: f32[1, 64, 1, 1] = var_mean[1];  var_mean = None
        add_1: f32[1, 64, 1, 1] = torch.ops.aten.add.Tensor(getitem, 1e-05)
        rsqrt: f32[1, 64, 1, 1] = torch.ops.aten.rsqrt.default(add_1);  add_1 = None
        sub: f32[16, 64, 56, 56] = torch.ops.aten.sub.Tensor(convolution, getitem_1)
        mul: f32[16, 64, 56, 56] = torch.ops.aten.mul.Tensor(sub, rsqrt);  sub = None
        squeeze: f32[64] = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]);  getitem_1 = None
        squeeze_1: f32[64] = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]);  rsqrt = None
        mul_1: f32[64] = torch.ops.aten.mul.Tensor(squeeze, 0.1)
        mul_2: f32[64] = torch.ops.aten.mul.Tensor(primals_4, 0.9);  primals_4 = None
        add_2: f32[64] = torch.ops.aten.add.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
        squeeze_2: f32[64] = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]);  getitem = None
        mul_3: f32[64] = torch.ops.aten.mul.Tensor(squeeze_2, 1.0000199302441455);  squeeze_2 = None
        mul_4: f32[64] = torch.ops.aten.mul.Tensor(mul_3, 0.1);  mul_3 = None
        mul_5: f32[64] = torch.ops.aten.mul.Tensor(primals_5, 0.9);  primals_5 = None
        add_3: f32[64] = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
        unsqueeze: f32[64, 1] = torch.ops.aten.unsqueeze.default(primals_2, -1)
        unsqueeze_1: f32[64, 1, 1] = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = None
        unsqueeze_2: f32[64, 1] = torch.ops.aten.unsqueeze.default(primals_3, -1);  primals_3 = None
        unsqueeze_3: f32[64, 1, 1] = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = None
        mul_6: f32[16, 64, 56, 56] = torch.ops.aten.mul.Tensor(mul, unsqueeze_1);  mul = unsqueeze_1 = None
        add_4: f32[16, 64, 56, 56] = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3);  mul_6 = unsqueeze_3 = None
        
        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:37, code: out = self.relu(out)
        relu: f32[16, 64, 56, 56] = torch.ops.aten.relu.default(add_4);  add_4 = None
        le: b8[16, 64, 56, 56] = torch.ops.aten.le.Scalar(relu, 0)

Similarly, in the final codegen output, one can see which decomposed node belongs in each kernel generated and what aten level op did the decomposed node come from. Inside the kernel, there are also comments describing the practitioner code stack which is included in each kernel.
This is already covered by Thunder today to some extent as the trace output lists all the decomposed nodes which are part of a NVFusion block. But a mapping to original code would be fantastic for better understanding.

# aten._native_batch_norm_legit_functional => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
# aten.relu => relu
# aten.threshold_backward => le
triton_poi_fused__native_batch_norm_legit_functional_relu_threshold_backward_4 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4194304], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*i1', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def triton_poi_fused__native_batch_norm_legit_functional_relu_threshold_backward_4(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xnumel = 3211264
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 3136) % 64

    # ORIGIN:
    # call_function aten.relu.default
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 37, in forward\    out = self.relu(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.add.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.rsqrt.default
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.mul.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.var_mean.correction
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.mul.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.add.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.sub.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN

    tmp0 = tl.load(in_ptr0 + (x3), None)
    tmp1 = tl.load(in_ptr1 + (x1), None)
    tmp3 = tl.load(in_ptr2 + (x1), None)
    tmp10 = tl.load(in_ptr3 + (x1), None)
    tmp12 = tl.load(in_ptr4 + (x1), None)
    tmp2 = tmp0 - tmp1
    tmp4 = 50176.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = tl.math.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tmp14 = tl.where(0 != 0, 0, tl.where(0 > tmp13, 0, tmp13))

    # ORIGIN:
    # call_function aten.le.Scalar
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 37, in forward\    out = self.relu(out)\
    # END ORIGIN

    tmp15 = 0.0
    tmp16 = tmp14 <= tmp15
    tl.store(out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp14, None)
    tl.store(out_ptr1 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp16, None)
''')

cc @carmocca

Partial function is not supported in `grad_transform`

๐Ÿš€ Feature

Hitting this assert below vvv

root@847841b8737c:/opt/pytorch/lightning-thunder# python /volume/pooling.py
Traceback (most recent call last):
  File "/volume/pooling.py", line 36, in <module>
    o = jit_model(image)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 632, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 265, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 574, in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
  File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 216, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3879, in forward_and_backward_from_trace
    forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 528, in _trace
    result = fn(*proxyargs, **proxykwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3850, in augmented_forward_fn
    result, env = augmented_forward_pass(*args, trace=trace, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3461, in augmented_forward_pass
    result, env = eval_trace(
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 1698, in eval_trace
    prim_func = symbol_mapper(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3385, in vjp_symbol_mapper
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/vjp_utils.py", line 63, in make_aug_forward_and_backward
    joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 506, in _trace
    proxyargs, proxykwargs = _unpack_inputs(fn, trace, args, kwargs, rename_proxies=rename_proxies)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 273, in _unpack_inputs
    si = get_siginfo(fn, args, kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/codeutils.py", line 313, in get_siginfo
    check(
  File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
NotImplementedError: Support for partials with positional args (like ('test',)) is not implemented yet

I was trying to use something like

foo = partial(bar, pos_arg0)
OperatorExecutor.register_operator(..., grad_transform=fn)

This isn't a high priority issue, since we can easily work around it for now. Filing the issue just to keep track of missing feature.

Does `jit` understand monkeypatched methods?

๐Ÿ› Bug

Tensor.register_hook is currently not supported by Thunder.

In Lightning Fabric, we use this once for error checking that the user properly called backward. https://github.com/Lightning-AI/pytorch-lightning/blob/096b063d6eeb41567409f4a6b9bac6f5af28ed93/src/lightning/fabric/wrappers.py#L232-L233

Since this hook is not critical, as it's only meant to avoid user errors, I would like to be able to monkeypatch it externally.

However, it doesn't seem like it has an effect with Thunder:

To Reproduce

import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
import thunder

model = torch.nn.Linear(1, 1, bias=False, device="cuda")
x = torch.randn(1, 1, device="cuda", requires_grad=True)

fabric = Fabric(accelerator="cuda", devices=1)
model = fabric.setup(model)

# monkeypatch what's causing trouble
assert isinstance(model, _FabricModule)
assert model._register_backward_hook is not None
model._register_backward_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()
print(y)
print(x.grad)

Which fails as Thunder doesn't support register_hook

AttributeError: The torch language context has no method register_hook

Interestingly, a non-fabric snippet doesn't fail so there is something funny going on:

import thunder
import torch

class Wrapper(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(1, 1, bias=False)

    def forward(self, x):
        y = self.model(x)
        self.register_hook(y)
        return y

    def register_hook(self, tensor):
        tensor.register_hook(self.hook)

    def hook(self, _):
        print("hi")

model = Wrapper()
x = torch.randn(1, 1)

model.register_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()

Represent slices natively in traces

๐Ÿš€ Feature

Motivation

Tensor slices are represented in traces as:

  t107 = torch_slice_prim_impl(t53, [0, 0, 0, 0], [4, 32, 2048, 0], [1, 1, 1, 1])  # t107: "cuda:0 bf16[4, 32, 2048, 0]"

But there's no torch_slice_prim_impl import. And we can use Python to represent it.

This reference comes from:

slice_prim_impl = ex.register_operator("torch_slice_prim_impl", meta=prims.slice_prim.meta, fn=_slice_prim_impl)
_register_implementation(prims.slice_prim, slice_prim_impl, checker=_always_executable)

# TODO When getitem is fully supported this can be changed to be an execution transform instead of a direct impl
def _slice_prim_impl(
a: torch.Tensor, start_indices: Sequence[int], end_indices: Sequence[int], strides: None | Sequence[int] = None
) -> torch.Tensor:
_strides = strides if strides is not None else [1] * len(start_indices)
slices: list = []
for start, stop, step in zip(start_indices, end_indices, _strides):
slices.append(slice(start, stop, step))
return operator.getitem(a, slices)

Pitch

Instead represent it with __getitem__ and slice():

t123 = t321.__getitem__([slice(0, 3), slice(0, 5)])  # t123: "cuda:..."

Alternatives

Add the torch_slice_prim_impl import from torchex to the trace so that it's a valid program

cc @apaz-cli @nikitaved

Benchmark targets on test_nanogpt_cross_entropy_grad has some import issue

๐Ÿ› Bug

Benchmark targets on test_nanogpt_cross_entropy_grad has some import issue

To Reproduce

Steps to reproduce the behavior:

root@8d345ed01185:/opt/pytorch/lightning-thunder# pytest -vvvs thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad]
============================================================================================== test session starts ==============================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/opt/pytorch/lightning-thunder/.hypothesis/examples'))
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: cov-4.1.0, hypothesis-6.100.0, random-order-1.1.1, timestamper-0.0.10, timeout-2.2.0, xdist-3.5.0, shard-0.1.2, benchmark-4.0.0
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 1 item
Running 1 items in this shard: thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad]

[2024-04-10 21:48:06] thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] FAILED

=================================================================================================== FAILURES ====================================================================================================
______________________________________________________________________________ test_nanogpt_cross_entropy_grad[thunder+apex-grad] _______________________________________________________________________________

benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x7fe26c8f0cd0>
executor = functools.partial(<function thunder_grad_transform at 0x7fe26ca5ecb0>, compile_fn=<function thunder_apex_executor at 0x7fe26cad8b80>)

    @pytest.mark.parametrize(
        "executor,",
        (grad_executors + apex_grad_executors),
        ids=(grad_executors_ids + apex_grad_executors_ids),
    )
    def test_nanogpt_cross_entropy_grad(benchmark, executor: None | Callable):
        if executor is None:
            pytest.skip("Executor is unavailable")

        bench: Benchmark = NanoGPTCrossEntropyBenchmark(
            config="gpt2-xl", device="cuda:0", dtype=thunder.bfloat16, requires_grad=True
        )

        setup = make_setup(bench)
        fn = executor(bench)
        fn = wrap_for_benchmark(fn)

>       benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1)

thunder/benchmarks/targets.py:479:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:137: in pedantic
    return self._raw_pedantic(target, args=args, kwargs=kwargs, setup=setup, rounds=rounds,
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:211: in _raw_pedantic
    runner(loops_range)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:95: in runner
    result = function_to_benchmark(*args, **kwargs)
thunder/benchmarks/targets.py:60: in fn_
    result = fn(*args, **kwargs)
thunder/benchmarks/targets.py:235: in wrapper
    populate_grads(grads, cfn, args=args, kwargs=kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

grads = [tensor([[0.0000e+00, 2.2768e-18, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0090e-12,...4, 8.5986e-29, 0.0000e+00,  ..., 0.0000e+00, 5.7932e-31,
         0.0000e+00]], device='cuda:0', dtype=torch.bfloat16)]
tom = <function NanoGPTCrossEntropyBenchmark.fn.<locals>.foo at 0x7fe26c93c8b0>
args = (tensor([[ 79.0000, 227.0000,   8.5625,  ..., 166.0000, 152.0000, 154.0000],
        [240.0000, 224.0000,   2.3125,  ....cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>), tensor([223, 144, 141,  ..., 219, 169, 186], device='cuda:0'))
kwargs = {}

    def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, args=None, kwargs=None) -> None:
        idx: int = 0
        from thunder import ThunderModule, compile_data

>       if isinstance(tom, ThunderModule) or thunder.compile_data(tom).using_jit:
E       NameError: name 'thunder' is not defined

thunder/core/transforms.py:555: NameError
============================================================================================ short test summary info ============================================================================================
FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] - NameError: name 'thunder' is not defined
======================================================================================== 1 failed, 5 warnings in 11.31s =========================================================================================

Code sample

see above

Expected behavior

benchmark should be able to run

Environment

  • internal image: pjnl-20240410
  • thunder: dba8ce7

Additional context

same issues on those two:

FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] - NameError: name 'thunder' is not defined
FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex+nvfuser-grad] - NameError: name 'thunder' is not defined

cc @tfogal @IvanYashchuk

Weight tying + FSDP = nvfuser internal error

๐Ÿ› Bug

To Reproduce

Code:

import os
import torch
import torch.distributed as tdist
import thunder
from thunder.tests.lit_gpt_model import GPT, Config

if __name__ == "__main__":
    tdist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)

    config = Config(block_size=256, padded_vocab_size=32000, n_layer=6, n_head=6, head_size=48, n_embd=288, rotary_percentage=1.0, parallel_residual=False, bias=False, _norm_class='RMSNorm', _mlp_class='LLaMAMLP', intermediate_size=768)
    with device:
        model = GPT(config)

    model.transformer.wte.weight = model.lm_head.weight

    model = thunder.distributed.fsdp(model)
    model = thunder.jit(model)

    input_ids = torch.randint(1, 30010, (128, 256), dtype=torch.long, device=device)
    logits = model(input_ids)
    print(logits.shape)

Run with:

torchrun --nproc-per-node 2 --local-ranks-filter 0 repro.py

Nvfuser repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.ops.mul(T1, T1)
    T3 = fd.ops.sum(T2, dims=[2], keepdim=False, dtype=DataType.Null)
    S4 = fd.define_scalar(128, dtype=DataType.Int)
    S5 = fd.define_scalar(256, dtype=DataType.Int)
    S6 = fd.define_scalar(1, dtype=DataType.Int)
    V7 = fd.define_vector([S4, S5, S6], dtype=DataType.Int)
    T8 = fd.ops.broadcast_in_dim(T3, shape=V7, broadcast_dims=[0, 1])
    S9 = fd.define_scalar(288.000, dtype=DataType.Double)
    S10 = fd.ops.reciprocal(S9)
    T11 = fd.ops.mul(T8, S10)
    S12 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T13 = fd.ops.add(T11, S12)
    T14 = fd.ops.rsqrt(T13)
    S15 = fd.define_scalar(128, dtype=DataType.Int)
    S16 = fd.define_scalar(256, dtype=DataType.Int)
    S17 = fd.define_scalar(288, dtype=DataType.Int)
    V18 = fd.define_vector([S15, S16, S17], dtype=DataType.Int)
    T19 = fd.ops.broadcast_in_dim(T14, shape=V18, broadcast_dims=[0, 1, 2])
    T20 = fd.ops.mul(T1, T19)
    T21 = fd.ops.mul(T20, T0)
    fd.add_output(T14)
    fd.add_output(T21)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((128, 256, 288), (0, 0, 1)),
    torch.randn((9437184,), dtype=torch.float32, device='cuda:0').as_strided((128, 256, 288), (73728, 288, 1)),
]
fd.execute(inputs)

Short error:

RuntimeError: _result == CUDA_SUCCESS INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/executor_utils.cpp":907, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_ASSERT failed with error device-side assert triggered

Full error:

error.txt

Removing one of:

  • FSDP
  • 30010 as the highest input value
  • weight tying

makes the problem not appear

cc @tfogal

torchex running pooling without decomposition

๐Ÿš€ Feature

max_poolXd through decomposition is expensive in thunder. torch executor should be able to run those as a single aten call on fwd as well as bwd via a custom grad_transform

Motivation

Currently if we run the example below vvv

import torch
import torch.nn as nn

import thunder

dtype = torch.float16
batch_size = 32
test_grad = True

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        layers = list()
        layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.layer = nn.Sequential(*layers)

    def forward(self, inp):
        return self.layer(inp)

model = Model()

model = model.cuda()
model = model.to(dtype)

image = torch.randn(batch_size, 3, 224, 224, dtype=dtype).cuda()
if test_grad:
    image.requires_grad_()

def fn(arg):
    return model(arg)

jit_model = thunder.jit(fn)

# warm up
for i in range(20):
  o = jit_model(image)
  if test_grad:
      o.sum().backward()
  o = fn(image)
  if test_grad:
      o.sum().backward()

import time
fwd_traces = thunder.last_traces(jit_model)
print("fwd_traces:\n", fwd_traces[-1])
if test_grad:
    bwd_traces = thunder.last_backward_traces(jit_model)
    print("bwd_traces:\n", bwd_traces[-1])

torch.cuda.synchronize()

t0 = time.time()
for i in range(10):
    o = jit_model(image)
    if test_grad:
        o.sum().backward()
        image.grad = None
torch.cuda.synchronize()
print("jit_model elapsed time: ", time.time() - t0)

torch.cuda.synchronize()
t0 = time.time()
for i in range(10):
    o = fn(image)
    if test_grad:
        o.sum().backward()
        image.grad = None
torch.cuda.synchronize()
print("torch eager elapsed time: ", time.time() - t0)

jit_model elapsed time: 0.02024698257446289
torch eager elapsed time: 0.002202272415161133

fwd graph looks like

from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(arg):
  # arg: "cuda:0 f16[32, 3, 224, 224]"
  t0 = prims.pad(arg, -float('inf'), [(0, 0, 0), (0, 0, 0), (1, 1, 0), (1, 1, 0)])  # t0: "cuda:0 f16[32, 3, 226, 226]"
  t1 = ltorch.arange(9, None, 1, device=devices.Device("cuda:0"), dtype=None)  # t1: "cuda:0 i64[9]"
    # t1 = prims.iota(9, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64)  # t1: "cuda:0 i64[9]"
  t2 = prims.broadcast_in_dim(t1, [9, 1], [0])  # t2: "cuda:0 i64[9, 1]"
  t3 = prims.broadcast_in_dim(t1, [1, 9], [1])  # t3: "cuda:0 i64[1, 9]"
  t4 = prims.broadcast_in_dim(t2, (9, 9), (0, 1))  # t4: "cuda:0 i64[9, 9]"
  t5 = prims.broadcast_in_dim(t3, (9, 9), (0, 1))  # t5: "cuda:0 i64[9, 9]"
  t6 = prims.eq(t4, t5)  # t6: "cuda:0 b8[9, 9]"
  t7 = prims.convert_element_type(t6, dtypes.float16)  # t7: "cuda:0 f16[9, 9]"
  t8 = prims.reshape(t7, (1, 9, 1, 3, 3))  # t8: "cuda:0 f16[1, 9, 1, 3, 3]"
  t9 = prims.broadcast_in_dim(t8, (3, 9, 1, 3, 3), (0, 1, 2, 3, 4))  # t9: "cuda:0 f16[3, 9, 1, 3, 3]"
  t10 = prims.reshape(t9, (27, 1, 3, 3))  # t10: "cuda:0 f16[27, 1, 3, 3]"
  t11 = prims.convolution(t0, t10, None, (2,), (0,), (1,), False, (0, 0), 3)  # t11: "cuda:0 f16[32, 27, 112, 112]"
  t12 = prims.reshape(t11, (32, 3, 9, 112, 112))  # t12: "cuda:0 f16[32, 3, 9, 112, 112]"
  t13 = prims.convert_element_type(t12, dtypes.float32)  # t13: "cuda:0 f32[32, 3, 9, 112, 112]"
  t14 = prims.amax(t13, (2,))  # t14: "cuda:0 f32[32, 3, 112, 112]"
  t15 = prims.convert_element_type(t14, dtypes.float16)  # t15: "cuda:0 f16[32, 3, 112, 112]"
  return {'output': t15, 'flat_args': [arg], 'flat_output': (t15,)}, ((t10, t14, t13), (0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 3, 32, 27, 112, 112, 32, 3, 9, 112, 112, 2))

bwd graph:

 i239 = operator.neg(i86)  # i239
    # i239 = prims.neg(i86)  # i239
  del i86
  i257 = operator.neg(i27)  # i257
    # i257 = prims.neg(i27)  # i257
  del i27
  i258 = operator.neg(i28)  # i258
    # i258 = prims.neg(i28)  # i258
  del i28
  i259 = operator.neg(i30)  # i259
    # i259 = prims.neg(i30)  # i259
  del i30
  i260 = operator.neg(i31)  # i260
    # i260 = prims.neg(i31)  # i260
  del i31
  i261 = operator.neg(i33)  # i261
    # i261 = prims.neg(i33)  # i261
  del i33
  i262 = operator.neg(i34)  # i262
    # i262 = prims.neg(i34)  # i262
  del i34
  i263 = operator.neg(i36)  # i263
    # i263 = prims.neg(i36)  # i263
  del i36
  i264 = operator.neg(i37)  # i264
    # i264 = prims.neg(i37)  # i264
  del i37
  t303 = torch.unsqueeze(t19, 2)  # t303
    # t303 = ltorch.unsqueeze(t19, 2)  # t303
      # t303 = prims.broadcast_in_dim(t19, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t303
  del t19
  t220 = Tensor.expand(t303, [32, 3, 1, 2, 2])  # t220
    # t220 = ltorch.expand(t303, [32, 3, 1, 2, 2])  # t220
      # t220 = prims.broadcast_in_dim(t303, (32, 3, 1, 2, 2), (0, 1, 2, 3, 4))  # t220
  del t303
  t221 = Tensor.expand(t220, (i104, i105, i106, i107, i108))  # t221
    # t221 = ltorch.expand(t220, (i104, i105, i106, i107, i108))  # t221
      # t221 = prims.broadcast_in_dim(t220, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4))  # t221
  del t220
  t233 = torch.permute(t15, (1, 0, 2, 3))  # t233
    # t233 = ltorch.permute(t15, (1, 0, 2, 3))  # t233
      # t233 = prims.transpose(t15, (1, 0, 2, 3))  # t233
  del t15
  t234 = torch.reshape(t233, [1, i91, 9, 3, 3])  # t234
    # t234 = ltorch.reshape(t233, [1, i91, 9, 3, 3])  # t234
      # t234 = prims.reshape(t233, (1, i91, 9, 3, 3))  # t234
  del t233
  t235 = torch.permute(t234, (1, 0, 2, 3, 4))  # t235
    # t235 = ltorch.permute(t234, (1, 0, 2, 3, 4))  # t235
      # t235 = prims.transpose(t234, (1, 0, 2, 3, 4))  # t235
  del t234
  t236 = torch.reshape(t235, [3, 9, 3, 3])  # t236
    # t236 = ltorch.reshape(t235, [3, 9, 3, 3])  # t236
      # t236 = prims.reshape(t235, (3, 9, 3, 3))  # t236
  del t235
  [t230, t282] = nvFusion0(i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221)
    # t18 = prims.convert_element_type(t17, dtypes.float32)  # t18
    # t282 = prims.pad(t0, 0.0, [(0, 0, 0), (0, 0, 0), (i9, 3, 0), (i10, 3, 0)])  # t282
    # t217 = prims.convert_element_type(t21, dtypes.float32)  # t217
    # t218 = prims.broadcast_in_dim(t217, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t218
    # t219 = prims.broadcast_in_dim(t218, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4))  # t219
    # t222 = prims.eq(t18, t221)  # t222
    # t223 = prims.sum(t222, (i109,))  # t223
    # t224 = prims.broadcast_in_dim(t223, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t224
    # t225 = prims.convert_element_type(t222, dtypes.float32)  # t225
    # t226 = prims.mul(t219, t225)  # t226
    # t227 = prims.broadcast_in_dim(t224, (32, 3, 9, 2, 2), (0, 1, 2, 3, 4))  # t227
    # t228 = prims.convert_element_type(t227, dtypes.float32)  # t228
    # t229 = prims.div(t226, t228)  # t229
    # t230 = prims.convert_element_type(t229, dtypes.float16)  # t230
  del i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221
  t283 = torch.permute(t282, (1, 0, 2, 3))  # t283
    # t283 = ltorch.permute(t282, (1, 0, 2, 3))  # t283
      # t283 = prims.transpose(t282, (1, 0, 2, 3))  # t283
  del t282
  t284 = torch.reshape(t283, [i16, 3, 32, 11, 11])  # t284
    # t284 = ltorch.reshape(t283, [i16, 3, 32, 11, 11])  # t284
      # t284 = prims.reshape(t283, (i16, 3, 32, 11, 11))  # t284
  del t283
  t285 = torch.permute(t284, (1, 0, 2, 3, 4))  # t285
    # t285 = ltorch.permute(t284, (1, 0, 2, 3, 4))  # t285
      # t285 = prims.transpose(t284, (1, 0, 2, 3, 4))  # t285
  del t284
  t286 = torch.reshape(t285, [3, 32, 11, 11])  # t286
    # t286 = ltorch.reshape(t285, [3, 32, 11, 11])  # t286
      # t286 = prims.reshape(t285, (3, 32, 11, 11))  # t286
  del t285
  t231 = torch.reshape(t230, (i97, i98, i99, i100))  # t231
    # t231 = ltorch.reshape(t230, (i97, i98, i99, i100))  # t231
      # t231 = prims.reshape(t230, (i97, i98, i99, i100))  # t231
  del t230, i97, i98, i99, i100
  t232 = torch_pad_prim_impl(t231, 0.0, [(0, 0, 0), (0, 0, 0), (0, 0, 1), (0, 0, 1)])  # t232
  del t231
  t237 = torch.flip(t236, (2, 3))  # t237
    # t237 = ltorch.flip(t236, (2, 3))  # t237
      # t237 = prims.flip(t236, (2, 3))  # t237
  del t236
  t238 = torch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
    # t238 = ltorch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
      # t238 = prims.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
  del t232, t237, i87, i89, i90, i91
  [t270] = nvFusion1(i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238)
    # t241 = prims.pad(t238, 0.0, [(0, 0, 0), (0, 0, 0), (i239, 0, 0), (i239, 0, 0)])  # t241
    # t265 = prims.pad(t241, 0.0, [(i257, i258, 0), (i259, i260, 0), (i261, i262, 0), (i263, i264, 0)])  # t265
    # t266 = prims.slice(t265, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t266
    # t267 = prims.slice(t266, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t267
    # t268 = prims.slice(t267, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t268
    # t269 = prims.slice(t268, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t269
    # t270 = prims.where(t2, t269, 0.0)  # t270
  del i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238
  t287 = torch.permute(t270, (1, 0, 2, 3))  # t287
    # t287 = ltorch.permute(t270, (1, 0, 2, 3))  # t287
      # t287 = prims.transpose(t270, (1, 0, 2, 3))  # t287
  del t270
  t288 = torch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
    # t288 = ltorch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
      # t288 = prims.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
  del t286, t287, i11, i12, i7, i8, i14, i15, i16
  t289 = torch.permute(t288, (1, 0, 2, 3))  # t289
    # t289 = ltorch.permute(t288, (1, 0, 2, 3))  # t289
      # t289 = prims.transpose(t288, (1, 0, 2, 3))  # t289
  del t288
  return (None, t289)

Pitch

I'm prototyping this in a draft PR (not functional yet!)

Alternatives

We can have pooling layers as we prim as well, but I don't think that's a necessity at this point.

Comparison with `torch.compile` instead of Eager

๐Ÿ“š Documentation

Hey! I saw your tool and plots with "acceleration", but you compare to un-optimised eager torch, which is obviously slower. Could you provide a graph with comparison against basic native pytorch loop, where you torch.compile the model? It would be useful for people who already have some optimisations in their pipelines, but would like to try yours framework instead

Thanks

Functional JIT loading closures sharp edge

Strategy required

This issue resumes form PR2410, we need to decide on the strategy for closures sharp edge. Let's start simple, I think we can all agree that this is a sharp edge if we jit foo:

x = 5
def foo():
      return x

And that's because we are using a variable outside of the jitted scope. However, here is where things get interesting: should we consider the following a sharp egde?

def foo(x):
    def bar():
        return x
    return bar()

I assume that, since we captured x when jitting foo, this should not be a sharp edge for bar because the variable was declared in the scope(or in this case captured). To fix such a case we can remember what variables we captured and then look them up when we see a freevar. However, @mruberry has an interesting point, what happens in the case that the variable gets deleted? How can we deal with something like:

def foo():
  a = 5

  def bar():
    nonlocal a
    del a

  bar()

  return a

In conclusion, what do you think should be the definition of sharp edge in this context?

cc @apaz-cli @t-vi @mruberry

Support for torchvision models, e.g., a simple ViT

๐Ÿ› Bug

I was trying to run a simple torchvision ViT and am getting the following error:

File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 136, in <module>
    train(
  File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 31, in train
    logits = model(features)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 194, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 611, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 262, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 498, in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 175, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/jit_ext.py", line 1386, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6580, in fn_
    raise e
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6543, in fn_2
    return fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(x, x, x, need_weights=False)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1236, in forward
    any_nested = query.is_nested or key.is_nested or value.is_nested
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 1253, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/proxies.py", line 1234, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_nested

Not sure how to go about debugging this. I thought that sharing this may help improving thunder in terms of supporting more models and edge cases

To Reproduce

Steps to reproduce the behavior:

I attached self-contained code in the zip.

# Runs PyTorch eager, works ok!

python 01_pytorch-vit.py

# Runs torch.compile, works ok!
python 01_pytorch-vit.py --compilation_option "torch.compile"

# Runs thunder.jit(), fails! (See error above)
python 01_pytorch-vit.py --compilation_option "thunder_default"

Code sample

See zip attached

Expected behavior

Either a clearer error message or ideally it should work :)

Environment

Same as Zero to Thunder studio.

Archive.zip

cc @apaz-cli

Handling inplace through SSA

This issue is to facilitate discussion of inplace handling, namely the "big" solution of having a static single assignment (SSA) representation.

For any handling of inplace, we want to make certain that two things are achieved:

  • we don't want to take shortcuts that complicate passes by introducing the need to detect obstacles to optimizations, because it would harm usability and extensibility of Thunder.
  • we don't want to create ad-hoc band-aids to get things working that we would need to regress on later to introduce more proper handling because developing in the open more or less means no regressions.

Some thoughts from video/chat discussions:

About the problem:

  • The key difficulty in SSA is that we would need to keep track of which tensors get modified by an inplace update (i.e. which
    have memory that is to be updated), so we would need to know about views (the fancy term is alias analysis),
  • this is difficult for some things in PyTorch (i.e. reshape),
  • "assuming the worst" works to some extend.

Solution considerations:

  • Likely we would want inplace updates to have all affected tensors as outputs.
  • on inputs we would need to check for aliases as part of the prologue (maybe with a separate "assume aliasing is the OK" cache mode or sorts later),
  • operations need to know if their output is a view of their inputs (difficult for reshape, easy for most others),
  • initially, we would only check if tensors share storage,
  • likely the translation could be done in the interpretation phase,
  • we would need to have versioning / disambiguation of versions for tensor proxies during this, but not when we have the SSA.

Later versions could refine the alias analysis as needed.

@tfogal @mruberry @IvanYashchuk

Operator support for `F.hardswish`

๐Ÿš€ Feature

Implement HardSwish activation function.

Motivation

Relatively easy activation function implementation as a good first issue as nikitaved suggested under #64

Pitch

Add HardSwish (x * ReLU6(x + 3) / 6) leveraging existing ReLU6 support.

cc @apaz-cli

Label tracking meta-issue (edit me to get automatically CC'ed on issues!)

This issue is used by lightning-probot to manage subscriptions to labels. To subscribe yourself to a label, add a line * label @yourusername, or add your username to an existing line (space separated) in the body of this issue. Do not try to subscribe in comments, the bot only parses the initial post.

This is a copy of pytorch/pytorch#24422.

As a courtesy to others, please do not edit the subscriptions of users who are not you.


The current list of labels can be retrieved with $ gh label list --limit 1000 --json name --jq '.[] | "* " + .name' | sort -n

Non-`topk` related issue in `mixtral`-like model tests.

๐Ÿ› Bug

Now that we have topk supported, it is time to unlock some tests. However, the following diff:

diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py
index d1d55073..ad69a721 100644
--- a/thunder/tests/test_jit_general.py
+++ b/thunder/tests/test_jit_general.py
@@ -613,7 +613,7 @@ def test_nanogpt():
         "falcon-7b-like",
         "falcon-40b-like",
         "codellama2-like",
-        pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
+        "mixtral-like",
     ),
 )
 @pytest.mark.parametrize(

Breaks pytest -sv thunder/tests/test_jit_general.py -k test_litgpt_variants[cpu-mixtral-like] with

___________________________________________________________________________________________________ test_litgpt_variants[cpu-mixtral-like] ___________________________________________________________________________________________________

name = 'mixtral-like', device = device(type='cpu')

    @skipif_not_pytorch_2_1
    @pytest.mark.parametrize(
        "name",
        (
            "gpt-neox-like",
            "llama1-like",
            "long-context-like",
            "llama2-like",
            "falcon-7b-like",
            "falcon-40b-like",
            "codellama2-like",
            "mixtral-like",
        ),
    )
    @pytest.mark.parametrize(
        "device",
        ("cpu", "cuda"),
    )
    def test_litgpt_variants(name, device):
        if device == "cuda" and not torch.cuda.is_available():
            pytest.skip("CUDA not available")
    
        device = torch.device(device)
    
        x = torch.randint(0, 200, (5, 5), device=device)
        config = litgpt_model.Config.from_name(name)
    
        with device:
            reference = litgpt_model.GPT(config)
        expected_logits = reference(x)
    
        expected_logits.sum().backward()
    
        with device:
            model = litgpt_model.GPT(config)
        model.load_state_dict(reference.state_dict())
        tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex)
>       actual_logits = tom(x)

thunder/tests/test_jit_general.py:642: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/__init__.py:194: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
thunder/core/interpreter.py:6684: in fn_
    raise e
thunder/core/interpreter.py:6647: in fn_2
    return fn(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:94: in forward
    x = block(x, cos, sin, mask, input_pos)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:187: in forward
    x = self.mlp(self.norm_2(x)) + x
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:347: in forward
    token_idx, expert_idx = torch.where(mask)
thunder/core/interpreter.py:1258: in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
thunder/core/symbol.py:250: in __call__
    result = self.meta(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (t157,), kwargs = {}, tok = <Token used var=<ContextVar name='langctx' at 0x7fa2ad45a340> at 0x7f9bf1b6bdc0>

    @wraps(fn)
    def _fn(*args, **kwargs):
        try:
            tok = set_langctx(self.langctx)
>           result = fn(*args, **kwargs)
E           TypeError: where() missing 2 required positional arguments: 'a' and 'b'

thunder/core/langctxs.py:124: TypeError
========================================================================================================== short test summary info ===========================================================================================================
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants[cpu-mixtral-like] - TypeError: where() missing 2 required positional arguments: 'a' and 'b'
=============================================================================================== 1 failed, 54 deselected, 10 warnings in 8.04s ================================================================================================

Operator support for `F.one_hot`

๐Ÿ› Bug

thunder fails When attempting to compile a graph containing torch.nn.functional.one_hot within the forward pass.
The error message indicates that the input to the method must be a Tensor, but a TensorProxy is received instead.

To Reproduce

Steps to reproduce the behavior:

  • Define a PyTorch model class with a forward pass involving F.one_hot to convert the input tensor to a one-hot encoded representation.
  • Create an instance of the model and evaluate it on a random input tensor.
  • Compile the model using thunder.jit.
  • Call the compiled model with the same input tensor.

Example

import thunder


class MLP(nn.Module):
    def __init__(self, hidden_size=1024):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(6 * 256, hidden_size, bias=False)
        self.head = nn.Linear(hidden_size, 32000, bias=False)

    def forward(self, inputs):
        x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
        x = self.hidden(x)
        logits = self.head(x)
        return logits


x = torch.randint(0, 6, (1, 256))

model = MLP(1024).eval()
print(model(x))

model = thunder.jit(model)
print(model(x))
Output
tensor([[-0.1134, -0.0827, -0.0205,  ...,  0.0757,  0.0066,  0.0974]],
       grad_fn=<MmBackward0>)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-6-6425e5faad6e>](https://localhost:8080/#) in <cell line: 23>()
     21 
     22 model = thunder.jit(model)
---> 23 print(model(x))

16 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    192 
    193     def forward(self, *args, **kwargs):
--> 194         res = self._forward_fn(*args, **kwargs)
    195         return res
    196 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in fn_(*args, **kwargs)
    609         cs.calls += 1
    610 
--> 611         cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    612         cs.last_trace_host_execution_start = time.time_ns()
    613 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in cache_info_wrapper(*args, **kwargs)
    260         tok = _cache_info_ctx.set({})
    261         try:
--> 262             res = fn(*args, **kwargs)
    263         finally:
    264             _cache_info_ctx.reset(tok)

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in get_computation_and_inputs(*args, **kwargs)
    496                 prologue_trc: TraceCtx
    497                 computation_trc: TraceCtx
--> 498                 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    499                     fn, args, kwargs, sharp_edges=cd.sharp_edges
    500                 )

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in _general_frontend(fn, args, kwargs, sharp_edges)
    173 # Translates the Python function to a thunder program using the thunder interpreter
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
    176 
    177 

[/usr/local/lib/python3.10/dist-packages/thunder/core/jit_ext.py](https://localhost:8080/#) in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1384     with general_jit_ctx(ctx):
   1385         with tracectx(computation_trace):
-> 1386             result = jfn(*args, **kwargs)
   1387             prims.python_return(result)
   1388             process_recorded_modifications(ctx, epilogue_trace)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_(*args, **kwargs)
   6578                 assert isinstance(e, BaseException), e
   6579                 runtimectx.curexc = None
-> 6580                 raise e
   6581 
   6582             return interpretation_result

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_2()
   6541                 def getfn():
   6542                     def fn_2(args, kwargs):
-> 6543                         return fn(*args, **kwargs)
   6544 
   6545                     return fn_2

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _wrapped_call_impl()
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_impl()
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in forward()
      9 
     10     def forward(self, inputs):
---> 11         x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
     12         x = self.hidden(x)
     13         logits = self.head(x)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)
   6067         kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
   6068         try:
-> 6069             opaque_result: Any = fn(*args_, **kwargs_)
   6070         except Exception as e:
   6071             runtimectx.curexc = e

TypeError: one_hot(): argument 'input' (position 1) must be Tensor, not TensorProxy

Environment

  • OS: Ubuntu/Google Colab
  • Python Version: 3.10
  • PyTorch Version: 2.3.0.dev20240314+cu121
  • Thunder Version: 0.1.0
  • Installation:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
pip install lightning-thunder

Additional context

  • Other functional methods like F.relu doesn't seem to raise the issue.

Add `torch.nn.Dropout` recomputation support during the backward pass to Thunder

๐Ÿš€ Feature

I would like to have Thunder save the seed and offset from random number generation to allow for the recomputation of Dropout in the backward pass.

There are two pieces needed to make it work:

  • Support stateless (deterministic) PRNG. This is done with thunder.prims.uniform_philox.
  • Trace transform to query PyTorch's PRNG state before each uniform call, replacing uniform with uniform_philox , and incrementing PRNG state properly. This is not implemented.

Motivation

Multihead Attention modules in LLMs often use dropout where the memory used is the square of the sequence length.

cc @apaz-cli

adding DDP/FSDP transform after JITting does not work

๐Ÿ› Bug

The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: Lightning-AI/litgpt#1204

The objective is that fsdp|ddp can be applied after the thunder.jit call.

It works with FSDP, but not with DDP where it fails with:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]:     out = tmodel(x)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]:     res = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]:     computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]:     bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]:     updated_bwd_trace = visitor_transform(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]:     visit_type = visit(bsym)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]:     self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]:     self._maybe_allreduce(bucket, group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]:     self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]:     result = self.meta(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]:     result = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]:     utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]:     check(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:     raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>

To Reproduce

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)

out = tmodel(x)

if local_rank == 0:
    print(thunder.last_backward_traces(tmodel)[-1].python())

torchrun --nproc-per-node 2 bug.py

cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23

Updating a nn.Module attribute in forward raises an exception in prologue trace.

import torch
import thunder

import thunder.examine

class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.bar = 1

    def forward(self, x):
        self.bar = self.bar + 1
        # self.bar = 2  # This works
        return x

m = MyModule()

x = torch.randn(16, 16, device='cuda')

jit_linear = thunder.jit(m)

o = jit_linear(x)

Error:

File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 537, in get_computation_and_inputs
    inps = pro(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "thunder.prologue_0", line 16, in prologue
  File "/home/kkalambarkar/lightning-thunder/thunder/executors/pythonex.py", line 100, in _check_number_type_and_value_impl
    utils.check(
  File "/home/kkalambarkar/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: Expected 2 to be equal to and have the type of 1

cc @apaz-cli

`test_vjp_correctness` fails with ops that return tensors that do not require grads.

๐Ÿ› Bug

As per title. To reproduce, one could uncomment these tests in these tests in #118 to get:

thunder/tests/test_grad.py:423: in test_vjp_correctness                                                                                                                                                                                       
    result = run_snippet(                                                                                                                                                                                                                     
thunder/tests/framework.py:483: in run_snippet                                                                                                                                                                                                
    raise ex                                                                                                                                                                                                                                  
thunder/tests/framework.py:475: in run_snippet                                                                                                                                                                                                
    snippet(*args, **kwargs)                                                                                                                                                                                                                  
thunder/tests/test_grad.py:394: in snippet_vjp_correctness                                                                                                                                                                                    
    check_vjp(func, *args, executor=executor)                                                                                                                                                                                                 
thunder/tests/test_grad.py:304: in check_vjp                                                                                                                                                                                                  
    _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)                                                                                                                                      
thunder/common.py:783: in _fn                                                                                                                                                                                                                 
    trc_or_result = trace(compile_data=cd)(processed_function, *args, **kwargs)                                                                                                                                                               
thunder/core/interpreter.py:1298: in fn_                                                                                                                                                                                                      
    return fn(*args, **kwargs)                                                                                                                                                                                                                
thunder/common.py:534: in _trace                                                                                                                                                                                                              
    result = fn(*proxyargs, **proxykwargs)                                                                                                                                                                                                    
thunder/core/transforms.py:3629: in _vjp                                                                                                                                                                                                      
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)                                                                                                                                                                         
thunder/core/transforms.py:3603: in vjp_call_metafunc                                                                                                                                                                                         
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)                                                                                                                                                                     
thunder/core/transforms.py:3414: in augmented_forward_pass                                                                                                                                                                                    
    result, env = eval_trace(                                                                                                                                                                                                                 
thunder/core/transforms.py:1693: in eval_trace                                                                                                                                                                                                
    prim_func = symbol_mapper(symbol)                                                                                                                                                                                                         
thunder/core/transforms.py:3338: in vjp_symbol_mapper                                                                                                                                                                                         
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)                                                                                                                                                                             
thunder/core/vjp_utils.py:99: in make_aug_forward_and_backward                                                                                                                                                                                
    backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0])                                                                                                                                    
thunder/core/utils.py:1062: in find_producer_symbols                                                                                                                                                                                          
    if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
                                                                                                                                                                                                                                              
x = None                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
>   if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
E   AttributeError: 'NoneType' object has no attribute 'name'                                                                                                                                                                                 
                                                                                                                                                                                                                                              
thunder/core/utils.py:1062: AttributeError     

Better name for elements of list in `prologue_trace` and `computation_trace`

import thunder
import torch

def foo(xs):
    result = []
    for x in xs:
        result.append(x + x)
    return result

jfoo = thunder.jit(foo)

o = jfoo([torch.randn(3,),] * 6)
print(thunder.last_prologue_traces(jfoo)[-1])
print(thunder.last_traces(jfoo)[-1])

Names for the arguments to the computation trace are : res, x, a, b, t_0_4, t_0_5. It would be nice if there was a consistent pattern.

Traces

Prologue Trace

# Constructed by Transform for execution (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 1)
    # prims.check_len(args, 1)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  subscr: "Any" = args[0]
  res: "cpu f32[3]" = subscr[0]
  x: "cpu f32[3]" = subscr[1]
  a: "cpu f32[3]" = subscr[2]
  b: "cpu f32[3]" = subscr[3]
  t_0_4: "cpu f32[3]" = subscr[4]
  t_0_5: "cpu f32[3]" = subscr[5]
  ...
  return (res, x, a, b, t_0_4, t_0_5)

Computation Trace

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(res, x, a, b, t_0_4, t_0_5):
  # res: "cpu f32[3]"
  # x: "cpu f32[3]"
  # a: "cpu f32[3]"
  # b: "cpu f32[3]"
  # t_0_4: "cpu f32[3]"
  # t_0_5: "cpu f32[3]"
  result = torch.add(res, res)  # result: "cpu f32[3]"
    # result = ltorch.add(res, res, alpha=None)  # result: "cpu f32[3]"
      # result = prims.add(res, res)  # result: "cpu f32[3]"
  del res
  ...
  return [result, t1, t2, t3, t4, t5]

Support `memory_format` on `to()`

๐Ÿš€ Feature

to(memory_format=something) is part of the MegatronImagen model in NeMo.

Ideally, this would work:

$ git diff .
diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imen/imagen.py
index 4fa6cd230..2cf7a8ffa 100644
--- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
+++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
@@ -31,6 +31,7 @@ from nemo.collections.nlp.modules.common.megatron.module import Float16Module
 from nemo.collections.nlp.parts.utils_funcs import get_last_rank
 from nemo.core.classes.common import Serialization
 from nemo.utils import logging
+import thunder
 
 try:
     from apex import amp
@@ -190,6 +191,7 @@ class MegatronImagen(MegatronBaseModel):
         self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
 
         self.model = self.model_provider_func()
+        self.model = thunder.jit(self.model)
 
         if self.trainer.precision in ['bf16', 'bf16-mixed']:
             self.autocast_dtype = torch.bfloat16

Motivation

Trying to evaluate NeMo models in thunder and expand our model support there. Megatron-based models appear to be widely used.

Alternatives

I wonder if we could temporarily just accept the keyword without actually doing anything about it. I imagine that would be very slow, but it might allow us to get models like this one into thunder more easily.

I'll start trying to convert smaller parts of the model next.

Additional context

Model in question:

https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L175

I think the to that is failing for me
is actually this line:
https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L135

Model test:
log.txt

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.