diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index c6e8d425459f..230536a83a14 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -630,7 +630,7 @@ def __init__(self, config: JambaConfig, layer_idx): # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = torch.arange(1, self.ssm_state_size + 1)[None, :] A = A.expand(self.intermediate_size, -1).contiguous() self.A_log = nn.Parameter(torch.log(A))