Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add single file support for Stable Cascade #7274

Merged
merged 8 commits into from
Mar 13, 2024
Merged

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Mar 11, 2024

What does this PR do?

Adds single file support to the StableCascadeUNet, which allows users to load in checkpoints published in the original format.

The single file loading logic and mappings are defined within the UNet without a Mixin, since the mapping functions have be accessible to the from_single_file method somehow. It relies on fetching the configs from a hosted repo of configs hosted on the diffusers org.

This is a more practical approach, since a single file pipeline checkpoint would be quite large (~34GB) to load. Single File loading with combined pipelines is also not something we support at the moment and it is particularly challenging with Cascade due to the dtype restrictions in the Prior (Stage C) pipeline, which cannot work with float16, while the decoder (Stage B) can.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested review from yiyixuxu and sayakpaul March 11, 2024 07:15
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -257,6 +257,39 @@ def fetch_ldm_config_and_checkpoint(
return original_config, checkpoint


def load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it confusing with fetch_ldm_config_and_checkpoint():

def fetch_ldm_config_and_checkpoint(

Can't we repurpose fetch_ldm_config_and_checkpoint() here? Code looks almost the same to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can move load_single_file_model_checkpoint into the fetch_ldm_config_and_checkpoint` function.

}


def convert_single_file_to_diffusers(original_state_dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could leverage these functions in the conversion script to massively reduce the duplicated code, no?

(Future PR candidate)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied these over from the conversion script. But yes, I can update the scripts to use this function once merged.


assert unet.config[param_name] == param_value

def test_stable_cascade_config_loading(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): We're testing the config for the UNet. The test, hence, should be named accordingly.

unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior",
subfolder="prior",
revision="refs/pr/2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not merged yet? 😱

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of nits.

My biggest concern is that with the current design to support single-file checkpoint loading, we are breaking the design of how we support it for ControlNet and VAE massively. Personally, I like this current approach, though!

I also think we should have a dedicated section in the Stable Cascade documentation page about how one should properly load a single-file checkpoint and run inference to obtain valid results.

@DN6
Copy link
Collaborator Author

DN6 commented Mar 11, 2024

My biggest concern is that with the current design to support single-file checkpoint loading, we are breaking the design of how we support it for ControlNet and VAE massively. Personally, I like this current approach, though!

I think what we can do is move towards a single SingleFileModelMixin and use class attributes to pass in the functions to convert the checkpoint from original format to diffusers. I can do that in a follow up for all the model classes that use from_single_file

@sayakpaul
Copy link
Member

I think what we can do is move towards a single SingleFileModelMixin and use class attributes to pass in the functions to convert the checkpoint from original format to diffusers.

I don't fully understand. Could you maybe elaborate this with pseudocode?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@@ -387,6 +475,52 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru

self.gradient_checkpointing = False

@classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we quickly create a FromOriginalUnetMixin, move everything there, and throw a warning if the cls.__name != StableCascadeUNet?

we can refactor the code later and apply to other unets

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer the other way around then. I.e., create a FromOriginalUnetMixin and then land this PR. Since with this PR, we are including a big anti-pattern in the code, which I would like to avoid at all costs.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu requested a review from sayakpaul March 12, 2024 16:04
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sayakpaul sayakpaul merged commit ed224f9 into main Mar 13, 2024
17 checks passed
@sayakpaul sayakpaul deleted the cascade-single-file branch March 13, 2024 03:07
@vladmandic
Copy link
Contributor

great stuff.
minor nitpick, there are several references to ControlNetModel, i guess since this implementation is copied from there?

@sayakpaul
Copy link
Member

The docstring is. #7295 should clean it up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants