From fe10bc6f9abeeddaa4e80ee9db1842be28134f7d Mon Sep 17 00:00:00 2001 From: okotaku Date: Wed, 1 Nov 2023 12:52:29 +0000 Subject: [PATCH 1/6] support noise method --- .pre-commit-config.yaml | 4 +- README.md | 3 + configs/input_perturbation/README.md | 46 +++++++++ ...sion_xl_pokemon_blip_input_perturbation.py | 12 +++ configs/offset_noise/README.md | 40 ++++++++ ..._diffusion_xl_pokemon_blip_offset_noise.py | 12 +++ configs/pyramid_noise/README.md | 40 ++++++++ ...diffusion_xl_pokemon_blip_pyramid_noise.py | 12 +++ diffengine/models/__init__.py | 1 + .../editors/deepfloyd_if/deepfloyd_if.py | 41 ++++---- .../editors/distill_sd/distill_sd_xl.py | 8 +- .../editors/ip_adapter/ip_adapter_xl.py | 16 +-- diffengine/models/editors/ssd_1b/ssd_1b.py | 21 ++-- .../stable_diffusion/stable_diffusion.py | 37 ++++--- .../stable_diffusion_controlnet.py | 8 +- .../stable_diffusion_xl.py | 37 ++++--- .../stable_diffusion_xl_controlnet.py | 8 +- .../stable_diffusion_xl_t2i_adapter.py | 8 +- diffengine/models/utils/__init__.py | 3 + diffengine/models/utils/noise.py | 97 +++++++++++++++++++ pyproject.toml | 3 + tests/test_models/test_utils/test_noise.py | 56 +++++++++++ 22 files changed, 420 insertions(+), 93 deletions(-) create mode 100644 configs/input_perturbation/README.md create mode 100644 configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py create mode 100644 configs/offset_noise/README.md create mode 100644 configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py create mode 100644 configs/pyramid_noise/README.md create mode 100644 configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py create mode 100644 diffengine/models/utils/__init__.py create mode 100644 diffengine/models/utils/noise.py create mode 100644 tests/test_models/test_utils/test_noise.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b6d68f..df48d02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,9 +18,11 @@ repos: - id: mixed-line-ending args: ["--fix=lf"] - repo: https://github.com/codespell-project/codespell - rev: v2.2.1 + rev: v2.2.4 hooks: - id: codespell + additional_dependencies: + - tomli - repo: https://github.com/executablebooks/mdformat rev: 0.7.9 hooks: diff --git a/README.md b/README.md index c017034..8074f7d 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,9 @@ For detailed user guides and advanced guides, please refer to our [Documentation diff --git a/configs/input_perturbation/README.md b/configs/input_perturbation/README.md new file mode 100644 index 0000000..d3da67c --- /dev/null +++ b/configs/input_perturbation/README.md @@ -0,0 +1,46 @@ +# Input Perturbation + +[Input Perturbation Reduces Exposure Bias in Diffusion Models](https://arxiv.org/abs/2301.11706) + +## Abstract + +Denoising Diffusion Probabilistic Models have shown an impressive generation quality, although their long sampling chain leads to high computational costs. In this paper, we observe that a long sampling chain also leads to an error accumulation phenomenon, which is similar to the exposure bias problem in autoregressive text generation. Specifically, we note that there is a discrepancy between training and testing, since the former is conditioned on the ground truth samples, while the latter is conditioned on the previously generated results. To alleviate this problem, we propose a very simple but effective training regularization, consisting in perturbing the ground truth samples to simulate the inference time prediction errors. We empirically show that, without affecting the recall and precision, the proposed input perturbation leads to a significant improvement in the sample quality while reducing both the training and the inference times. For instance, on CelebA 64×64, we achieve a new state-of-the-art FID score of 1.27, while saving 37.5% of the training time. + +
+ +
+ +## Citation + +``` +@article{ning2023input, + title={Input Perturbation Reduces Exposure Bias in Diffusion Models}, + author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita}, + journal={arXiv preprint arXiv:2301.11706}, + year={2023} +} +``` + +## Run Training + +Run Training + +``` +# single gpu +$ mim train diffengine ${CONFIG_FILE} +# multi gpus +$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch + +# Example. +$ mim train diffengine configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py +``` + +## Inference with diffusers + +You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers). + +## Results Example + +#### stable_diffusion_xl_pokemon_blip_input_perturbation + +![example1](<>) diff --git a/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py b/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py new file mode 100644 index 0000000..d023662 --- /dev/null +++ b/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py @@ -0,0 +1,12 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict(input_perturbation_gamma=0.1) + +train_dataloader = dict(batch_size=1) + +optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times diff --git a/configs/offset_noise/README.md b/configs/offset_noise/README.md new file mode 100644 index 0000000..58ffe03 --- /dev/null +++ b/configs/offset_noise/README.md @@ -0,0 +1,40 @@ +# Offset Noise + +[Diffusion with Offset Noise](https://www.crosslabs.org/blog/diffusion-with-offset-noise) + +## Abstract + +Fine-tuning against a modified noise, enables Stable Diffusion to generate very dark or light images easily. + +
+ +
+ +## Citation + +``` +``` + +## Run Training + +Run Training + +``` +# single gpu +$ mim train diffengine ${CONFIG_FILE} +# multi gpus +$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch + +# Example. +$ mim train diffengine configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py +``` + +## Inference with diffusers + +You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers). + +## Results Example + +#### stable_diffusion_xl_pokemon_blip_offset_noise + +![example1](https://github.com/okotaku/diffengine/assets/24734142/7a3b26ff-618b-46f0-827e-32c2d47cde6f) diff --git a/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py b/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py new file mode 100644 index 0000000..a434002 --- /dev/null +++ b/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py @@ -0,0 +1,12 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict(noise_generator=dict(type="OffsetNoise", offset_weight=0.05)) + +train_dataloader = dict(batch_size=1) + +optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times diff --git a/configs/pyramid_noise/README.md b/configs/pyramid_noise/README.md new file mode 100644 index 0000000..d0ccbea --- /dev/null +++ b/configs/pyramid_noise/README.md @@ -0,0 +1,40 @@ +# Pyramid Noise + +[Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) + +## Abstract + +This report proposes a new noising approach that adds multi-resolution noise to an image or latent image during diffusion model training. A model trained with this technique can generate stunning images with a very different aesthetic to the usual diffusion model outputs. This seems like a promising direction for future research. + +
+ +
+ +## Citation + +``` +``` + +## Run Training + +Run Training + +``` +# single gpu +$ mim train diffengine ${CONFIG_FILE} +# multi gpus +$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch + +# Example. +$ mim train diffengine configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py +``` + +## Inference with diffusers + +You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers). + +## Results Example + +#### stable_diffusion_xl_pokemon_blip_pyramid_noise + +![example1](<>) diff --git a/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py b/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py new file mode 100644 index 0000000..12dc11e --- /dev/null +++ b/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py @@ -0,0 +1,12 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict(noise_generator=dict(type="PyramidNoise", discount=0.9)) + +train_dataloader = dict(batch_size=1) + +optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times diff --git a/diffengine/models/__init__.py b/diffengine/models/__init__.py index 79a890e..030ba7c 100644 --- a/diffengine/models/__init__.py +++ b/diffengine/models/__init__.py @@ -1,2 +1,3 @@ from .editors import * # noqa: F403 from .losses import * # noqa: F403 +from .utils import * # noqa: F403 diff --git a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py index 8eb4df9..e1bc3f6 100644 --- a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py +++ b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py @@ -31,10 +31,6 @@ class DeepFloydIF(BaseModel): example. dict(rank=4). Defaults to None. prior_loss_weight (float): The weight of prior preservation loss. It works when training dreambooth with class images. - noise_offset_weight (bool, optional): - The weight of noise offset introduced in - https://www.crosslabs.org/blog/diffusion-with-offset-noise - Defaults to 0. tokenizer_max_length (int): The max length of tokenizer. Defaults to 77. prediction_type (str): The prediction_type that shall be used for @@ -43,6 +39,11 @@ class DeepFloydIF(BaseModel): scheduler: `noise_scheduler.config.prediciton_type` is chosen. data_preprocessor (dict, optional): The pre-process config of :class:`SDDataPreprocessor`. + noise_generator (dict, optional): The noise generator config. + Defaults to ``dict(type='WhiteNoise')``. + input_perturbation_gamma (float): The gamma of input perturbation. + The recommended value is 0.1 for Input Perturbation. + Defaults to 0.0. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -56,16 +57,19 @@ def __init__( loss: dict | None = None, lora_config: dict | None = None, prior_loss_weight: float = 1., - noise_offset_weight: float = 0, tokenizer_max_length: int = 77, prediction_type: str | None = None, data_preprocessor: dict | nn.Module | None = None, + noise_generator: dict | None = None, + input_perturbation_gamma: float = 0.0, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDDataPreprocessor"} + if noise_generator is None: + noise_generator = {"type": "WhiteNoise"} if loss is None: loss = {"type": "L2Loss", "loss_weight": 1.0} super().__init__(data_preprocessor=data_preprocessor) @@ -75,13 +79,12 @@ def __init__( self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing self.tokenizer_max_length = tokenizer_max_length + self.input_perturbation_gamma = input_perturbation_gamma if not isinstance(loss, nn.Module): loss = MODELS.build(loss) self.loss_module: nn.Module = loss - self.enable_noise_offset = noise_offset_weight > 0 - self.noise_offset_weight = noise_offset_weight assert prediction_type in [None, "epsilon", "v_prediction"] self.prediction_type = prediction_type @@ -94,6 +97,7 @@ def __init__( model, subfolder="text_encoder") self.unet = UNet2DConditionModel.from_pretrained( model, subfolder="unet") + self.noise_generator = MODELS.build(noise_generator) self.prepare_model() self.set_lora() @@ -244,6 +248,17 @@ def loss(self, loss_dict["loss"] = loss return loss_dict + def _preprocess_model_input(self, + latents: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor) -> torch.Tensor: + if self.input_perturbation_gamma > 0: + input_noise = self.input_perturbation_gamma * torch.randn_like( + noise) + else: + input_noise = noise + return self.scheduler.add_noise(latents, input_noise, timesteps) + def forward( self, inputs: torch.Tensor, @@ -283,15 +298,7 @@ def forward( model_input = inputs["img"] - noise = torch.randn_like(model_input) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - model_input.shape[0], - model_input.shape[1], - 1, - 1, - device=noise.device) + noise = self.noise_generator(model_input) num_batches = model_input.shape[0] timesteps = torch.randint( @@ -300,7 +307,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_model_input = self.scheduler.add_noise(model_input, noise, + noisy_model_input = self._preprocess_model_input(model_input, noise, timesteps) encoder_hidden_states = self.text_encoder( diff --git a/diffengine/models/editors/distill_sd/distill_sd_xl.py b/diffengine/models/editors/distill_sd/distill_sd_xl.py index c9468ce..52104c1 100644 --- a/diffengine/models/editors/distill_sd/distill_sd_xl.py +++ b/diffengine/models/editors/distill_sd/distill_sd_xl.py @@ -161,11 +161,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -173,7 +169,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) if not self.pre_compute_text_embeddings: inputs["text_one"] = self.tokenizer_one( diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py index daa23d8..8407df1 100644 --- a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py +++ b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py @@ -242,11 +242,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -254,7 +250,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) @@ -401,11 +397,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -413,7 +405,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) diff --git a/diffengine/models/editors/ssd_1b/ssd_1b.py b/diffengine/models/editors/ssd_1b/ssd_1b.py index b078352..e3b3f35 100644 --- a/diffengine/models/editors/ssd_1b/ssd_1b.py +++ b/diffengine/models/editors/ssd_1b/ssd_1b.py @@ -45,16 +45,14 @@ class SSD1B(StableDiffusionXL): example. dict(rank=4). Defaults to None. prior_loss_weight (float): The weight of prior preservation loss. It works when training dreambooth with class images. - noise_offset_weight (bool, optional): - The weight of noise offset introduced in - https://www.crosslabs.org/blog/diffusion-with-offset-noise - Defaults to 0. prediction_type (str): The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen. data_preprocessor (dict, optional): The pre-process config of :class:`SDXLDataPreprocessor`. + noise_generator (dict, optional): The noise generator config. + Defaults to ``dict(type='WhiteNoise')``. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -73,9 +71,9 @@ def __init__( loss: dict | None = None, lora_config: dict | None = None, prior_loss_weight: float = 1., - noise_offset_weight: float = 0, prediction_type: str | None = None, data_preprocessor: dict | nn.Module | None = None, + noise_generator: dict | None = None, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -90,6 +88,8 @@ def __init__( if data_preprocessor is None: data_preprocessor = {"type": "SDXLDataPreprocessor"} + if noise_generator is None: + noise_generator = {"type": "WhiteNoise"} if loss is None: loss = {"type": "L2Loss", "loss_weight": 1.0} super(StableDiffusionXL, self).__init__(data_preprocessor=data_preprocessor) @@ -106,8 +106,6 @@ def __init__( loss = MODELS.build(loss) self.loss_module: nn.Module = loss - self.enable_noise_offset = noise_offset_weight > 0 - self.noise_offset_weight = noise_offset_weight assert prediction_type in [None, "epsilon", "v_prediction"] self.prediction_type = prediction_type @@ -144,6 +142,7 @@ def __init__( elif student_model_weight == "unet": self.unet = UNet2DConditionModel.from_pretrained( student_model, subfolder="unet") + self.noise_generator = MODELS.build(noise_generator) self.prepare_model() self.set_lora() @@ -257,11 +256,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -269,7 +264,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) if not self.pre_compute_text_embeddings: inputs["text_one"] = self.tokenizer_one( diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py index 79b4885..cfa5386 100644 --- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py +++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py @@ -32,15 +32,16 @@ class StableDiffusion(BaseModel): example. dict(rank=4). Defaults to None. prior_loss_weight (float): The weight of prior preservation loss. It works when training dreambooth with class images. - noise_offset_weight (bool, optional): - The weight of noise offset introduced in - https://www.crosslabs.org/blog/diffusion-with-offset-noise - Defaults to 0. prediction_type (str): The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the data_preprocessor (dict, optional): The pre-process config of :class:`SDDataPreprocessor`. + noise_generator (dict, optional): The noise generator config. + Defaults to ``dict(type='WhiteNoise')``. + input_perturbation_gamma (float): The gamma of input perturbation. + The recommended value is 0.1 for Input Perturbation. + Defaults to 0.0. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -54,15 +55,18 @@ def __init__( loss: dict | None = None, lora_config: dict | None = None, prior_loss_weight: float = 1., - noise_offset_weight: float = 0, prediction_type: str | None = None, data_preprocessor: dict | nn.Module | None = None, + noise_generator: dict | None = None, + input_perturbation_gamma: float = 0.0, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDDataPreprocessor"} + if noise_generator is None: + noise_generator = {"type": "WhiteNoise"} if loss is None: loss = {"type": "L2Loss", "loss_weight": 1.0} super().__init__(data_preprocessor=data_preprocessor) @@ -71,13 +75,12 @@ def __init__( self.finetune_text_encoder = finetune_text_encoder self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing + self.input_perturbation_gamma = input_perturbation_gamma if not isinstance(loss, nn.Module): loss = MODELS.build(loss) self.loss_module: nn.Module = loss - self.enable_noise_offset = noise_offset_weight > 0 - self.noise_offset_weight = noise_offset_weight assert prediction_type in [None, "epsilon", "v_prediction"] self.prediction_type = prediction_type @@ -91,6 +94,7 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae") self.unet = UNet2DConditionModel.from_pretrained( model, subfolder="unet") + self.noise_generator = MODELS.build(noise_generator) self.prepare_model() self.set_lora() @@ -244,6 +248,17 @@ def loss(self, loss_dict["loss"] = loss return loss_dict + def _preprocess_model_input(self, + latents: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor) -> torch.Tensor: + if self.input_perturbation_gamma > 0: + input_noise = self.input_perturbation_gamma * torch.randn_like( + noise) + else: + input_noise = noise + return self.scheduler.add_noise(latents, input_noise, timesteps) + def forward( self, inputs: torch.Tensor, @@ -282,11 +297,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) num_batches = latents.shape[0] timesteps = torch.randint( @@ -295,7 +306,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) encoder_hidden_states = self.text_encoder(inputs["text"])[0] diff --git a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py index 9770718..14f1e7a 100644 --- a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py @@ -209,11 +209,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) num_batches = latents.shape[0] timesteps = torch.randint( @@ -222,7 +218,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) encoder_hidden_states = self.text_encoder(inputs["text"])[0] diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index ab1cb2f..8c5020e 100644 --- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -59,16 +59,17 @@ class StableDiffusionXL(BaseModel): example. dict(rank=4). Defaults to None. prior_loss_weight (float): The weight of prior preservation loss. It works when training dreambooth with class images. - noise_offset_weight (bool, optional): - The weight of noise offset introduced in - https://www.crosslabs.org/blog/diffusion-with-offset-noise - Defaults to 0. prediction_type (str): The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen. data_preprocessor (dict, optional): The pre-process config of :class:`SDXLDataPreprocessor`. + noise_generator (dict, optional): The noise generator config. + Defaults to ``dict(type='WhiteNoise')``. + input_perturbation_gamma (float): The gamma of input perturbation. + The recommended value is 0.1 for Input Perturbation. + Defaults to 0.0. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -85,9 +86,10 @@ def __init__( loss: dict | None = None, lora_config: dict | None = None, prior_loss_weight: float = 1., - noise_offset_weight: float = 0, prediction_type: str | None = None, data_preprocessor: dict | nn.Module | None = None, + noise_generator: dict | None = None, + input_perturbation_gamma: float = 0.0, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -95,6 +97,8 @@ def __init__( ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDXLDataPreprocessor"} + if noise_generator is None: + noise_generator = {"type": "WhiteNoise"} if loss is None: loss = {"type": "L2Loss", "loss_weight": 1.0} super().__init__(data_preprocessor=data_preprocessor) @@ -104,6 +108,7 @@ def __init__( self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing self.pre_compute_text_embeddings = pre_compute_text_embeddings + self.input_perturbation_gamma = input_perturbation_gamma if pre_compute_text_embeddings: assert not finetune_text_encoder @@ -111,8 +116,6 @@ def __init__( loss = MODELS.build(loss) self.loss_module: nn.Module = loss - self.enable_noise_offset = noise_offset_weight > 0 - self.noise_offset_weight = noise_offset_weight assert prediction_type in [None, "epsilon", "v_prediction"] self.prediction_type = prediction_type @@ -139,6 +142,7 @@ def __init__( vae_path, subfolder="vae" if vae_model is None else None) self.unet = UNet2DConditionModel.from_pretrained( model, subfolder="unet") + self.noise_generator = MODELS.build(noise_generator) self.prepare_model() self.set_lora() @@ -350,6 +354,17 @@ def loss(self, loss_dict["loss"] = loss return loss_dict + def _preprocess_model_input(self, + latents: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor) -> torch.Tensor: + if self.input_perturbation_gamma > 0: + input_noise = self.input_perturbation_gamma * torch.randn_like( + noise) + else: + input_noise = noise + return self.scheduler.add_noise(latents, input_noise, timesteps) + def forward( self, inputs: torch.Tensor, @@ -382,11 +397,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -394,7 +405,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) if not self.pre_compute_text_embeddings: inputs["text_one"] = self.tokenizer_one( diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py index 58d10ed..5590146 100644 --- a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py @@ -219,11 +219,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) timesteps = torch.randint( 0, @@ -231,7 +227,7 @@ def forward( device=self.device) timesteps = timesteps.long() - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) diff --git a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py index 16661b5..93c6825 100644 --- a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py +++ b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py @@ -210,11 +210,7 @@ def forward( latents = self.vae.encode(inputs["img"]).latent_dist.sample() latents = latents * self.vae.config.scaling_factor - noise = torch.randn_like(latents) - - if self.enable_noise_offset: - noise = noise + self.noise_offset_weight * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + noise = self.noise_generator(latents) # Cubic sampling to sample a random time step for each image. # For more details about why cubic sampling is used, refer to section @@ -226,7 +222,7 @@ def forward( timesteps = timesteps.clamp( 0, self.scheduler.config.num_train_timesteps - 1) - noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) diff --git a/diffengine/models/utils/__init__.py b/diffengine/models/utils/__init__.py new file mode 100644 index 0000000..325fc3b --- /dev/null +++ b/diffengine/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .noise import OffsetNoise, PyramidNoise, WhiteNoise + +__all__ = ["WhiteNoise", "OffsetNoise", "PyramidNoise"] diff --git a/diffengine/models/utils/noise.py b/diffengine/models/utils/noise.py new file mode 100644 index 0000000..324a1f1 --- /dev/null +++ b/diffengine/models/utils/noise.py @@ -0,0 +1,97 @@ +import random + +import torch +from torch import nn + +from diffengine.registry import MODELS + + +@MODELS.register_module() +class WhiteNoise(nn.Module): + """White noise module.""" + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Generates noise for the given latents. + + Args: + ---- + latents (torch.Tensor): Latent vectors. + """ + return torch.randn_like(latents) + + +@MODELS.register_module() +class OffsetNoise(nn.Module): + """Offset noise module. + + https://www.crosslabs.org/blog/diffusion-with-offset-noise + + Args: + ---- + offset_weight (float): Noise offset weight. Defaults to 0.05. + """ + + def __init__(self, offset_weight: float = 0.05) -> None: + super().__init__() + self.offset_weight = offset_weight + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Generates noise for the given latents. + + Args: + ---- + latents (torch.Tensor): Latent vectors. + """ + noise = torch.randn_like(latents) + return noise + self.offset_weight * torch.randn( + latents.shape[0], latents.shape[1], 1, 1, device=noise.device) + + +@MODELS.register_module() +class PyramidNoise(nn.Module): + """Pyramid noise module. + + https://wandb.ai/johnowhitaker/multires_noise/reports/ + Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 + + Args: + ---- + discount (float): Noise offset weight. Defaults to 0.9. + random_multiplier (bool): Whether to use random multiplier. + Defaults to True. + """ + + def __init__(self, discount: float = 0.9, + *, + random_multiplier: bool = True) -> None: + super().__init__() + self.discount = discount + self.random_multiplier = random_multiplier + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Generates noise for the given latents. + + Args: + ---- + latents (torch.Tensor): Latent vectors. + """ + noise = torch.randn_like(latents) + + b, c, w, h = latents.shape + u = nn.Upsample(size=(w, h), mode="bilinear") + for i in range(16): + r = random.random() * 2 + 2 if self.random_multiplier else 2 # noqa: S311 + + w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) + noise += u( + torch.randn(b, c, w, h).to(latents)) * self.discount ** i + if w==1 or h==1: + break + + return noise / noise.std() diff --git a/pyproject.toml b/pyproject.toml index 1a675e0..70d1eeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,3 +90,6 @@ python_version = "3.11" no_strict_optional = true ignore_missing_imports = true check_untyped_defs = true + +[tool.codespell] +ignore-words-list = "enver," diff --git a/tests/test_models/test_utils/test_noise.py b/tests/test_models/test_utils/test_noise.py new file mode 100644 index 0000000..59b44f3 --- /dev/null +++ b/tests/test_models/test_utils/test_noise.py @@ -0,0 +1,56 @@ +from unittest import TestCase + +import torch + +from diffengine.models.utils import OffsetNoise, PyramidNoise, WhiteNoise + + +class TestWhiteNoise(TestCase): + + def test_init(self): + _ = WhiteNoise() + + def test_forward(self): + module = WhiteNoise() + latens = torch.randn(1, 4, 16, 16) + noise = module(latens) + assert latens.shape == noise.shape + + +class TestOffsetNoise(TestCase): + + def test_init(self): + module = OffsetNoise() + assert module.offset_weight == 0.05 + + module = OffsetNoise(offset_weight=0.2) + assert module.offset_weight == 0.2 + + def test_forward(self): + module = OffsetNoise() + latens = torch.randn(1, 4, 16, 16) + noise = module(latens) + assert latens.shape == noise.shape + + +class TestPyramidNoise(TestCase): + + def test_init(self): + module = PyramidNoise() + assert module.discount == 0.9 + assert module.random_multiplier + + module = PyramidNoise(discount=0.8, random_multiplier=False) + assert module.discount == 0.8 + assert not module.random_multiplier + + def test_forward(self): + module = PyramidNoise() + latens = torch.randn(1, 4, 16, 16) + noise = module(latens) + assert latens.shape == noise.shape + + module = PyramidNoise(random_multiplier=False) + latens = torch.randn(1, 4, 16, 16) + noise = module(latens) + assert latens.shape == noise.shape From 26e18908ef830c859f0756f5b8ee17204508807f Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 3 Nov 2023 02:02:10 +0000 Subject: [PATCH 2/6] fix input_perturbation --- configs/input_perturbation/README.md | 2 +- configs/pyramid_noise/README.md | 2 +- .../editors/deepfloyd_if/deepfloyd_if.py | 2 +- .../stable_diffusion/stable_diffusion.py | 2 +- .../stable_diffusion_xl.py | 2 +- .../test_deepfloyd_if/test_deepfloyd_if.py | 17 ++++++++++++++++ .../test_stable_diffusion.py | 17 ++++++++++++++++ .../test_stable_diffusion_xl.py | 20 +++++++++++++++++++ 8 files changed, 59 insertions(+), 5 deletions(-) diff --git a/configs/input_perturbation/README.md b/configs/input_perturbation/README.md index d3da67c..76d1805 100644 --- a/configs/input_perturbation/README.md +++ b/configs/input_perturbation/README.md @@ -43,4 +43,4 @@ You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/ru #### stable_diffusion_xl_pokemon_blip_input_perturbation -![example1](<>) +![example1](https://github.com/okotaku/diffengine/assets/24734142/b0a631e7-153c-467a-9cb6-d9155eaa7161) diff --git a/configs/pyramid_noise/README.md b/configs/pyramid_noise/README.md index d0ccbea..c6a362b 100644 --- a/configs/pyramid_noise/README.md +++ b/configs/pyramid_noise/README.md @@ -37,4 +37,4 @@ You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/ru #### stable_diffusion_xl_pokemon_blip_pyramid_noise -![example1](<>) +![example1](https://github.com/okotaku/diffengine/assets/24734142/8ee2f0b1-6ef6-4b5e-a018-8b0acbd73ec9) diff --git a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py index e1bc3f6..f079dde 100644 --- a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py +++ b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py @@ -253,7 +253,7 @@ def _preprocess_model_input(self, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: if self.input_perturbation_gamma > 0: - input_noise = self.input_perturbation_gamma * torch.randn_like( + input_noise = noise + self.input_perturbation_gamma * torch.randn_like( noise) else: input_noise = noise diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py index cfa5386..7539992 100644 --- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py +++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py @@ -253,7 +253,7 @@ def _preprocess_model_input(self, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: if self.input_perturbation_gamma > 0: - input_noise = self.input_perturbation_gamma * torch.randn_like( + input_noise = noise + self.input_perturbation_gamma * torch.randn_like( noise) else: input_noise = noise diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index 8c5020e..a49ce90 100644 --- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -359,7 +359,7 @@ def _preprocess_model_input(self, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: if self.input_perturbation_gamma > 0: - input_noise = self.input_perturbation_gamma * torch.randn_like( + input_noise = noise + self.input_perturbation_gamma * torch.randn_like( noise) else: input_noise = noise diff --git a/tests/test_models/test_editors/test_deepfloyd_if/test_deepfloyd_if.py b/tests/test_models/test_editors/test_deepfloyd_if/test_deepfloyd_if.py index d81422d..4fcf1b8 100644 --- a/tests/test_models/test_editors/test_deepfloyd_if/test_deepfloyd_if.py +++ b/tests/test_models/test_editors/test_deepfloyd_if/test_deepfloyd_if.py @@ -62,6 +62,23 @@ def test_train_step(self): assert log_vars assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_input_perturbation(self): + # test load with loss module + StableDiffuser = DeepFloydIF( + "hf-internal-testing/tiny-if-pipe", + input_perturbation_gamma=0.1, + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_with_gradient_checkpointing(self): # test load with loss module StableDiffuser = DeepFloydIF( diff --git a/tests/test_models/test_editors/test_stable_diffusion/test_stable_diffusion.py b/tests/test_models/test_editors/test_stable_diffusion/test_stable_diffusion.py index 96d87ef..953dfa7 100644 --- a/tests/test_models/test_editors/test_stable_diffusion/test_stable_diffusion.py +++ b/tests/test_models/test_editors/test_stable_diffusion/test_stable_diffusion.py @@ -62,6 +62,23 @@ def test_train_step(self): assert log_vars assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_input_perturbation(self): + # test load with loss module + StableDiffuser = StableDiffusion( + "diffusers/tiny-stable-diffusion-torch", + input_perturbation_gamma=0.1, + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_with_gradient_checkpointing(self): # test load with loss module StableDiffuser = StableDiffusion( diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py index fd93bcb..e9d1bc9 100644 --- a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py @@ -87,6 +87,26 @@ def test_train_step(self): assert log_vars assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_input_perturbation(self): + # test load with loss module + StableDiffuser = StableDiffusionXL( + "hf-internal-testing/tiny-stable-diffusion-xl-pipe", + input_perturbation_gamma=0.1, + loss=L2Loss(), + data_preprocessor=SDXLDataPreprocessor()) + + # test train step + data = dict( + inputs=dict( + img=[torch.zeros((3, 64, 64))], + text=["a dog"], + time_ids=[torch.zeros((1, 6))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + def test_train_step_with_gradient_checkpointing(self): # test load with loss module StableDiffuser = StableDiffusionXL( From 56c070c2c4dc37bd7c0d11b2dc5b14e0e18e859c Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 3 Nov 2023 02:25:58 +0000 Subject: [PATCH 3/6] fix input_perturbation --- diffengine/models/editors/ssd_1b/ssd_1b.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/diffengine/models/editors/ssd_1b/ssd_1b.py b/diffengine/models/editors/ssd_1b/ssd_1b.py index e3b3f35..bccc0b0 100644 --- a/diffengine/models/editors/ssd_1b/ssd_1b.py +++ b/diffengine/models/editors/ssd_1b/ssd_1b.py @@ -53,6 +53,9 @@ class SSD1B(StableDiffusionXL): :class:`SDXLDataPreprocessor`. noise_generator (dict, optional): The noise generator config. Defaults to ``dict(type='WhiteNoise')``. + input_perturbation_gamma (float): The gamma of input perturbation. + The recommended value is 0.1 for Input Perturbation. + Defaults to 0.0. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -74,6 +77,7 @@ def __init__( prediction_type: str | None = None, data_preprocessor: dict | nn.Module | None = None, noise_generator: dict | None = None, + input_perturbation_gamma: float = 0.0, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -99,6 +103,7 @@ def __init__( self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing self.pre_compute_text_embeddings = pre_compute_text_embeddings + self.input_perturbation_gamma = input_perturbation_gamma if pre_compute_text_embeddings: assert not finetune_text_encoder From 078f606c0ea6de6602cebb297b13047a96e57d94 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 3 Nov 2023 04:26:07 +0000 Subject: [PATCH 4/6] fix test --- .../test_hooks/test_compile_hook.py | 3 +++ .../test_hooks/test_controlnet_save_hook.py | 3 +++ .../test_hooks/test_fast_norm_hook.py | 3 +++ .../test_hooks/test_ip_adapter_save_hook.py | 3 +++ .../test_hooks/test_lora_save_hook.py | 3 +++ .../test_hooks/test_t2i_adapter_save_hook.py | 3 +++ .../test_hooks/test_visualization_hook.py | 3 +++ .../test_editors/test_ssd_1b/test_ssd_1b.py | 23 ++++++++++++++----- 8 files changed, 38 insertions(+), 6 deletions(-) diff --git a/tests/test_engine/test_hooks/test_compile_hook.py b/tests/test_engine/test_hooks/test_compile_hook.py index 522884c..8fca268 100644 --- a/tests/test_engine/test_hooks/test_compile_hook.py +++ b/tests/test_engine/test_hooks/test_compile_hook.py @@ -13,6 +13,7 @@ StableDiffusionXL, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class TestCompileHook(RunnerTestCase): @@ -26,6 +27,7 @@ def setUp(self) -> None: MODELS.register_module( name="SDXLDataPreprocessor", module=SDXLDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self) -> None: @@ -34,6 +36,7 @@ def tearDown(self) -> None: MODELS.module_dict.pop("SDDataPreprocessor") MODELS.module_dict.pop("SDXLDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_init(self) -> None: diff --git a/tests/test_engine/test_hooks/test_controlnet_save_hook.py b/tests/test_engine/test_hooks/test_controlnet_save_hook.py index 72c66b6..73da845 100644 --- a/tests/test_engine/test_hooks/test_controlnet_save_hook.py +++ b/tests/test_engine/test_hooks/test_controlnet_save_hook.py @@ -14,6 +14,7 @@ StableDiffusionControlNet, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class DummyWrapper(BaseModel): @@ -38,6 +39,7 @@ def setUp(self) -> None: name="SDControlNetDataPreprocessor", module=SDControlNetDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self): @@ -45,6 +47,7 @@ def tearDown(self): MODELS.module_dict.pop("StableDiffusionControlNet") MODELS.module_dict.pop("SDControlNetDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_init(self): diff --git a/tests/test_engine/test_hooks/test_fast_norm_hook.py b/tests/test_engine/test_hooks/test_fast_norm_hook.py index a787053..5685f09 100644 --- a/tests/test_engine/test_hooks/test_fast_norm_hook.py +++ b/tests/test_engine/test_hooks/test_fast_norm_hook.py @@ -13,6 +13,7 @@ StableDiffusionXL, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise try: import apex @@ -31,6 +32,7 @@ def setUp(self) -> None: MODELS.register_module( name="SDXLDataPreprocessor", module=SDXLDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self) -> None: @@ -39,6 +41,7 @@ def tearDown(self) -> None: MODELS.module_dict.pop("SDDataPreprocessor") MODELS.module_dict.pop("SDXLDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() @unittest.skipIf(apex is None, "apex is not installed") diff --git a/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py b/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py index 027ec11..f4a3d15 100644 --- a/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py +++ b/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py @@ -11,6 +11,7 @@ from diffengine.engine.hooks import IPAdapterSaveHook from diffengine.models.editors import IPAdapterXL, IPAdapterXLDataPreprocessor from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class DummyWrapper(BaseModel): @@ -34,6 +35,7 @@ def setUp(self) -> None: name="IPAdapterXLDataPreprocessor", module=IPAdapterXLDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self): @@ -41,6 +43,7 @@ def tearDown(self): MODELS.module_dict.pop("IPAdapterXL") MODELS.module_dict.pop("IPAdapterXLDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_init(self): diff --git a/tests/test_engine/test_hooks/test_lora_save_hook.py b/tests/test_engine/test_hooks/test_lora_save_hook.py index 603eea8..ac0b9c9 100644 --- a/tests/test_engine/test_hooks/test_lora_save_hook.py +++ b/tests/test_engine/test_hooks/test_lora_save_hook.py @@ -16,6 +16,7 @@ StableDiffusionXL, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class DummyWrapper(BaseModel): @@ -42,6 +43,7 @@ def setUp(self) -> None: MODELS.register_module( name="SDXLDataPreprocessor", module=SDXLDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self): @@ -51,6 +53,7 @@ def tearDown(self): MODELS.module_dict.pop("SDDataPreprocessor") MODELS.module_dict.pop("SDXLDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_init(self): diff --git a/tests/test_engine/test_hooks/test_t2i_adapter_save_hook.py b/tests/test_engine/test_hooks/test_t2i_adapter_save_hook.py index ee60ac3..cf483e5 100644 --- a/tests/test_engine/test_hooks/test_t2i_adapter_save_hook.py +++ b/tests/test_engine/test_hooks/test_t2i_adapter_save_hook.py @@ -14,6 +14,7 @@ StableDiffusionXLT2IAdapter, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class DummyWrapper(BaseModel): @@ -39,6 +40,7 @@ def setUp(self) -> None: name="SDXLControlNetDataPreprocessor", module=SDXLControlNetDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self): @@ -46,6 +48,7 @@ def tearDown(self): MODELS.module_dict.pop("StableDiffusionXLT2IAdapter") MODELS.module_dict.pop("SDXLControlNetDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_init(self): diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 3c9bdd8..d696a13 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -13,6 +13,7 @@ StableDiffusionControlNet, ) from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class TestVisualizationHook(RunnerTestCase): @@ -27,6 +28,7 @@ def setUp(self) -> None: name="SDControlNetDataPreprocessor", module=SDControlNetDataPreprocessor) MODELS.register_module(name="L2Loss", module=L2Loss) + MODELS.register_module(name="WhiteNoise", module=WhiteNoise) return super().setUp() def tearDown(self): @@ -35,6 +37,7 @@ def tearDown(self): MODELS.module_dict.pop("StableDiffusionControlNet") MODELS.module_dict.pop("SDControlNetDataPreprocessor") MODELS.module_dict.pop("L2Loss") + MODELS.module_dict.pop("WhiteNoise") return super().tearDown() def test_after_train_epoch(self): diff --git a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py index d6713b5..25464c2 100644 --- a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py +++ b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py @@ -2,6 +2,7 @@ import pytest import torch +from diffengin.models.utils import WhiteNoise from mmengine.optim import OptimWrapper from torch.optim import SGD @@ -18,6 +19,7 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise(), finetune_text_encoder=True) with pytest.raises( @@ -26,6 +28,7 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise(), lora_config=dict(rank=8)) with pytest.raises( @@ -34,13 +37,15 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise(), student_model_weight="dummy") def test_infer(self): StableDiffuser = SSD1B( "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) # test infer result = StableDiffuser.infer( @@ -77,7 +82,8 @@ def test_infer_with_pre_compute_embs(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", pre_compute_text_embeddings=True, - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) assert not hasattr(StableDiffuser, "tokenizer_one") assert not hasattr(StableDiffuser, "text_encoder_one") @@ -101,7 +107,8 @@ def test_train_step(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) # test train step data = dict( @@ -122,6 +129,7 @@ def test_train_step_with_gradient_checkpointing(self): student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise(), gradient_checkpointing=True) # test train step @@ -143,7 +151,8 @@ def test_train_step_with_pre_compute_embs(self): student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", pre_compute_text_embeddings=True, loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) assert not hasattr(StableDiffuser, "tokenizer_one") assert not hasattr(StableDiffuser, "text_encoder_one") @@ -169,7 +178,8 @@ def test_train_step_dreambooth(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) # test train step data = dict( @@ -192,7 +202,8 @@ def test_val_and_test_step(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor()) + data_preprocessor=SDXLDataPreprocessor(), + noise_generator=WhiteNoise()) # test val_step with pytest.raises(NotImplementedError, match="val_step is not"): From 16284266bea54846b08c85b96ca0dc2d026a068c Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 3 Nov 2023 04:43:26 +0000 Subject: [PATCH 5/6] fix test --- tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py index 25464c2..430d059 100644 --- a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py +++ b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py @@ -2,12 +2,12 @@ import pytest import torch -from diffengin.models.utils import WhiteNoise from mmengine.optim import OptimWrapper from torch.optim import SGD from diffengine.models.editors import SSD1B, SDXLDataPreprocessor from diffengine.models.losses import L2Loss +from diffengine.models.utils import WhiteNoise class TestSSD1B(TestCase): From 5442062bef0ba628c47aceba294e92411a6c3feb Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 3 Nov 2023 05:06:40 +0000 Subject: [PATCH 6/6] fix test --- .../test_editors/test_ssd_1b/test_ssd_1b.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py index 430d059..e360a8e 100644 --- a/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py +++ b/tests/test_models/test_editors/test_ssd_1b/test_ssd_1b.py @@ -7,7 +7,7 @@ from diffengine.models.editors import SSD1B, SDXLDataPreprocessor from diffengine.models.losses import L2Loss -from diffengine.models.utils import WhiteNoise +from diffengine.models.utils import WhiteNoise # noqa class TestSSD1B(TestCase): @@ -19,7 +19,6 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise(), finetune_text_encoder=True) with pytest.raises( @@ -28,7 +27,6 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise(), lora_config=dict(rank=8)) with pytest.raises( @@ -37,15 +35,13 @@ def test_init(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise(), student_model_weight="dummy") def test_infer(self): StableDiffuser = SSD1B( "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) # test infer result = StableDiffuser.infer( @@ -82,8 +78,7 @@ def test_infer_with_pre_compute_embs(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", pre_compute_text_embeddings=True, - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) assert not hasattr(StableDiffuser, "tokenizer_one") assert not hasattr(StableDiffuser, "text_encoder_one") @@ -107,8 +102,7 @@ def test_train_step(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) # test train step data = dict( @@ -129,7 +123,6 @@ def test_train_step_with_gradient_checkpointing(self): student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise(), gradient_checkpointing=True) # test train step @@ -151,8 +144,7 @@ def test_train_step_with_pre_compute_embs(self): student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", pre_compute_text_embeddings=True, loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) assert not hasattr(StableDiffuser, "tokenizer_one") assert not hasattr(StableDiffuser, "text_encoder_one") @@ -178,8 +170,7 @@ def test_train_step_dreambooth(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) # test train step data = dict( @@ -202,8 +193,7 @@ def test_val_and_test_step(self): "hf-internal-testing/tiny-stable-diffusion-xl-pipe", student_model="hf-internal-testing/tiny-stable-diffusion-xl-pipe", loss=L2Loss(), - data_preprocessor=SDXLDataPreprocessor(), - noise_generator=WhiteNoise()) + data_preprocessor=SDXLDataPreprocessor()) # test val_step with pytest.raises(NotImplementedError, match="val_step is not"):