Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX [PEFT / Core] Copy the state dict when passing it to load_lora_weights #7058

Merged

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Feb 22, 2024

What does this PR do?

As per title

Fixes: #7054

There should be no reason to not copy the state dict of the lora layers if one passes a dict into load_lora_weights, therefore avoiding to sliently modifying the passed state_dict in-place. Added also a nice test with a state dict pushed under hf-internal-testing

cc @yiyixuxu @sayakpaul @pacman100

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this LGTM, thanks for working on this.

I tried to track down the source of the mutation. If I'm not missing something, the culprits seem to be _maybe_map_sgm_blocks_to_diffusers and _convert_kohya_lora_to_diffusers here because they pop from the state_dict. I wonder if it wouldn't be better to create the shallow copy in these two functions. The advantage would be that if we call them from somewhere else, the state_dict is still not mutated. As is, we are only safe if we go via load_lora_weights. Right now, AFAICT, that's the only function that calls them, but this could change in the future.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this looks good to me. But I concur with @BenjaminBossan’s notes. Could we approach the PR along those lines? @yiyixuxu what are your thoughts?

@@ -1727,6 +1729,20 @@ def test_load_unload_load_kohya_lora(self):
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
release_memory(pipe)

def test_empty_state_dict(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you but can we maybe add this as a fast test with smaller ckpts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having two slow tests is fine

lcm_lora = load_file(cached_file)

pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of making the test more rigid by comparing the length of the state dict instead?

@BenjaminBossan what are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically remember the size before passing it and then ensuring that it's the same after? I don't see why not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since load_file already gives you a dict, we could store the original state dict length with len(lcm_lora) and use the value for assertion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we can just make another test_load_unload_load test but with state dict, where we can call pipe.load_lora_weights(lcm_lora, adapter_name="lcm") twice and make sure it gives the same results -
that way we make sure we can re-use the state dict we passed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the test !

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

lcm_lora = load_file(cached_file)

pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we can just make another test_load_unload_load test but with state dict, where we can call pipe.load_lora_weights(lcm_lora, adapter_name="lcm") twice and make sure it gives the same results -
that way we make sure we can re-use the state dict we passed

@younesbelkada
Copy link
Contributor Author

Thanks everyone for the review! merging for now ! if we see any other similar issue in the future I'd be happy to refactor that a bit as suggested by Benjamin

@younesbelkada younesbelkada merged commit 8a69273 into huggingface:main Feb 27, 2024
13 checks passed
@younesbelkada younesbelkada deleted the fix-peft-state-dict-issue branch February 27, 2024 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LoraLoaderMixin.load_lora_weights() empties state_dict passed as input param.
5 participants