Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL Inpainting VAE Normalization #7225

Closed

Conversation

AlekseyKorshuk
Copy link

What does this PR do?

Since the release of Playground V2.5 with "custom" VAE we should normalize and denormalize latents.
This is already implemented in StableDiffusionXLPipeline:

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor

Who can review?

cc: @sayakpaul @patil-suraj

@@ -308,6 +308,25 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


def requires_vae_latents_normalization(vae):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can definitely do this in place without having to delegate it to a method. As that way the readability of the code stays linear and the reader doesn't have to refer to another method to see what's going on.

Comment on lines 316 to 320
def normalize_vae_latents(latents, latents_mean, latents_std):
latents_mean = latents_mean.to(device=latents.device, dtype=latents.dtype)
latents_std = latents_std.to(device=latents.device, dtype=latents.dtype)
latents = (latents - latents_mean) / latents_std
return latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Comment on lines 323 to 327
def denormalize_vae_latents(latents, latents_mean, latents_std):
latents_mean = latents_mean.to(device=latents.device, dtype=latents.dtype)
latents_std = latents_std.to(device=latents.device, dtype=latents.dtype)
latents = latents * latents_std + latents_mean
return latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some results for us to see how the Playground v2.5 checkpoint plays out with inpainting.

I am okay with that changes you're introducing given the comments are addressed.

Cc: @patil-suraj as well.

@AlekseyKorshuk
Copy link
Author

@sayakpaul Thank for commenting!

Don't you think that the code should be DRY (Don't Repeat Yourself)? Overall these 3 functions should be used in any place of SDXL pipeline with this VAE. Maybe it is reasonable to move it to "utils"/"common" or in base SDXL pipeline.
Just my thoughts, otherwise happy to put inplace without delegating the function.

Don't have checkpoint yet. I trained one but figured out this normalisation too late...

@sayakpaul
Copy link
Member

Don't you think that the code should be DRY (Don't Repeat Yourself)? Overall these 3 functions should be used in any place of SDXL pipeline with this VAE. Maybe it is reasonable to move it to "utils"/"common" or in base SDXL pipeline.

Please refer to https://huggingface.co/docs/diffusers/en/conceptual/philosophy. It's okay to trade away DRY in the interest of readability as our pipelines are also read for educational purposes.

@AlekseyKorshuk
Copy link
Author

Addressed all the comments. Please let me know if I can be helpful.

Btw while you are replying, if it is possible by any chance to share training args used for this PR to train inpainting model:

This would be very helpful, thank you!

@sayakpaul
Copy link
Member

Btw while you are replying, if it is possible by any chance to share training args used for this PR to train inpainting model:

Will defer that to the training ninja @patil-suraj here :)

@@ -939,6 +939,17 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
Copy link
Collaborator

@yiyixuxu yiyixuxu Mar 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need it here?
we didn't normalize the latents in SDXL tex to image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be needed because here we are encoding the input image. This is what we do in the SDXL DreamBooth LoRA training script.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I agree that #7132 tackles it from a comprehensive perspective.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I re-opene it and I'm happy to be proven wrong
once we merge #7132, you can update the branch, and you can then test it out and show some results with and without the normalization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 9, 2024

just saw this PR #7132
I think it is correctly done there I'm closing this PR for now. Feel free to leave any questions and we will continue the discussion here

@yiyixuxu yiyixuxu closed this Mar 9, 2024
@yiyixuxu yiyixuxu reopened this Mar 9, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

github-actions bot commented Apr 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 5, 2024
@sayakpaul sayakpaul requested a review from yiyixuxu June 30, 2024 05:38
@sayakpaul
Copy link
Member

@yiyixuxu could you give this another look?

@github-actions github-actions bot removed the stale Issues that haven't received updates label Sep 14, 2024
Copy link

github-actions bot commented Oct 9, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 9, 2024
@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Oct 9, 2024
@sayakpaul
Copy link
Member

@yiyixuxu a gentle ping.

Copy link

github-actions bot commented Nov 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 3, 2024
@a-r-r-o-w a-r-r-o-w removed the stale Issues that haven't received updates label Nov 3, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 28, 2024
@sayakpaul
Copy link
Member

Safe to close it now.

@sayakpaul sayakpaul closed this Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants