From ea5a3f51c8074b265fa8c0cd25e539c2f206ee77 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 28 Jul 2022 09:08:56 +0200 Subject: [PATCH 1/3] start repaint --- src/diffusers/pipelines/repaint/__init__.py | 1 + .../pipelines/repaint/pipeline_repaint.py | 67 +++++++ .../schedulers/scheduling_repaint.py | 177 ++++++++++++++++++ 3 files changed, 245 insertions(+) create mode 100644 src/diffusers/pipelines/repaint/__init__.py create mode 100644 src/diffusers/pipelines/repaint/pipeline_repaint.py create mode 100644 src/diffusers/schedulers/scheduling_repaint.py diff --git a/src/diffusers/pipelines/repaint/__init__.py b/src/diffusers/pipelines/repaint/__init__.py new file mode 100644 index 000000000000..16bc86d1cedf --- /dev/null +++ b/src/diffusers/pipelines/repaint/__init__.py @@ -0,0 +1 @@ +from .pipeline_repaint import RePaintPipeline diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py new file mode 100644 index 000000000000..7ff63495c530 --- /dev/null +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -0,0 +1,67 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import torch + +from tqdm.auto import tqdm + +from ...pipeline_utils import DiffusionPipeline + + +class RePaintPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size=1, + num_inference_steps=250, + jump_length=10, + jump_n_sample=10, + generator=None, + torch_device=None, + output_type="pil", + ): + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.unet.to(torch_device) + + # sample gaussian noise to begin the loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(torch_device) + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample) + + for t in tqdm(self.scheduler.timesteps): + # 1. predict the noise residual + model_output = self.unet(image, t)["sample"] + + # 2. compute previous image: x_t -> t_t-1 + image = self.scheduler.step(model_output, t, image)["prev_sample"] + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + return {"sample": image} diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py new file mode 100644 index 000000000000..ed76873f8a96 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -0,0 +1,177 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from typing import Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t + from 0 to 1 and + produces the cumulative product of (1-beta) up to that part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + trained_betas=None, + timestep_values=None, + clip_sample=True, + tensor_format="pt", + ): + + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps + )[::-1].copy() + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + ): + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + if not torch.is_tensor(model_output): + variance = variance.numpy() + + prev_sample = prev_sample + variance + + return {"prev_sample": prev_sample} + + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps From aa6da5ad722ff361a599d2196e2be91f06744813 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 1 Aug 2022 11:56:46 +0200 Subject: [PATCH 2/3] Test with a DDPM model --- examples/train_unconditional.py | 18 ++- src/diffusers/__init__.py | 11 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/repaint/pipeline_repaint.py | 39 ++++-- src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_repaint.py | 120 ++++++++++-------- src/diffusers/schedulers/scheduling_utils.py | 3 + tests/test_modeling_utils.py | 28 ++++ 8 files changed, 140 insertions(+), 81 deletions(-) diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 3d260c6faeed..5854a0d44495 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -7,7 +7,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset -from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel @@ -39,25 +39,23 @@ def main(args): in_channels=3, out_channels=3, layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), + block_out_channels=(128, 128, 256, 256, 512), down_block_types=( "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", - "AttnDownBlock2D", "DownBlock2D", ), up_block_types=( "UpBlock2D", - "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", ), ) - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + noise_scheduler = DDIMScheduler(num_train_timesteps=1000, tensor_format="pt") optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -150,7 +148,7 @@ def transforms(examples): # Generate sample images for visual inspection if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: - pipeline = DDPMPipeline( + pipeline = DDIMPipeline( unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), scheduler=noise_scheduler, ) @@ -179,15 +177,15 @@ def transforms(examples): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") + parser.add_argument("--dataset", type=str, default="huggan/smithsonian_butterflies_subset") parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64") parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--resolution", type=int, default=32) parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--save_images_epochs", type=int, default=10) - parser.add_argument("--save_model_epochs", type=int, default=10) + parser.add_argument("--save_images_epochs", type=int, default=1) + parser.add_argument("--save_model_epochs", type=int, default=100) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--lr_scheduler", type=str, default="cosine") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 32af42b56017..dae95f4606bf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,8 +18,15 @@ get_scheduler, ) from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline -from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler +from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + PNDMScheduler, + RePaintScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) from .training_utils import EMAModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 50855568ddb7..473fad770327 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,6 +3,7 @@ from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline +from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index 7ff63495c530..5e568397f389 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -18,10 +18,15 @@ from tqdm.auto import tqdm +from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline +from ...schedulers import RePaintScheduler class RePaintPipeline(DiffusionPipeline): + unet: UNet2DModel + scheduler: RePaintScheduler + def __init__(self, unet, scheduler): super().__init__() scheduler = scheduler.set_format("pt") @@ -30,7 +35,8 @@ def __init__(self, unet, scheduler): @torch.no_grad() def __call__( self, - batch_size=1, + original_image: torch.Tensor, + mask: torch.Tensor, num_inference_steps=250, jump_length=10, jump_n_sample=10, @@ -42,26 +48,33 @@ def __call__( torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.unet.to(torch_device) + original_image = original_image.to(torch_device) + mask = mask.to(torch_device) # sample gaussian noise to begin the loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + sample = torch.randn( + (original_image.shape[0], self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) - image = image.to(torch_device) + sample = sample.to(torch_device) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample) + t_last = self.scheduler.timesteps[-1] + 1 for t in tqdm(self.scheduler.timesteps): - # 1. predict the noise residual - model_output = self.unet(image, t)["sample"] - - # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image)["prev_sample"] + if t < t_last: + # predict the noise residual + model_output = self.unet(sample, t)["sample"] + # compute previous image: x_t -> x_t-1 + sample = self.scheduler.step(model_output, t, sample, original_image, mask, generator)["prev_sample"] + else: + # compute the reverse: x_t-1 -> x_t + sample = self.scheduler.undo_step(sample, t, generator) + t_last = t - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + sample = (sample / 2 + 0.5).clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": - image = self.numpy_to_pil(image) + sample = self.numpy_to_pil(sample) - return {"sample": image} + return {"sample": sample} diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 57a5c994522e..be7c0fba5b3b 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,7 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index ed76873f8a96..da11eb9f1e10 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -48,7 +48,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class DDIMScheduler(SchedulerMixin, ConfigMixin): +class RePaintScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, @@ -94,11 +94,27 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps): + def set_timesteps(self, num_inference_steps, jump_length=10, jump_n_sample=10): self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + timesteps = [] + + jumps = {} + for j in range(0, num_inference_steps - jump_length, jump_length): + jumps[j] = jump_n_sample - 1 + + t = num_inference_steps + while t >= 1: + t = t - 1 + timesteps.append(t) + + if jumps.get(t, 0) > 0: + jumps[t] = jumps[t] - 1 + for _ in range(jump_length): + t = t + 1 + timesteps.append(t) + + self.timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) + self.set_format(tensor_format=self.tensor_format) def step( @@ -106,72 +122,64 @@ def step( model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - eta: float = 0.0, - use_clipped_model_output: bool = False, + original_sample: Union[torch.FloatTensor, np.ndarray], + mask: Union[torch.FloatTensor, np.ndarray], generator=None, ): - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointingc to x_t" - # - pred_prev_sample -> "x_t-1" - - # 1. get previous step value (=t-1) + device = model_output.device prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one - beta_prod_t = 1 - alpha_prod_t + alpha = self.alphas[timestep] + alpha_prod = self.alphas_cumprod[timestep] + beta = self.betas[timestep] + alpha_prod_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + std_dev = self.sqrt(self._get_variance(timestep, prev_timestep)) + + if timestep > 1: + noise = torch.randn(model_output.shape, generator=generator).to(device) + else: + noise = torch.zeros(model_output.shape, device=device) - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # compute predicted original sample from predicted noise + pred_original_sample = (sample - self.sqrt(1 - alpha_prod) * model_output) / self.sqrt(alpha_prod) - # 4. Clip "predicted x_0" + # clip "predicted x_0" if self.config.clip_sample: pred_original_sample = self.clip(pred_original_sample, -1, 1) - # 5. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # add noise to the known pixels of the image + prev_known_part = self.sqrt(alpha_prod) * original_sample + self.sqrt(1 - alpha_prod) * noise + + # add noise to the unknown pixels of the image + posterior_mean_coef1 = ( + beta * self.sqrt(alpha_prod_prev) / + (1.0 - alpha_prod) + ) + posterior_mean_coef2 = ( + (1.0 - alpha_prod_prev) + * self.sqrt(alpha) + / (1.0 - alpha_prod) + ) + prev_unknown_part = posterior_mean_coef1 * pred_original_sample + posterior_mean_coef2 * sample + prev_unknown_part = prev_unknown_part + std_dev * noise + #pred_sample_direction = self.sqrt(1 - alpha_prod_prev - std_dev ** 2) * model_output + #prev_unknown_part = self.sqrt(alpha_prod_prev) * pred_original_sample + pred_sample_direction + #prev_unknown_part = prev_unknown_part + std_dev * noise + + prev_sample = mask * prev_known_part + (1 - mask) * prev_unknown_part - if eta > 0: - device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + return {"prev_sample": prev_sample} - if not torch.is_tensor(model_output): - variance = variance.numpy() + def undo_step(self, sample, timestep, generator=None): + beta = self.betas[timestep] - prev_sample = prev_sample + variance + noise = torch.randn(sample.shape, generator=generator).to(sample.device) + next_sample = self.sqrt(1 - beta) * sample + self.sqrt(beta) * noise - return {"prev_sample": prev_sample} + return next_sample def add_noise(self, original_samples, noise, timesteps): - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index b0cd4bda104f..40a15b6a85c4 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -54,6 +54,9 @@ def log(self, tensor): raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + def sqrt(self, tensor): + return tensor**0.5 + def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): """ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index cd5767c4d0a0..cbc793e11a8c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -33,6 +33,8 @@ LDMTextToImagePipeline, PNDMPipeline, PNDMScheduler, + RePaintPipeline, + RePaintScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel, @@ -920,3 +922,29 @@ def test_ddpm_ddim_equality_batched(self): # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 + + #@slow + def test_repaint_celebahq(self): + from datasets import load_dataset + + dataset = load_dataset('huggan/CelebA-HQ', split='train', streaming=True) + original_image = next(iter(dataset))["image"].resize((256, 256)) + original_image = torch.tensor(np.array(original_image)).permute(2, 0, 1).unsqueeze(0) + original_image = (original_image / 255.0) * 2 - 1 + mask = torch.zeros_like(original_image) + mask[:, :, :128, :] = 1 # mask the top half of the image + + model_id = "google/ddpm-ema-celebahq-256" + unet = UNet2DModel.from_pretrained(model_id) + scheduler = RePaintScheduler.from_config(model_id) + + repaint = RePaintPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = repaint(original_image, mask, generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 8a9ec8b0cd6ab4c19fd9cd0b9dc92e7efb7c7b4c Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 1 Sep 2022 17:16:14 +0200 Subject: [PATCH 3/3] move tests --- src/diffusers/__init__.py | 10 ++++++- .../schedulers/scheduling_repaint.py | 17 ++++------- tests/test_pipelines.py | 28 +++++++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6a79f6915a02..a9aceb227e79 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -15,7 +15,15 @@ get_scheduler, ) from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .pipelines import ( + DDIMPipeline, + DDPMPipeline, + KarrasVePipeline, + LDMPipeline, + PNDMPipeline, + RePaintPipeline, + ScoreSdeVePipeline, +) from .schedulers import ( DDIMScheduler, DDPMScheduler, diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index da11eb9f1e10..d07ae01daa45 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -151,20 +151,13 @@ def step( prev_known_part = self.sqrt(alpha_prod) * original_sample + self.sqrt(1 - alpha_prod) * noise # add noise to the unknown pixels of the image - posterior_mean_coef1 = ( - beta * self.sqrt(alpha_prod_prev) / - (1.0 - alpha_prod) - ) - posterior_mean_coef2 = ( - (1.0 - alpha_prod_prev) - * self.sqrt(alpha) - / (1.0 - alpha_prod) - ) + posterior_mean_coef1 = beta * self.sqrt(alpha_prod_prev) / (1.0 - alpha_prod) + posterior_mean_coef2 = (1.0 - alpha_prod_prev) * self.sqrt(alpha) / (1.0 - alpha_prod) prev_unknown_part = posterior_mean_coef1 * pred_original_sample + posterior_mean_coef2 * sample prev_unknown_part = prev_unknown_part + std_dev * noise - #pred_sample_direction = self.sqrt(1 - alpha_prod_prev - std_dev ** 2) * model_output - #prev_unknown_part = self.sqrt(alpha_prod_prev) * pred_original_sample + pred_sample_direction - #prev_unknown_part = prev_unknown_part + std_dev * noise + # pred_sample_direction = self.sqrt(1 - alpha_prod_prev - std_dev ** 2) * model_output + # prev_unknown_part = self.sqrt(alpha_prod_prev) * pred_original_sample + pred_sample_direction + # prev_unknown_part = prev_unknown_part + std_dev * noise prev_sample = mask * prev_known_part + (1 - mask) * prev_unknown_part diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 011604775558..d5482bb54b1b 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -35,6 +35,8 @@ LMSDiscreteScheduler, PNDMPipeline, PNDMScheduler, + RePaintPipeline, + RePaintScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, @@ -896,3 +898,29 @@ def test_stable_diffusion_in_paint_pipeline(self): assert sampled_array.shape == (512, 768, 3) assert np.max(np.abs(sampled_array - expected_array)) < 1e-3 + + @slow + def test_repaint_celebahq(self): + from datasets import load_dataset + + dataset = load_dataset("huggan/CelebA-HQ", split="train", streaming=True) + original_image = next(iter(dataset))["image"].resize((256, 256)) + original_image = torch.tensor(np.array(original_image)).permute(2, 0, 1).unsqueeze(0) + original_image = (original_image / 255.0) * 2 - 1 + mask = torch.zeros_like(original_image) + mask[:, :, :128, :] = 1 # mask the top half of the image + + model_id = "google/ddpm-ema-celebahq-256" + unet = UNet2DModel.from_pretrained(model_id) + scheduler = RePaintScheduler.from_config(model_id) + + repaint = RePaintPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = repaint(original_image, mask, generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2