Skip to content

Commit

Permalink
[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (#…
Browse files Browse the repository at this point in the history
…864)

* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
kashif and patrickvonplaten authored Oct 25, 2022
1 parent 38ae5a2 commit 240abdd
Show file tree
Hide file tree
Showing 9 changed files with 931 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else:
from ..utils.dummy_flax_objects import * # noqa F403

Expand Down
15 changes: 6 additions & 9 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -173,7 +173,9 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):

return variance

def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -211,9 +213,6 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns:
Expand Down Expand Up @@ -279,13 +278,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
23 changes: 13 additions & 10 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -129,11 +133,12 @@ def __init__(
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)

self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)

self.variance_type = variance_type
def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)

def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -214,7 +219,7 @@ def step(
"""
t = timestep

if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
Expand Down Expand Up @@ -267,13 +272,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
A reasonable range is [0.2, 80].
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -97,10 +101,13 @@ def __init__(
s_min: float = 0.05,
s_max: float = 50,
):
self.state = KarrasVeSchedulerState.create()
pass

def create_state(self):
return KarrasVeSchedulerState.create()

def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
15 changes: 10 additions & 5 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


@flax.struct.dataclass
Expand Down Expand Up @@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -85,8 +89,10 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

def create_state(self):
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)

def get_lms_coefficient(self, state, order, t, current_order):
Expand All @@ -112,7 +118,7 @@ def lms_derivative(tau):
return integrated_coeff

def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -199,8 +205,7 @@ def add_noise(
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
sigma = broadcast_to_shape_from_left(sigma, noise.shape)

noisy_samples = original_samples + noise * sigma

Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -168,6 +168,8 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
"""
offset = self.config.steps_offset

Expand Down Expand Up @@ -509,13 +511,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
26 changes: 18 additions & 8 deletions src/diffusers/schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


@flax.struct.dataclass
Expand Down Expand Up @@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample.
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -90,12 +94,20 @@ def __init__(
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
state = ScoreSdeVeSchedulerState.create()
pass

self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def create_state(self):
state = ScoreSdeVeSchedulerState.create()
return self.set_sigmas(
state,
self.config.num_train_timesteps,
self.config.sigma_min,
self.config.sigma_max,
self.config.sampling_eps,
)

def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -193,8 +205,7 @@ def step_pred(
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
drift = drift - diffusion**2 * model_output

# equation 6: sample noise for the diffusion term of
Expand Down Expand Up @@ -252,8 +263,7 @@ def step_correct(

# compute corrected sample: model_output term and noise term
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
step_size = broadcast_to_shape_from_left(step_size, sample.shape)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/schedulers/scheduling_utils_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple

import jax.numpy as jnp

Expand Down Expand Up @@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
"""

config_name = SCHEDULER_CONFIG_NAME


def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
Loading

0 comments on commit 240abdd

Please sign in to comment.