Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL Inpainting VAE Normalization #7225

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,25 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


def requires_vae_latents_normalization(vae):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can definitely do this in place without having to delegate it to a method. As that way the readability of the code stays linear and the reader doesn't have to refer to another method to see what's going on.

return hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None and \
hasattr(vae.config, "latents_std") and vae.config.latents_std is not None


def normalize_vae_latents(latents, latents_mean, latents_std):
latents_mean = latents_mean.to(device=latents.device, dtype=latents.dtype)
latents_std = latents_std.to(device=latents.device, dtype=latents.dtype)
latents = (latents - latents_mean) / latents_std
return latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.



def denormalize_vae_latents(latents, latents_mean, latents_std):
latents_mean = latents_mean.to(device=latents.device, dtype=latents.dtype)
latents_std = latents_std.to(device=latents.device, dtype=latents.dtype)
latents = latents * latents_std + latents_mean
return latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.



class StableDiffusionXLInpaintPipeline(
DiffusionPipeline,
StableDiffusionMixin,
Expand Down Expand Up @@ -939,6 +958,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

if requires_vae_latents_normalization(self.vae):
image_latents = normalize_vae_latents(
image_latents,
latents_mean=torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1),
latents_std=torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1),
)
if self.vae.config.force_upcast:
self.vae.to(dtype)

Expand Down Expand Up @@ -1763,6 +1788,13 @@ def denoising_value_valid(dnv):
if XLA_AVAILABLE:
xm.mark_step()

if requires_vae_latents_normalization(self.vae):
latents = denormalize_vae_latents(
latents,
latents_mean=torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1),
latents_std=torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1),
)

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
Expand Down