diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a906c39eb24c..91026b7d487f 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -34,7 +34,7 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import FlaxSchedulerMixin + from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left else: from ..utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 7726ef2eda0d..590e3aac2e49 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -173,7 +173,9 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -211,9 +213,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - key (`random.KeyArray`): a PRNG key. - eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): TODO return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: @@ -279,13 +278,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[:, None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7b3265611101..7220a0145471 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -23,7 +23,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -129,11 +133,12 @@ def __init__( self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.one = jnp.array(1.0) - self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) - - self.variance_type = variance_type + def create_state(self): + return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: + def set_timesteps( + self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -214,7 +219,7 @@ def step( """ t = timestep - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None @@ -267,13 +272,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index caf27aa4c226..78ab007954aa 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): A reasonable range is [0.2, 80]. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -97,10 +101,13 @@ def __init__( s_min: float = 0.05, s_max: float = 50, ): - self.state = KarrasVeSchedulerState.create() + pass + + def create_state(self): + return KarrasVeSchedulerState.create() def set_timesteps( - self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index cd71e1960e8c..20982d38aa40 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass @@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -85,8 +89,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + def create_state(self): self.state = LMSDiscreteSchedulerState.create( - num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + num_train_timesteps=self.config.num_train_timesteps, + sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, ) def get_lms_coefficient(self, state, order, t, current_order): @@ -112,7 +118,7 @@ def lms_derivative(tau): return integrated_coeff def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -199,8 +205,7 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() - while len(sigma.shape) < len(noise.shape): - sigma = sigma[..., None] + sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 2e587d5e2121..357ecfe046c1 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -168,6 +168,8 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. """ offset = self.config.steps_offset @@ -509,13 +511,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index c4d802f83f94..d3eadede61b5 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -22,7 +22,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass @@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -90,12 +94,20 @@ def __init__( sampling_eps: float = 1e-5, correct_steps: int = 1, ): - state = ScoreSdeVeSchedulerState.create() + pass - self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) + def create_state(self): + state = ScoreSdeVeSchedulerState.create() + return self.set_sigmas( + state, + self.config.num_train_timesteps, + self.config.sigma_min, + self.config.sigma_max, + self.config.sampling_eps, + ) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -193,8 +205,7 @@ def step_pred( # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() - while len(diffusion.shape) < len(sample.shape): - diffusion = diffusion[:, None] + diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of @@ -252,8 +263,7 @@ def step_correct( # compute corrected sample: model_output term and noise term step_size = step_size.flatten() - while len(step_size.shape) < len(sample.shape): - step_size = step_size[:, None] + step_size = broadcast_to_shape_from_left(step_size, sample.shape) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 63de51f146f5..e545cfe2479e 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from typing import Tuple import jax.numpy as jnp @@ -41,3 +42,8 @@ class FlaxSchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME + + +def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: + assert len(shape) >= x.ndim + return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py new file mode 100644 index 000000000000..d2feaa752ae2 --- /dev/null +++ b/tests/test_scheduler_flax.py @@ -0,0 +1,863 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest +from typing import Dict, List, Tuple + +from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax.numpy as jnp + from jax import random + + +@require_flax +class FlaxSchedulerCommonTest(unittest.TestCase): + scheduler_classes = () + forward_default_kwargs = () + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + key1, key2 = random.split(random.PRNGKey(0)) + sample = random.uniform(key1, (batch_size, num_channels, height, width)) + + return sample, key2 + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = jnp.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + return jnp.transpose(sample, (3, 0, 1, 2)) + + def get_scheduler_config(self): + raise NotImplementedError + + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) + + return model + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + +@require_flax +class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDPMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "variance_type": "fixed_small", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_variance_type(self): + for variance in ["fixed_small", "fixed_large", "other"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + key1, key2 = random.split(random.PRNGKey(0)) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + output = scheduler.step(state, residual, t, sample, key1) + pred_prev_sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 255.1113) < 1e-2 + assert abs(result_mean - 0.332176) < 1e-3 + + +@require_flax +class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDIMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + key1, key2 = random.split(random.PRNGKey(0)) + + num_inference_steps = 10 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + state = scheduler.set_timesteps(state, num_inference_steps) + + for t in state.timesteps: + residual = model(sample, t) + output = scheduler.step(state, residual, t, sample) + sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + return sample + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 172.0067) < 1e-2 + assert abs(result_mean - 0.223967) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 149.8295) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 149.0784) < 1e-2 + assert abs(result_mean - 0.1941) < 1e-3 + + +@require_flax +class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + state = state.replace(ets=dummy_past_residuals[:]) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + new_state = new_state.replace(ets=dummy_past_residuals[:]) + + (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + # copy over dummy past residuals + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residual (must be after setting timesteps) + new_state.replace(ets=dummy_past_residuals[:]) + + output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + for i, t in enumerate(state.prk_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_prk(state, residual, t, sample) + + for i, t in enumerate(state.plms_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_plms(state, residual, t, sample) + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + state = state.replace(ets=dummy_past_residuals[:]) + + output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 10, shape=()) + assert jnp.equal( + state.timesteps, + jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(state.prk_timesteps[:2]): + sample, state = scheduler.step_prk(state, residual, t, sample) + + def test_inference_plms_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 198.1318) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 186.9466) < 1e-2 + assert abs(result_mean - 0.24342) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 186.9482) < 1e-2 + assert abs(result_mean - 0.2434) < 1e-3