Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zero terminal SNR noise schedule option #14145

Merged
merged 13 commits into from
Jan 1, 2024
28 changes: 28 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"

def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar

if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)

if opts.use_downcasted_alpha_bar:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR":
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)

with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

Expand Down
6 changes: 6 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer

if shared.cmd_opts.no_half:
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
Expand All @@ -387,7 +388,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None

alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
Expand Down Expand Up @@ -642,6 +647,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
'alphas_cumprod': None,
'': torch.float16,
}

Expand Down
2 changes: 1 addition & 1 deletion modules/sd_samplers_timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model, *args, **kwargs):
self.inner_model = model

def predict_eps_from_z_and_v(self, x_t, t, v):
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t

def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
}))

options_templates.update(options_section(('interrogate', "Interrogate"), {
Expand Down Expand Up @@ -335,6 +336,7 @@
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
}))

options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
Expand Down