From 3a55e7e3910b8ae58f82a5a0e4c11d7d4fa3143f Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Sat, 18 May 2024 15:53:34 +0200 Subject: [PATCH] feat: add AlignYourStepsScheduler (#2905) --- .../contrib/external_align_your_steps.py | 55 +++++++++++++++++++ modules/flags.py | 2 +- modules/sample_hijack.py | 4 ++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 ldm_patched/contrib/external_align_your_steps.py diff --git a/ldm_patched/contrib/external_align_your_steps.py b/ldm_patched/contrib/external_align_your_steps.py new file mode 100644 index 000000000..624bbce2a --- /dev/null +++ b/ldm_patched/contrib/external_align_your_steps.py @@ -0,0 +1,55 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], + "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], + "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} + +class AlignYourStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["SD1", "SDXL", "SVD"], ), + "steps": ("INT", {"default": 10, "min": 10, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "AlignYourStepsScheduler": AlignYourStepsScheduler, +} \ No newline at end of file diff --git a/modules/flags.py b/modules/flags.py index 9f2aefb3b..0c6054394 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -47,7 +47,7 @@ KSAMPLER_NAMES = list(KSAMPLER.keys()) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps"] SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index 5936a096d..4ab3cbbde 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -3,6 +3,7 @@ import ldm_patched.modules.model_management from collections import namedtuple +from ldm_patched.contrib.external_align_your_steps import AlignYourStepsScheduler from ldm_patched.contrib.external_custom_sampler import SDTurboScheduler from ldm_patched.k_diffusion import sampling as k_diffusion_sampling from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler @@ -175,6 +176,9 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): sigmas = normal_scheduler(model, steps, sgm=True) elif scheduler_name == "turbo": sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0] + elif scheduler_name == "align_your_steps": + model_type = 'SDXL' if isinstance(model.latent_format, ldm_patched.modules.latent_formats.SDXL) else 'SD1' + sigmas = AlignYourStepsScheduler().get_sigmas(model_type=model_type, steps=steps, denoise=1.0)[0] else: raise TypeError("error invalid scheduler") return sigmas