-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX [PEFT
/ Core
] Copy the state dict when passing it to load_lora_weights
#7058
FIX [PEFT
/ Core
] Copy the state dict when passing it to load_lora_weights
#7058
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, this LGTM, thanks for working on this.
I tried to track down the source of the mutation. If I'm not missing something, the culprits seem to be _maybe_map_sgm_blocks_to_diffusers
and _convert_kohya_lora_to_diffusers
here because they pop from the state_dict
. I wonder if it wouldn't be better to create the shallow copy in these two functions. The advantage would be that if we call them from somewhere else, the state_dict
is still not mutated. As is, we are only safe if we go via load_lora_weights
. Right now, AFAICT, that's the only function that calls them, but this could change in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, this looks good to me. But I concur with @BenjaminBossan’s notes. Could we approach the PR along those lines? @yiyixuxu what are your thoughts?
tests/lora/test_lora_layers_peft.py
Outdated
@@ -1727,6 +1729,20 @@ def test_load_unload_load_kohya_lora(self): | |||
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) | |||
release_memory(pipe) | |||
|
|||
def test_empty_state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you but can we maybe add this as a fast test with smaller ckpts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having two slow tests is fine
lcm_lora = load_file(cached_file) | ||
|
||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm") | ||
self.assertTrue(lcm_lora != {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT of making the test more rigid by comparing the length of the state dict instead?
@BenjaminBossan what are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically remember the size before passing it and then ensuring that it's the same after? I don't see why not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since load_file
already gives you a dict
, we could store the original state dict length with len(lcm_lora)
and use the value for assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we can just make another test_load_unload_load
test but with state dict, where we can call pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
twice and make sure it gives the same results -
that way we make sure we can re-use the state dict we passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the test !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!
lcm_lora = load_file(cached_file) | ||
|
||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm") | ||
self.assertTrue(lcm_lora != {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we can just make another test_load_unload_load
test but with state dict, where we can call pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
twice and make sure it gives the same results -
that way we make sure we can re-use the state dict we passed
Thanks everyone for the review! merging for now ! if we see any other similar issue in the future I'd be happy to refactor that a bit as suggested by Benjamin |
What does this PR do?
As per title
Fixes: #7054
There should be no reason to not copy the state dict of the lora layers if one passes a dict into
load_lora_weights
, therefore avoiding to sliently modifying the passed state_dict in-place. Added also a nice test with a state dict pushed under hf-internal-testingcc @yiyixuxu @sayakpaul @pacman100