From a8e12bfb86ed2472d99f569c93d14496838a9d49 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 7 May 2024 11:36:00 +0200 Subject: [PATCH 1/2] fix typos and one shape comment --- src/transformers/models/mamba/modeling_mamba.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 8f19c361269e..1d4ea6e6083e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -282,16 +282,16 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] - discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) scan_outputs = [] for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) @@ -299,7 +299,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): cache_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states # fmt: on From 190c2d9a8da1fb5dfba8ea934646876a0fcf07ed Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 8 May 2024 17:58:55 +0200 Subject: [PATCH 2/2] fix `intermediade` typo in jamba --- src/transformers/models/jamba/modeling_jamba.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 1dbcbc76f3c2..a90cb89f8a28 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -962,15 +962,15 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] - discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) scan_outputs = [] for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len] + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) @@ -978,7 +978,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa cache_params.ssm_states[self.layer_idx] = ssm_state # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states # fmt: on