Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204

Merged
merged 39 commits into from
Mar 13, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Since we have shifted to the peft backend for all things LoRA, there's no need for us to use LoRACompatible* classes now.

We should also start the deprecation cycles for the LoRALinearLayer and LoRAConv2dLayer. This PR does that.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh great! Thanks so much

We can deprecate the scale argument everywhere now, too, right?
e.g., all the attention processors

src/diffusers/models/attention.py Show resolved Hide resolved
src/diffusers/models/attention_processor.py Show resolved Hide resolved
@sayakpaul sayakpaul requested a review from yiyixuxu March 5, 2024 08:46
@sayakpaul
Copy link
Member Author

@yiyixuxu up for another review. Rigorous review appreciated!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
I have two main feedbacks:

  1. let's deprecate the scale argument by removing it from the signature (see here [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment))
  2. I see that you deprecated the scale argument from some more public classes, but for some less public classes, you simply removed or silently ignored it - maybe we should just deprecate it everywhere and remove them all together later? e.g. e.g. [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment) and [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment)

src/diffusers/models/attention.py Show resolved Hide resolved
src/diffusers/models/resnet.py Outdated Show resolved Hide resolved
src/diffusers/models/resnet.py Outdated Show resolved Hide resolved
src/diffusers/models/resnet.py Outdated Show resolved Hide resolved
src/diffusers/models/resnet.py Outdated Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Outdated Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Outdated Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Outdated Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Outdated Show resolved Hide resolved
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 8, 2024

cc @BenjaminBossan - we would appreciate it if you can give a review also

@sayakpaul
Copy link
Member Author

@yiyixuxu addressed all your comments. They were very very helpful! Thank you!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very thorough PR, thanks a lot Sayak. This LGTM overall.

Just a suggestion for the deprecation. Currently, the message is:

Use of scale is deprecated. Please remove the argument

As a user, I might not know what to do with that info: Is this feature removed completely or can I still use it, but have to do it differently? Also, I might get the impression that I can still pass scale and it works, it's just deprecated, when in fact the argument doesn't do anything, right? Perhaps the message could be clarified.

Moreover, if we already have an idea in which diffusers version this will be removed (hence raise an error), it could be added to the warning. On top, we could add a comment like # TODO remove argument in diffusers X.Y to make it more likely that this will indeed be cleaned up when this version is released.

@sayakpaul
Copy link
Member Author

sayakpaul commented Mar 11, 2024

Thanks, Benjamin!

As a user, I might not know what to do with that info: Is this feature removed completely or can I still use it, but have to do it differently? Also, I might get the impression that I can still pass scale and it works, it's just deprecated, when in fact the argument doesn't do anything, right? Perhaps the message could be clarified.

Very good point. I clarified that as much as I could.

Moreover, if we already have an idea in which diffusers version this will be removed (hence raise an error), it could be added to the warning. On top, we could add a comment like # TODO remove argument in diffusers X.Y to make it more likely that this will indeed be cleaned up when this version is released.

The depcrecate() utility I am using will automatically take care of that. So, once we hit 1.0.0, on any PR, it will raise an error unless handled accordingly. Here is an example: #6885. Does that work?

@BenjaminBossan
Copy link
Member

Regarding the error message:

deprecation_message = f"Use of scale is deprecated. Please remove the argument. Even if you pass it to the forward() of the {attn.__class__.__name__} class, it won't have any effect."

I think it's almost too detailed, users will not normally pass the argument directly to the {attn.__class__.__name__}, right? Instead, the argument was probably passed along by something higher up. I think the message could be shortened to:

deprecation_message = f"The scale is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future."

Can we also add a sentence on how to control the scale instead?

The depcrecate() utility I am using will automatically take care of that. So, once we hit 1.0.0, on any PR, it will raise an error unless handled accordingly. Here is an example: #6885. Does that work?

Cool, I didn't know 👍

@sayakpaul
Copy link
Member Author

How about?

The scale is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. scale should directly be passed while calling the underlying pipeline component i.e., via cross_attention_kwargs.

@BenjaminBossan
Copy link
Member

Yes, that sounds good, as it clarifies to the user what they need to do.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks!
I did another round of review,

  1. I left a question about the deprecation message, and I think we should use same message everywhere (I saw you updated in some places but not others)
  2. let's add a warning everywhere when the scale passed via cross_atten_kwargs is ignored - we can remove all these warnings all together at the same time in the future.

src/diffusers/models/attention.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
src/diffusers/models/unets/unet_2d_blocks.py Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

@yiyixuxu have resolved your comments, except for #7204. Thank you!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much ! IMO all good on PEFT end ! Great work @sayakpaul !

src/diffusers/models/attention_processor.py Show resolved Hide resolved
src/diffusers/models/embeddings.py Show resolved Hide resolved
src/diffusers/models/attention.py Outdated Show resolved Hide resolved
src/diffusers/models/downsampling.py Outdated Show resolved Hide resolved
src/diffusers/models/embeddings.py Show resolved Hide resolved
@@ -327,31 +326,20 @@ def forward(
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@@ -1238,8 +1241,6 @@ def forward(
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()

lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@@ -2440,7 +2477,6 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@@ -1175,8 +1180,6 @@ def forward(
):
output_states = ()

lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@@ -1355,7 +1358,6 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@@ -1687,8 +1694,7 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")

@sayakpaul sayakpaul merged commit 531e719 into main Mar 13, 2024
17 checks passed
@sayakpaul sayakpaul deleted the remove-linear_cls branch March 13, 2024 02:26
@panxiaoguang
Copy link

Can anyone tell me how to make a scale for LoRA if I can't use scale anymore?

@younesbelkada
Copy link
Contributor

@panxiaoguang scale corresponds to lora_alpha / r, you simply need to make sure to pass your desired lora_alpha and r arguments in LoraConfig

@sayakpaul
Copy link
Member Author

Can anyone tell me how to make a scale for LoRA if I can't use scale anymore?

Apart from what @younesbelkada mentioned (applies to "training" only) you can definitely use cross_attention_kwargs={"scale": ...} and it will all work. #7338 PR will probably clear all the confusion users may have had.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants