Giter VIP home page Giter VIP logo

flash-linear-attention's People

Contributors

donglixp avatar doraemonzzz avatar eltociear avatar hypnopump avatar ridgerchu avatar sustcsonglin avatar yzhangcs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flash-linear-attention's Issues

'RebasedFeatureMap' is missing?

Hi, thanks to all your efforts for such a good repo.

I'm trying to build a model basing on RebasedLinearAttention, but encounter error:

ImportError: cannot import name 'RebasedFeatureMap' from 'fla.modules.feature_map' (.*/fla/modules/feature_map.py)

Should this feature under developing or use hazyResearch's implementation located in https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py?

Error loading pretrained checkpoints through `transformers` library

It rasies the following error when I run model = AutoModelForCausalLM.from_pretrained("fla-hub/gla-1.3B-200B").

ValueError: The checkpoint you are trying to load has model type gla but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Greetings!

First off, I want to express my gratitude for sharing your incredible work; it's truly impressive! However, when I attempted to execute the test code outlined in the README, I encountered the following error. Could you kindly offer any guidance or recommendations on how to resolve this issue?

Thank you for your assistance!

Traceback (most recent call last):
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/unit_test.py", line 19, in <module>
    y2 = gla(x)
  File "/home/z'z'z/.conda/envs/gla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/layers/gla.py", line 95, in forward
    o = fused_chunk_gla(q, k, v, gk)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/gla/chunk_fuse.py", line 516, in fused_chunk_gla
    o, final_state = FusedChunkGLAFunction.apply(
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/utils.py", line 11, in wrapper
    return fn(ctx,
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/share/project/zzzr/jjlm/old_ckpts/VisionProjects/up-to-date-lib/flash-linear-attention/fla/ops/gla/chunk_fuse.py", line 320, in forward
    fwd_decay_cumsum[grid](
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in __getattribute__
    self._init_handles()
  File "/home/zzz/.conda/envs/gla/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _init_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
RuntimeError: Triton Error [CUDA]: device kernel image is invalid

And my environment is shown below:

PyTorch version: 2.2.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 470.129.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          256
On-line CPU(s) list:             0-255
Thread(s) per core:              2
Core(s) per socket:              64
Socket(s):                       2
NUMA node(s):                    8
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7742 64-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         3235.858
CPU max MHz:                     2250.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        4491.50
Virtualization:                  AMD-V
L1d cache:                       4 MiB
L1i cache:                       4 MiB
L2 cache:                        64 MiB
L3 cache:                        512 MiB
NUMA node0 CPU(s):               0-15,128-143
NUMA node1 CPU(s):               16-31,144-159
NUMA node2 CPU(s):               32-47,160-175
NUMA node3 CPU(s):               48-63,176-191
NUMA node4 CPU(s):               64-79,192-207
NUMA node5 CPU(s):               80-95,208-223
NUMA node6 CPU(s):               96-111,224-239
NUMA node7 CPU(s):               112-127,240-255
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.2.1+cu118
[pip3] torchaudio==2.2.1+cu118
[pip3] torchvision==0.17.1+cu118
[pip3] triton==2.2.0
[conda] numpy                     1.26.3                   pypi_0    pypi
[conda] torch                     2.2.1+cu118              pypi_0    pypi
[conda] torchaudio                2.2.1+cu118              pypi_0    pypi
[conda] torchvision               0.17.1+cu118             pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi

RWKV6 backward issue

Hi
I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.

  File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
    _u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
    h = tl.zeros([BV, BK], dtype=tl.float32)
    if USE_INITIAL_STATE:
        p_init_s = initial_state + i_bh * DK * DV + \
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
                     ^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')

Thanks for considering this issue

Triton Error in flash-linear-attention/fla/modules/rmsnorm.py

PyTorch=='2.3.0a0+ebedce2'
Triton=='2.2.0'
(from out-of-box NVIDIA PyTorch Container 24.02 nvcr.io/nvidia/pytorch:24.02-py3)

[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/train.py", line 705, in <module>
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     logits, loss = model(X, Y)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     ori_output = self.model(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     mamba_outputs = self.backbone(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     hidden_states = mixer_block(hidden_states, cache_params=cache_params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     hidden_states = self.norm(hidden_states)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return rms_norm_fn(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return super().apply(*args, **kwargs)  # type: ignore[misc]
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     y, mean, rstd, residual_out = _layer_norm_fwd(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     MAX_FUSED_SIZE = 65536 // x.element_size()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     _layer_norm_fwd_1pass_kernel[(M,)](
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     if len(self.configs) > 1:
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     fn()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.pre_hook(args)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     self.fn.run(
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     grid = get_special_arg("grid")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_warps = get_special_arg("num_warps")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_ctas = get_special_arg("num_ctas", 1)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     num_stages = get_special_arg("num_stages")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     bound_args = self.signature.bind(*args, **kwargs)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     bound_args.apply_defaults()
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 435, in resume_in_run_at_430
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] ==========
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,021] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* specialization_key             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 170
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] WON'T CONVERT <genexpr> /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py line 436
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ========== TorchDynamo Stack Trace ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ========== The above exception occurred while processing the following code ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/train.py", line 705, in <module>
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     logits, loss = model(X, Y)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     ori_output = self.model(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     mamba_outputs = self.backbone(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     hidden_states = mixer_block(hidden_states, cache_params=cache_params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     hidden_states = self.norm(hidden_states)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return self._call_impl(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return forward_call(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return rms_norm_fn(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return super().apply(*args, **kwargs)  # type: ignore[misc]
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     y, mean, rstd, residual_out = _layer_norm_fwd(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     MAX_FUSED_SIZE = 65536 // x.element_size()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     _layer_norm_fwd_1pass_kernel[(M,)](
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.nargs = dict(zip(self.arg_names, args))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     if len(self.configs) > 1:
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     fn()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.pre_hook(args)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     self.fn.run(
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     grid = get_special_arg("grid")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_warps = get_special_arg("num_warps")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_ctas = get_special_arg("num_ctas", 1)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     num_stages = get_special_arg("num_stages")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     bound_args = self.signature.bind(*args, **kwargs)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     bound_args.apply_defaults()
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     assert len(bound_args.arguments) == len(self.params)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 436, in resume_in_run_at_430
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] 
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] ==========
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] Traceback (most recent call last):
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     result = inner_convert(frame, cache_entry, hooks, frame_state)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 332, in _convert_frame_assert
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     unimplemented("generator")
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO]     raise Unsupported(msg)
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [INFO] torch._dynamo.exc.Unsupported: generator
[2024-04-05 19:06:32,022] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 450
[2024-04-05 19:06:32,022] [33/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _device_of /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:231
[2024-04-05 19:06:32,022] [33/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:231 in _device_of (JITFunction)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         @staticmethod
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:233 in _device_of (JITFunction._device_of)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             try:
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE SETUP_FINALLY 12 []
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:234 in _device_of (JITFunction._device_of)
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 return arg.device.type
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST arg []
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR device [LazyVariableTracker()]
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_arg_ L['arg']
[2024-04-05 19:06:32,023] [33/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['arg'] (16384, 1308) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], tensor_source=LocalSource(local_name='arg', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={})
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR type [ConstantVariable(device)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_BLOCK None [ConstantVariable(str)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [ConstantVariable(str)]
[2024-04-05 19:06:32,025] [33/0] torch._dynamo.convert_frame: [DEBUG] Skipping frame because no content in function call _device_of                     /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 231
[2024-04-05 19:06:32,025] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 451
[2024-04-05 19:06:32,025] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* <listcomp>             /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py 453
[2024-04-05 19:06:32,025] [34/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _pinned_memory_of /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:238
[2024-04-05 19:06:32,025] [34/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:238 in _pinned_memory_of (JITFunction)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         @staticmethod
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:240 in _pinned_memory_of (JITFunction._pinned_memory_of)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             try:
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE SETUP_FINALLY 12 []
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:241 in _pinned_memory_of (JITFunction._pinned_memory_of)
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 return arg.is_pinned()
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST arg []
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_pinned [LazyVariableTracker()]
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_arg_ L['arg']
[2024-04-05 19:06:32,026] [34/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['arg'] (16384, 1308) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], tensor_source=LocalSource(local_name='arg', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={})
[2024-04-05 19:06:32,028] [34/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 0 [GetAttrVariable(TensorVariable(), is_pinned)]
Traceback (most recent call last):
  File "/raid/xind/TLM/train.py", line 705, in <module>
    logits, loss = model(X, Y)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/model_fla.py", line 26, in forward
    ori_output = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 579, in forward
    mamba_outputs = self.backbone(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 494, in forward
    hidden_states = mixer_block(hidden_states, cache_params=cache_params)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/models/mamba/modeling_mamba.py", line 318, in forward
    hidden_states = self.norm(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 504, in forward
    return rms_norm_fn(
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 488, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 421, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 155, in _layer_norm_fwd
    MAX_FUSED_SIZE = 65536 // x.element_size()
  File "/raid/xind/TLM/new_models/flash-linear-attention/fla/modules/rmsnorm.py", line 162, in resume_in__layer_norm_fwd_at_155
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in run
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run_at_127
    self.nargs = dict(zip(self.arg_names, args))
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 128, in resume_in_run_at_127
    if len(self.configs) > 1:
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run_at_128
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 109, in kernel_call
    self.pre_hook(args)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call_at_109
    self.fn.run(
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 413, in run
    grid = get_special_arg("grid")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run_at_413
    num_warps = get_special_arg("num_warps")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run_at_414
    num_ctas = get_special_arg("num_ctas", 1)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run_at_415
    num_stages = get_special_arg("num_stages")
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 417, in resume_in_run_at_416
    enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run_at_417
    bound_args = self.signature.bind(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run_at_426
    bound_args.apply_defaults()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_427
    assert len(bound_args.arguments) == len(self.params)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run_at_429
    assert len(bound_args.arguments) == len(self.params)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 430, in resume_in_run_at_429
    args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 453, in resume_in_run_at_430
    [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 453, in <listcomp>
    [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 580, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 741, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 384, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 643, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 246, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 524, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 489, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2110, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 780, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 462, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1190, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 644, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 645, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 770, in call_method
    return wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1302, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1387, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1590, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1545, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1086, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1546, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1657, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1638, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1480, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1711, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 442, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 827, in nyi
    assert func not in _device_not_kwarg_ops, f"NYI: {func}"
torch._dynamo.exc.TorchRuntimeError: Failed running call_method is_pinned(*(FakeTensor(..., device='cuda:0', size=(16384, 1308), requires_grad=True),), **{}):
NYI: aten.is_pinned.default

from user code:
   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 241, in _pinned_memory_of
    return arg.is_pinned()


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Function                                  Runtimes (s)
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] --------------------------------------  --------------
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner                 2.5749
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] OutputGraph.call_user_compiler                  1.282
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] create_aot_dispatcher_function                  1.3513
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] compile_fx.<locals>.fw_compiler_base            1.2468
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] GraphLowering.run                               0.0301
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] GraphLowering.compile_to_module                 1.0085
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Scheduler.__init__                              0.0083
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] Scheduler.codegen                               0.1939
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] WrapperCodeGen.generate                         0.0011
[2024-04-05 19:06:32,062] torch._dynamo.utils: [INFO] CachingAutotuner.benchmark_all_configs          0.1649

inconsistent results when "masking" gating term between "fused_recurrent" and "fused_chunk" (fused_chunk presumably wrong)

Hi,

I tried to implement an option where I can pass a mask to a GLA layer that sets to zero certain hidden states so that I can pack sequences and avoid data contamination between them. For that I set corresponding gk to -infty so that gp.exp() = 0.
It boils down to these two lines :

if reset_mask is not None:
    gk = gk.masked_fill(reset_mask, -0.1*torch.finfo(gk.dtype).max)

where reset_mask is a boolean mask that indicates first token of packed sequences in each row.

Apparently it works when using using "fused_recurrent" mode. In this test I compare concat(gla(x1), gla(x2)) against gla(concat(x1, x2), reset_mask=reset_mask). It also works on real data (while contaminated version fails).


device = "cuda"
mode = "fused_recurrent"
b, n, d = 1, 8, 64
gla = GatedLinearAttention(hidden_size=d, mode=mode).to(device)

x = torch.randn(b, n, d).to(device)
x1, x2 = x.chunk(2, dim=1)

reset_mask = (torch.arange(n, device=device)%(n//2)) == 0 #False everywhere except in first position of x1 and x2
reset_mask = rearrange(reset_mask, "n -> 1 n 1")

y = gla(x, reset_mask=reset_mask)
y1, y2 = gla(x1), gla(x2)

assert torch.allclose(torch.cat((y1, y2), dim=1), y)

However "fused_chunk" curiously fails the test, and training on real data is unstable.

mode = "fused_recurrent"
...
>>> assert torch.allclose(torch.cat((y1, y2), dim=1), y)
AssertionError

Is that a bug ?

Thank you,
Théodor

AssertionError('All values in both first input shape ([constexpr[16], constexpr[8]]) and second input shape ([constexpr[8], constexpr[16]]) must be >= 16!')

I am trying to use GLABlock with batch size 1, but encounter this error. How can I resolve this?

My current config:

 config = GLAConfig(
    hidden_size=channels,
    num_hidden_layers=n_layer,
    num_attention_heads=num_heads,
    num_heads=num_heads,
    attn_mode='fused_chunk',
    expand_k=0.5,
    expand_v=1.0,
    hidden_act="swish",
    bid_mode='layer',
    use_dirpe=False,
    rms_norm_eps=1e-6,
    if_norm_qkv=False,
    if_scale_qkv=False,
    fuse_norm=True,
)

Finetune RWKV6 with fla implementations (使用fla中的rwkv6微调)

在bo的代码基础上将cuda算子替换为fla后loss初始很高,8.几起步。之前直接使用gla算子替换cuda,因为gla和rwkv的state计算顺序错开所以那里的r需要roll一下(微调正常)。所以麻烦您看一下我的代码是否少了一些必要操作

image

illegal memory access error

Hi Songlin,

Thanks for your great work! I tried some comparative experiments on FLA recently and it presents great performance! But I faced an error when I increased the dimension to 1152 and set the head num to 16. The details of this error are shown below:

n136-180-028:1707397:1709473 [6] include/alloc.h:124 NCCL WARN Cuda failure 700 'an illegal memory access was encountered'
n136-180-028:1707397:1709473 [6] NCCL INFO include/alloc.h:245 -> 1

Could you give me some advice or guidance about this? Thanks a lot! By the way, could I get your WeChat for further communication? :D

Yours,
Lianghui

RWKV6 backward gives nan gradients

U gradient is fine. all other grads grow uncontrolably

import torch as th
from fla.ops.rwkv_6.recurrent_fuse import fused_recurrent_rwkv6

B, H, L, K, V = 2, 4, 256, 64, 64

r, k, v, w = th.randn(4, B, H, L, K).cuda()
w = w.sigmoid()
u = th.randn(H, K).cuda()

r.requires_grad = True
k.requires_grad = True
v.requires_grad = True
w.requires_grad = True
u.requires_grad = True

o, state = fused_recurrent_rwkv6(r, k, v, w, u)

print(o.shape)
o.mean().backward()
print(u.grad.shape)
print(w.grad)

Mistakes in the GLA paper

Thank you for all your great work on linear attention and I'm very excited about this repo!

I just wanted to bring up some errors in the GLA paper which might make the paper look less valuable, and maybe you could fix these in a new version if you want to.

Shouldn't the division by the temperature term be inside the brackets?
Screenshot 2024-01-10 at 09 35 53

Why is the Q_t transposed here?
Screenshot 2024-01-10 at 09 36 39

For the B calculation it should be the product of beta not B
Screenshot 2024-01-10 at 09 38 06

Again here why is the Q transposed?
Screenshot 2024-01-10 at 09 39 51

The V = V[...] is missing here
Screenshot 2024-01-10 at 09 40 44

Shouldn't a_normaliser and b_normaliser be a_q[...] not a[...]?
Screenshot 2024-01-10 at 09 41 32

Shouldn't it be k = K[:,:,iC + jc,....] not k*c?
Screenshot 2024-01-10 at 09 43 08

Shouldn't it be V_iC + k, not V_iCk?
Screenshot 2024-01-10 at 09 44 12

If any of these are not actually mistakes then feel free to point that out, hope this helps!

Using operators without having `transformers` installed

I'm currently trying to use just the operators defined in fla.ops; however, because of the __init__.py script for the main package, it's not possible to do this without importing things from the HF transformers package, which makes the import slower (and broke it entirely until I upgraded the package).

It would be nice if there were a way to just import the operators without the layer modules or anything else.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.