Skip to content

Commit

Permalink
add lu's uniform logsnr time steps
Browse files Browse the repository at this point in the history
  • Loading branch information
LuChengTHU committed Oct 26, 2023
1 parent 2748765 commit 892fec9
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
Expand Down Expand Up @@ -160,6 +164,7 @@ def __init__(
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
Expand Down Expand Up @@ -263,6 +268,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
Expand Down Expand Up @@ -359,6 +370,19 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Lu et al. (2022)."""

lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()

rho = 1.0 # 1.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas

def convert_model_output(
self,
model_output: torch.FloatTensor,
Expand Down

0 comments on commit 892fec9

Please sign in to comment.