From 34f67c01a8164f4b596ea68e336f7a66bc4f5c30 Mon Sep 17 00:00:00 2001 From: licyk <76895225+licyk@users.noreply.github.com> Date: Mon, 1 Jul 2024 20:24:21 +0800 Subject: [PATCH] feat: add restart sampler (#3219) --- ldm_patched/k_diffusion/sampling.py | 70 +++++++++++++++++++++++++++++ ldm_patched/modules/samplers.py | 2 +- modules/flags.py | 3 +- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/ldm_patched/k_diffusion/sampling.py b/ldm_patched/k_diffusion/sampling.py index ea5540a42..4d9d4ea64 100644 --- a/ldm_patched/k_diffusion/sampling.py +++ b/ldm_patched/k_diffusion/sampling.py @@ -835,4 +835,74 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n else: x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2) + return x + + +@torch.no_grad() +def sample_restart(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023) + Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]} + If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list + """ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + + def heun_step(x, old_sigma, new_sigma, second_order=True): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0 or not second_order: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + + steps = sigmas.shape[0] - 1 + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = {} + + restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()} + + step_list = [] + for i in range(len(sigmas) - 1): + step_list.append((sigmas[i], sigmas[i + 1])) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] + while restart_times > 0: + restart_times -= 1 + step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:])) + + last_sigma = None + for old_sigma, new_sigma in tqdm(step_list, disable=disable): + if last_sigma is None: + last_sigma = old_sigma + elif last_sigma < old_sigma: + x = x + torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5 + x = heun_step(x, old_sigma, new_sigma) + last_sigma = new_sigma + return x \ No newline at end of file diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 9ed1fcd28..05b4b3174 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5", "restart"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/modules/flags.py b/modules/flags.py index 29ac4615f..2addb8435 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -35,7 +35,8 @@ "dpmpp_3m_sde_gpu": "", "ddpm": "", "lcm": "LCM", - "tcd": "TCD" + "tcd": "TCD", + "restart": "Restart" } SAMPLER_EXTRA = {