Giter VIP home page Giter VIP logo

rwkv_pytorch's Issues

请问一下代码中的并行forward 是什么意思呢

   def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
        """
        模型的前向传播。
        Args:
            x (torch.Tensor): 输入张量,形状为[Batch, N_embd]。
            state (torch.Tensor): 隐藏状态张量,形状为[Batch, State Size, N_embd]。
            i (int): 时间索引。
        Returns:
            torch.Tensor: 前向传播结果张量,形状与输入的x相同。
        """
        if self.onnx_opset >= 17:
            x = x + self.time_mixing(self.ln1(x), state, i)
            x = x + self.channel_mixing(self.ln2(x), state, i)
        else:
            x = x + self.time_mixing(self.manual_layer_norm(x, self.ln1_weight, self.ln1_bias, 1e-5), state, i)
            x = x + self.channel_mixing(self.manual_layer_norm(x, self.ln2_weight, self.ln2_bias, 1e-5), state, i)
        return x
        
    def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
        """
        模型的并行前向传播。
        Args:
            x (torch.Tensor): 输入张量,形状为[Batch, L, N_embd]。
            state (torch.Tensor): 隐藏状态张量,形状为[Batch, State Size, N_embd]。
            i (int): 时间索引。
        Returns:
            torch.Tensor: 前向传播结果张量,形状与输入的x相同。
        """
        if self.onnx_opset >= 17:
            x = x + self.time_mixing_parallel(self.ln1(x), state, i)
            x = x + self.channel_mixing_parallel(self.ln2(x), state, i)
        else:
            x = x + self.time_mixing_parallel(self.manual_layer_norm(x, self.ln1_weight, self.ln1_bias, 1e-5), state, i)
            x = x + self.channel_mixing_parallel(self.manual_layer_norm(x, self.ln2_weight, self.ln2_bias, 1e-5), state, i)
        return x

我发现这里面的输入一个是[Batch, N_embd],另外一个是[Batch, L, N_embd],请问这里面的 L是什么意思呢

Is there a way to export using torch.jit.script ?

Thanks for this great repository!

I was wondering if there a way to export using torchscript? I tried a simple approach with torch.jit.script(model), but I get:

RuntimeError: 
Module 'RWKV_Block' has no attribute 'att_group_norm' :
  File "/data/workspaces/jp/LLMs/RWKV_Pytorch/src/model.py", line 229
        # 展平x并应用组归一化和门控
        if self.onnx_opset >= 18:
            x = self.att_group_norm(x.flatten(start_dim=1)) * g
                ~~~~~~~~~~~~~~~~~~~ <--- HERE
        else:
            x = x.flatten(start_dim=1) 
'RWKV_Block.time_mixing' is being compiled since it was called from 'RWKV_Block.forward'
  File "/data/workspaces/jp/LLMs/RWKV_Pytorch/src/model.py", line 319
        """
        if self.onnx_opset >= 17:
            x = x + self.time_mixing(self.ln1(x), state, i)
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            x = x + self.channel_mixing(self.ln2(x), state, i)
        else:

关于在香橙派上部署的一些问题

我还真的试了一下在香橙派ai pro 16G上推理,有以下问题:

  1. 香橙派不支持bf16,只能用fp16和fp32
  2. fp16会nan, 要每隔6层把x/2, 然后attention用fp32

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.