Skip to content

Commit

Permalink
ensure that original alpha bar always exists
Browse files Browse the repository at this point in the history
  • Loading branch information
drhead authored Dec 2, 2023
1 parent 668ae34 commit 309a606
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,17 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar

if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)

if opts.use_downcasted_alpha_bar:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR":
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'):
p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod

p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)

if opts.use_downcasted_alpha_bar:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR":
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)

with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
Expand Down

0 comments on commit 309a606

Please sign in to comment.