diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bf04c3e6a3ca..3c79cdc77b7c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -152,15 +152,15 @@ def _set_attention_slice(self, slice_size): def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual