diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index ce645f70e524..cec83bdded9e 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -17,8 +17,7 @@ import torch.nn.functional as F from torch import nn -from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleLinear +from ..utils import deprecate ACTIVATION_FUNCTIONS = { @@ -87,9 +86,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - - self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -97,9 +94,12 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states, scale: float = 1.0): - args = () if USE_PEFT_BACKEND else (scale,) - hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c65e6717959c..3d45cfa828a3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,18 +17,18 @@ import torch.nn.functional as F from torch import nn -from ..utils import USE_PEFT_BACKEND +from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention from .embeddings import SinusoidalPositionalEmbedding -from .lora import LoRACompatibleLinear from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm -def _chunked_feed_forward( - ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None -): +logger = logging.get_logger(__name__) + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( @@ -36,18 +36,10 @@ def _chunked_feed_forward( ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size - if lora_scale is None: - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - else: - # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete - ff_output = torch.cat( - [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) return ff_output @@ -299,6 +291,10 @@ def forward( class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] @@ -326,10 +322,7 @@ def forward( if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) - # 1. Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - # 2. Prepare GLIGEN inputs + # 1. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) @@ -348,7 +341,7 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - # 2.5 GLIGEN Control + # 1.2 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) @@ -394,11 +387,9 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward( - self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: - ff_output = self.ff(norm_hidden_states, scale=lora_scale) + ff_output = self.ff(norm_hidden_states) if self.norm_type == "ada_norm_zero": ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -643,7 +634,7 @@ def __init__( if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + linear_cls = nn.Linear if activation_fn == "gelu": act_fn = GELU(dim, inner_dim, bias=bias) @@ -665,11 +656,10 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: - compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) for module in self.net: - if isinstance(module, compatible_cls): - hidden_states = module(hidden_states, scale) - else: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states) return hidden_states diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5ec8876fc114..44fbd584cd7c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,10 +20,10 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import USE_PEFT_BACKEND, deprecate, logging +from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .lora import LoRACompatibleLinear, LoRALinearLayer +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -181,10 +181,7 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - if USE_PEFT_BACKEND: - linear_cls = nn.Linear - else: - linear_cls = LoRACompatibleLinear + linear_cls = nn.Linear self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) @@ -741,11 +738,14 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: - residual = hidden_states + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - args = () if USE_PEFT_BACKEND else (scale,) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -764,15 +764,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -783,7 +783,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -914,11 +914,14 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: - residual = hidden_states + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - args = () if USE_PEFT_BACKEND else (scale,) + residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -932,17 +935,17 @@ def __call__( hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, *args) - value = attn.to_v(hidden_states, *args) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -956,7 +959,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -984,11 +987,14 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: - residual = hidden_states + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - args = () if USE_PEFT_BACKEND else (scale,) + residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -1002,7 +1008,7 @@ def __call__( hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -1011,8 +1017,8 @@ def __call__( encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, *args) - value = attn.to_v(hidden_states, *args) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -1029,7 +1035,7 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1132,11 +1138,14 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: - residual = hidden_states + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - args = () if USE_PEFT_BACKEND else (scale,) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1165,15 +1174,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -1186,7 +1195,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1217,8 +1226,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1242,16 +1256,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1271,7 +1284,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1312,8 +1325,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1337,17 +1355,16 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) if encoder_hidden_states is None: - qkv = attn.to_qkv(hidden_states, *args) + qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) - kv = attn.to_kv(encoder_hidden_states, *args) + kv = attn.to_kv(encoder_hidden_states) split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) @@ -1368,7 +1385,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1859,7 +1876,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1877,7 +1894,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAAttnProcessor2_0(nn.Module): @@ -1920,7 +1937,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1938,7 +1955,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnProcessor2_0() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAXFormersAttnProcessor(nn.Module): @@ -1999,7 +2016,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -2017,7 +2034,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = XFormersAttnProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAAttnAddedKVProcessor(nn.Module): @@ -2058,7 +2075,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -2076,7 +2093,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnAddedKVProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class IPAdapterAttnProcessor(nn.Module): diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 80fb065a6f4c..9ae28e950e83 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleConv +from ..utils import deprecate from .normalization import RMSNorm from .upsampling import upfirdn2d_native @@ -103,7 +102,7 @@ def __init__( self.padding = padding stride = 2 self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) @@ -131,7 +130,10 @@ def __init__( else: self.conv = conv - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: @@ -143,13 +145,7 @@ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch assert hidden_states.shape[1] == self.channels - if not USE_PEFT_BACKEND: - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) return hidden_states diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 49f385d5f493..c15ff24cbcda 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,10 +18,9 @@ import torch from torch import nn -from ..utils import USE_PEFT_BACKEND, deprecate +from ..utils import deprecate from .activations import get_activation from .attention_processor import Attention -from .lora import LoRACompatibleLinear def get_timestep_embedding( @@ -200,7 +199,7 @@ def __init__( sample_proj_bias=True, ): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1bbc96c6f5a7..4e9e0c07ca75 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -204,6 +204,9 @@ def __init__( ): super().__init__() + deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRALinearLayer", "1.0.0", deprecation_message) + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. @@ -264,6 +267,9 @@ def __init__( ): super().__init__() + deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message) + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 84cb31f430a0..ec75861e2da0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND +from ..utils import deprecate from .activations import get_activation from .attention_processor import SpatialNorm from .downsampling import ( # noqa @@ -30,7 +30,6 @@ KDownsample2D, downsample_2d, ) -from .lora import LoRACompatibleConv, LoRACompatibleLinear from .normalization import AdaGroupNorm from .upsampling import ( # noqa FirUpsample2D, @@ -102,7 +101,7 @@ def __init__( self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if groups_out is None: groups_out = groups @@ -149,12 +148,11 @@ def __init__( bias=conv_shortcut_bias, ) - def forward( - self, - input_tensor: torch.FloatTensor, - temb: torch.FloatTensor, - scale: float = 1.0, - ) -> torch.FloatTensor: + def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = input_tensor hidden_states = self.norm1(hidden_states, temb) @@ -166,26 +164,24 @@ def forward( if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, scale=scale) - hidden_states = self.upsample(hidden_states, scale=scale) + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, scale=scale) - hidden_states = self.downsample(hidden_states, scale=scale) + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states, temb) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = ( - self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) - ) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor @@ -267,8 +263,8 @@ def __init__( self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear + conv_cls = nn.Conv2d if groups_out is None: groups_out = groups @@ -326,12 +322,11 @@ def __init__( bias=conv_shortcut_bias, ) - def forward( - self, - input_tensor: torch.FloatTensor, - temb: torch.FloatTensor, - scale: float = 1.0, - ) -> torch.FloatTensor: + def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = input_tensor hidden_states = self.norm1(hidden_states) @@ -342,38 +337,18 @@ def forward( if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = ( - self.upsample(input_tensor, scale=scale) - if isinstance(self.upsample, Upsample2D) - else self.upsample(input_tensor) - ) - hidden_states = ( - self.upsample(hidden_states, scale=scale) - if isinstance(self.upsample, Upsample2D) - else self.upsample(hidden_states) - ) + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - input_tensor = ( - self.downsample(input_tensor, scale=scale) - if isinstance(self.downsample, Downsample2D) - else self.downsample(input_tensor) - ) - hidden_states = ( - self.downsample(hidden_states, scale=scale) - if isinstance(self.downsample, Downsample2D) - else self.downsample(hidden_states) - ) + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = ( - self.time_emb_proj(temb, scale)[:, :, None, None] - if not USE_PEFT_BACKEND - else self.time_emb_proj(temb)[:, :, None, None] - ) + temb = self.time_emb_proj(temb)[:, :, None, None] if self.time_embedding_norm == "default": if temb is not None: @@ -393,12 +368,10 @@ def forward( hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = ( - self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) - ) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 8b391eeebfd9..555ea4f63808 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,14 +19,16 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from ...utils import BaseOutput, deprecate, is_torch_version, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..lora import LoRACompatibleConv, LoRACompatibleLinear from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + @dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -115,8 +117,8 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration @@ -304,6 +306,9 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -327,9 +332,6 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -337,21 +339,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) + hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -414,17 +408,9 @@ def custom_forward(*inputs): if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) + hidden_states = self.proj_out(hidden_states) else: - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index a0ec2a116664..b9e9e63bbc18 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from ...utils import is_torch_version, logging +from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 @@ -844,8 +844,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -882,7 +885,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -982,7 +985,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -995,7 +999,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -1006,7 +1010,7 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1111,23 +1115,22 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - lora_scale = cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") output_states = () for resnet, attn in zip(self.resnets, self.attentions): - cross_attention_kwargs.update({"scale": lora_scale}) - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) + hidden_states = downsampler(hidden_states, temb=temb) else: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states += (hidden_states,) @@ -1236,9 +1239,11 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + output_states = () blocks = list(zip(self.resnets, self.attentions)) @@ -1270,7 +1275,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1288,7 +1293,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1348,8 +1353,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -1370,13 +1379,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1447,13 +1456,17 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None, scale=scale) + hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states) return hidden_states @@ -1545,15 +1558,18 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states) return hidden_states @@ -1644,18 +1660,22 @@ def forward( hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale=scale) + hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1731,16 +1751,21 @@ def forward( hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale) + hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1816,8 +1841,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -1838,13 +1867,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale) + hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) @@ -1955,10 +1984,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - lora_scale = cross_attention_kwargs.get("scale", 1.0) + output_states = () if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -1991,7 +2021,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -2004,7 +2034,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale=lora_scale) + hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) @@ -2058,8 +2088,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -2080,7 +2114,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) @@ -2165,8 +2199,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -2196,7 +2233,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2316,24 +2353,28 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb, scale=scale) + hidden_states = upsampler(hidden_states, temb=temb) else: - hidden_states = upsampler(hidden_states, scale=scale) + hidden_states = upsampler(hidden_states) return hidden_states @@ -2440,7 +2481,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2494,7 +2538,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2506,7 +2550,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2567,8 +2611,13 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2612,11 +2661,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2683,11 +2732,9 @@ def __init__( self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb, scale=scale) + hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2783,17 +2830,14 @@ def __init__( self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, scale=scale) + hidden_states = upsampler(hidden_states) return hidden_states @@ -2898,18 +2942,22 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, skip_sample=None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) - cross_attention_kwargs = {"scale": scale} - hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) + hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -2923,7 +2971,7 @@ def forward( skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample @@ -3006,15 +3054,20 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, skip_sample=None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -3028,7 +3081,7 @@ def forward( skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample @@ -3108,8 +3161,13 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -3133,11 +3191,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=scale) + hidden_states = upsampler(hidden_states, temb) return hidden_states @@ -3253,8 +3311,9 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -3292,7 +3351,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -3303,7 +3362,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=lora_scale) + hidden_states = upsampler(hidden_states, temb) return hidden_states @@ -3364,8 +3423,13 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -3388,7 +3452,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -3498,7 +3562,6 @@ def forward( if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -3527,7 +3590,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -3630,6 +3693,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # 1. Self-Attention if self.add_self_attention: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index d067767315f1..fb40d6ea31b4 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1226,7 +1226,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -1297,7 +1297,6 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, - scale=lora_scale, ) # 6. post-process diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index e1ada1021b3a..a48f1841c683 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import is_torch_version +from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -35,6 +35,9 @@ ) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + def get_down_block( down_block_type: str, num_layers: int, @@ -1005,9 +1008,14 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, num_frames: int = 1, + *args, + **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () blocks = zip(self.resnets, self.motion_modules) @@ -1029,18 +1037,18 @@ def custom_forward(*inputs): ) else: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, scale + create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1173,9 +1181,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, additional_residuals: Optional[torch.FloatTensor] = None, ): - output_states = () + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + output_states = () blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): @@ -1206,7 +1216,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1228,7 +1238,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1355,7 +1365,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1410,7 +1423,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1426,7 +1439,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1507,9 +1520,14 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size=None, - scale: float = 1.0, num_frames: int = 1, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1559,12 +1577,12 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1687,8 +1705,11 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: @@ -1737,7 +1758,7 @@ def custom_forward(*inputs): hidden_states, num_frames=num_frames, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index a096f842ab6c..5c5c6a2cc5ec 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -89,7 +89,7 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - ff_output = self.ff(hidden_states, scale=1.0) + ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 080103504c53..4ecf6ebc26d2 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleConv +from ..utils import deprecate from .normalization import RMSNorm @@ -111,7 +110,7 @@ def __init__( self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) @@ -141,11 +140,12 @@ def __init__( self.Conv2d_0 = conv def forward( - self, - hidden_states: torch.FloatTensor, - output_size: Optional[int] = None, - scale: float = 1.0, + self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels if self.norm is not None: @@ -180,15 +180,9 @@ def forward( # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) else: - if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states) return hidden_states diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index f802a37de4a0..62a3a8728a2a 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1333,7 +1333,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -1589,7 +1589,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -1611,13 +1611,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1728,8 +1728,6 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1760,7 +1758,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1778,7 +1776,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1842,8 +1840,13 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1887,11 +1890,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1999,7 +2002,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2053,7 +2059,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2065,7 +2071,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2330,8 +2336,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -2368,7 +2377,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -2469,7 +2478,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -2482,7 +2492,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -2493,6 +2503,6 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index a4c6f9c6a8b9..101acafcff1f 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -2,8 +2,6 @@ import torch.nn as nn from ...models.attention_processor import Attention -from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear -from ...utils import USE_PEFT_BACKEND class WuerstchenLayerNorm(nn.LayerNorm): @@ -19,7 +17,7 @@ def forward(self, x): class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.mapper = linear_cls(c_timestep, c * 2) def forward(self, x, t): @@ -31,8 +29,8 @@ class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) @@ -66,7 +64,7 @@ class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index c44c259ab0b4..8cc294eaf79a 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -28,9 +28,8 @@ AttnAddedKVProcessor, AttnProcessor, ) -from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version +from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -41,8 +40,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear self.c_r = c_r self.projection = conv_cls(c_in, c, kernel_size=1) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index f08388921a4f..b5a5bec471a6 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -22,7 +22,6 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformers.transformer_2d import Transformer2DModel from diffusers.utils.testing_utils import ( @@ -482,7 +481,7 @@ def test_spatial_transformer_default_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear dim = 32 inner_dim = 128 @@ -506,7 +505,7 @@ def test_spatial_transformer_geglu_approx_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear dim = 32 inner_dim = 128