19
19
import torch
20
20
import torch .nn .functional as F
21
21
from diffusers .models .attention_processor import Attention
22
- from diffusers .utils import USE_PEFT_BACKEND , logging
22
+ from diffusers .utils import deprecate , logging
23
23
from diffusers .utils .import_utils import is_xformers_available
24
24
from torch import nn
25
25
@@ -107,8 +107,13 @@ def __call__(
107
107
encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
108
108
attention_mask : Optional [torch .FloatTensor ] = None ,
109
109
temb : Optional [torch .FloatTensor ] = None ,
110
- scale : float = 1.0 ,
110
+ * args ,
111
+ ** kwargs ,
111
112
) -> torch .FloatTensor :
113
+ if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
114
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
115
+ deprecate ("scale" , "1.0.0" , deprecation_message )
116
+
112
117
residual = hidden_states
113
118
if attn .spatial_norm is not None :
114
119
hidden_states = attn .spatial_norm (hidden_states , temb )
@@ -132,16 +137,15 @@ def __call__(
132
137
if attn .group_norm is not None :
133
138
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
134
139
135
- args = () if USE_PEFT_BACKEND else (scale ,)
136
- query = attn .to_q (hidden_states , * args )
140
+ query = attn .to_q (hidden_states )
137
141
138
142
if encoder_hidden_states is None :
139
143
encoder_hidden_states = hidden_states
140
144
elif attn .norm_cross :
141
145
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
142
146
143
- key = attn .to_k (encoder_hidden_states , * args )
144
- value = attn .to_v (encoder_hidden_states , * args )
147
+ key = attn .to_k (encoder_hidden_states )
148
+ value = attn .to_v (encoder_hidden_states )
145
149
146
150
inner_dim = key .shape [- 1 ]
147
151
head_dim = inner_dim // attn .heads
@@ -171,7 +175,7 @@ def __call__(
171
175
hidden_states = hidden_states .to (query .dtype )
172
176
173
177
# linear proj
174
- hidden_states = attn .to_out [0 ](hidden_states , * args )
178
+ hidden_states = attn .to_out [0 ](hidden_states )
175
179
# dropout
176
180
hidden_states = attn .to_out [1 ](hidden_states )
177
181
0 commit comments