Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DPM scheduler with EDM formulation #7120

Merged
merged 8 commits into from
Feb 27, 2024
Merged

add DPM scheduler with EDM formulation #7120

merged 8 commits into from
Feb 27, 2024

Conversation

patil-suraj
Copy link
Contributor

What does this PR do?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good. Let's make sure to add "Copied from ..." statements wherever applicable.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in principle!

else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should deis be accepted here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, for now I followed the logic in existing dpm scheduler.

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just map it to dpmsolver++ here so it should be fine; although I absolutely have no clue why "deis" is accepted as algorithm type (I looked back at the original PR that added it too and that did not help)

@patil-suraj patil-suraj changed the title [wip] add DPM scheduler with EDM formulation add DPM scheduler with EDM formulation Feb 27, 2024
@patil-suraj patil-suraj marked this pull request as ready for review February 27, 2024 16:59
@patil-suraj patil-suraj requested a review from yiyixuxu February 27, 2024 16:59
@patil-suraj patil-suraj merged commit 8492db2 into main Feb 27, 2024
15 checks passed
@patil-suraj patil-suraj deleted the edm-dpmsolver branch February 27, 2024 18:09
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look good!
I left a comment about the _sigma_to_alpha_sigma_t function - I think maybe we don't need this here and can simplify the math a little bit

return sigmas

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work? asking because we precondition the inputs so sample has a different scale now

return t

def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh nice! that's why this formula still works here

 x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output

or maybe we do not need this function at all , just change the math directly inside the steps as

 x_t = (sigma_t / sigma_s) * sample -  (torch.exp(-h) - 1.0) * model_output

Comment on lines +398 to +403
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

h = lambda_t - lambda_s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
h = torch.log(sigma_s) - torch.log(sigma_t)

I'm not 100% sure the math is correct here but the idea is we can calculate h from sigma directly


h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output

assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ (1 - torch.exp(-2.0 * h)) * model_output

self.sigmas[self.step_index - 1],
)

alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as #7120 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants