-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh great! Thanks so much
We can deprecate the scale
argument everywhere now, too, right?
e.g., all the attention processors
@yiyixuxu up for another review. Rigorous review appreciated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I have two main feedbacks:
- let's deprecate the
scale
argument by removing it from the signature (see here [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment)) - I see that you deprecated the
scale
argument from some more public classes, but for some less public classes, you simply removed or silently ignored it - maybe we should just deprecate it everywhere and remove them all together later? e.g. e.g. [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment) and [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204 (comment)
cc @BenjaminBossan - we would appreciate it if you can give a review also |
@yiyixuxu addressed all your comments. They were very very helpful! Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very thorough PR, thanks a lot Sayak. This LGTM overall.
Just a suggestion for the deprecation. Currently, the message is:
Use of
scale
is deprecated. Please remove the argument
As a user, I might not know what to do with that info: Is this feature removed completely or can I still use it, but have to do it differently? Also, I might get the impression that I can still pass scale
and it works, it's just deprecated, when in fact the argument doesn't do anything, right? Perhaps the message could be clarified.
Moreover, if we already have an idea in which diffusers version this will be removed (hence raise an error), it could be added to the warning. On top, we could add a comment like # TODO remove argument in diffusers X.Y
to make it more likely that this will indeed be cleaned up when this version is released.
Thanks, Benjamin!
Very good point. I clarified that as much as I could.
The |
Regarding the error message:
I think it's almost too detailed, users will not normally pass the argument directly to the
Can we also add a sentence on how to control the scale instead?
Cool, I didn't know 👍 |
How about?
|
Yes, that sounds good, as it clarifies to the user what they need to do. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks!
I did another round of review,
- I left a question about the deprecation message, and I think we should use same message everywhere (I saw you updated in some places but not others)
- let's add a warning everywhere when the
scale
passed viacross_atten_kwargs
is ignored - we can remove all these warnings all together at the same time in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much ! IMO all good on PEFT end ! Great work @sayakpaul !
@@ -327,31 +326,20 @@ def forward( | |||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |||
|
|||
# Retrieve lora scale. | |||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
@@ -1238,8 +1241,6 @@ def forward( | |||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |||
output_states = () | |||
|
|||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
@@ -2440,7 +2477,6 @@ def forward( | |||
attention_mask: Optional[torch.FloatTensor] = None, | |||
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |||
) -> torch.FloatTensor: | |||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
@@ -1175,8 +1180,6 @@ def forward( | |||
): | |||
output_states = () | |||
|
|||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
@@ -1355,7 +1358,6 @@ def forward( | |||
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |||
num_frames: int = 1, | |||
) -> torch.FloatTensor: | |||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
@@ -1687,8 +1694,7 @@ def forward( | |||
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |||
num_frames: int = 1, | |||
) -> torch.FloatTensor: | |||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Can anyone tell me how to make a scale for LoRA if I can't use |
@panxiaoguang scale corresponds to |
Apart from what @younesbelkada mentioned (applies to "training" only) you can definitely use |
What does this PR do?
Since we have shifted to the
peft
backend for all things LoRA, there's no need for us to useLoRACompatible*
classes now.We should also start the deprecation cycles for the
LoRALinearLayer
andLoRAConv2dLayer
. This PR does that.