def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
"""
模型的前向传播。
Args:
x (torch.Tensor): 输入张量,形状为[Batch, N_embd]。
state (torch.Tensor): 隐藏状态张量,形状为[Batch, State Size, N_embd]。
i (int): 时间索引。
Returns:
torch.Tensor: 前向传播结果张量,形状与输入的x相同。
"""
if self.onnx_opset >= 17:
x = x + self.time_mixing(self.ln1(x), state, i)
x = x + self.channel_mixing(self.ln2(x), state, i)
else:
x = x + self.time_mixing(self.manual_layer_norm(x, self.ln1_weight, self.ln1_bias, 1e-5), state, i)
x = x + self.channel_mixing(self.manual_layer_norm(x, self.ln2_weight, self.ln2_bias, 1e-5), state, i)
return x
def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
"""
模型的并行前向传播。
Args:
x (torch.Tensor): 输入张量,形状为[Batch, L, N_embd]。
state (torch.Tensor): 隐藏状态张量,形状为[Batch, State Size, N_embd]。
i (int): 时间索引。
Returns:
torch.Tensor: 前向传播结果张量,形状与输入的x相同。
"""
if self.onnx_opset >= 17:
x = x + self.time_mixing_parallel(self.ln1(x), state, i)
x = x + self.channel_mixing_parallel(self.ln2(x), state, i)
else:
x = x + self.time_mixing_parallel(self.manual_layer_norm(x, self.ln1_weight, self.ln1_bias, 1e-5), state, i)
x = x + self.channel_mixing_parallel(self.manual_layer_norm(x, self.ln2_weight, self.ln2_bias, 1e-5), state, i)
return x