From f3281a484e9c7b1235388efa8f6f7b4668466e7d Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 22 Jul 2024 09:22:01 +0200 Subject: [PATCH] shape comment --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 09ca0803d2b0..180027fba585 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -271,7 +271,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca # 3.c perform the recurrence y ← SSM(A, B, C)(x) if self.use_mambapy and self.training and cache_params is None: - hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, intermediate_size, seq_len, ssm_state_size] + hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size] scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] scan_output = scan_output + hidden_states * self.D[None, :, None]