From 3be47edfa5fde1b1b0ddeaa42939ea46be312938 Mon Sep 17 00:00:00 2001 From: Omiita <77219025+omihub777@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:46:52 +0900 Subject: [PATCH] Fix a small typo of a variable name (#1063) Fix a small typo fix a typo in `models/attention.py`. weight -> width --- models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/attention.py b/models/attention.py index 1f9cf641c32d..372c8492b485 100644 --- a/models/attention.py +++ b/models/attention.py @@ -165,15 +165,15 @@ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atte def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual