From cbaab7e6455d3d2c08ab9a0fea7c3f6f2fd11948 Mon Sep 17 00:00:00 2001 From: "Gokhale, Neelesh" Date: Fri, 19 Jan 2024 18:26:48 +0000 Subject: [PATCH] Add more ControlNet examples and MultiControlNet tests --- examples/stable-diffusion/README.md | 34 ++++ tests/test_diffusers.py | 293 ++++++++++++++++++++++++++++ 2 files changed, 327 insertions(+) diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index d569bd6a15..795a0e206d 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -206,6 +206,7 @@ python text_to_image_generation.py \ ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image. + Here is how to generate images conditioned by canny edge model: ```bash pip install -r requirements.txt @@ -223,6 +224,39 @@ python text_to_image_generation.py \ --bf16 ``` +Here is how to generate images conditioned by canny edge model and with multiple prompts: +```bash +pip install -r requirements.txt +python text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \ + --prompts "futuristic-looking woman" "a rusty robot" \ + --control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \ + --num_images_per_prompt 10 \ + --batch_size 4 \ + --image_save_dir /tmp/controlnet_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +Here is how to generate images conditioned by open pose model: +```bash +pip install -r requirements.txt +python text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --controlnet_model_name_or_path lllyasviel/sd-controlnet-openpose \ + --prompts "Chef in the kitchen" \ + --control_image https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png \ + --num_images_per_prompt 20 \ + --batch_size 4 \ + --image_save_dir /tmp/controlnet_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` ## Textual Inversion diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 0a0724cf98..134d5cb4fe 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -32,6 +32,7 @@ from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.testing_utils import slow +from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from optimum.habana import GaudiConfig from optimum.habana.diffusers import ( @@ -1182,6 +1183,7 @@ def init_weights(m): torch.nn.init.normal(m.weight) m.bias.data.fill_(1.0) + torch.manual_seed(0) controlnet = ControlNetModel( block_out_channels=(4, 8), layers_per_block=2, @@ -1427,3 +1429,294 @@ def test_stable_diffusion_controlnet_hpu_graphs(self): self.assertEqual(len(images), 10) self.assertEqual(images[-1].shape, (64, 64, 3)) + +class GaudiStableDiffusionMultiControlNetPipelineTester(TestCase): + """ + Tests the StableDiffusionControlNetPipeline for Gaudi. + """ + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=32, + time_cond_proj_dim=time_cond_proj_dim, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=1, + ) + + def init_weights(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.normal(m.weight) + m.bias.data.fill_(1.0) + + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(4, 8), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + norm_num_groups=1, + ) + controlnet1.controlnet_down_blocks.apply(init_weights) + + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(4, 8), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + norm_num_groups=1, + ) + controlnet2.controlnet_down_blocks.apply(init_weights) + + scheduler = GaudiDDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + generator = torch.Generator(device=device).manual_seed(seed) + controlnet_embedder_scale_factor = 2 + images = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + "image": images, + } + return inputs + + def test_stable_diffusion_multicontrolnet_num_images_per_prompt(self): + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionControlNetPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + prompt = inputs["prompt"] + # Test num_images_per_prompt=1 (default) + images = sd_pipe(**inputs).images + + self.assertEqual(len(images), 1) + self.assertEqual(images[0].shape, (64, 64, 3)) + + # Test num_images_per_prompt=1 (default) for several prompts + num_prompts = 3 + inputs["prompt"] = [prompt] * num_prompts + images = sd_pipe(**inputs).images + + self.assertEqual(len(images), num_prompts) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + inputs["prompt"] = prompt + images = sd_pipe( + num_images_per_prompt=num_images_per_prompt, + **inputs + ).images + + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + ## Test num_images_per_prompt for several prompts + num_prompts = 2 + inputs["prompt"] = [prompt] * num_prompts + images = sd_pipe( + num_images_per_prompt=num_images_per_prompt, + **inputs + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_multicontrolnet_batch_sizes(self): + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionControlNetPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + prompt = inputs["prompt"] + # Test batch_size > 1 where batch_size is a divider of the total number of generated images + batch_size = 3 + num_images_per_prompt = batch_size**2 + images = sd_pipe( + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + **inputs, + ).images + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Same test for several prompts + num_prompts = 3 + inputs["prompt"] = [prompt] * num_prompts + + images = sd_pipe( + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + **inputs, + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + inputs["prompt"] = prompt + # Test batch_size when it is not a divider of the total number of generated images for a single prompt + num_images_per_prompt = 7 + images = sd_pipe( + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + **inputs, + ).images + + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Same test for several prompts + num_prompts = 2 + inputs["prompt"] = [prompt] * num_prompts + images = sd_pipe( + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + **inputs + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_multicontrolnet_bf16(self): + """Test that stable diffusion works with bf16""" + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionControlNetPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + image = sd_pipe(**inputs).images[0] + + self.assertEqual(image.shape, (64, 64, 3)) + + def test_stable_diffusion_multicontrolnet_default(self): + components = self.get_dummy_components() + + sd_pipe = GaudiStableDiffusionControlNetPipeline( + use_habana=True, + gaudi_config="Habana/stable-diffusion", + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs["prompt"] = [inputs["prompt"]] * 2 + images = sd_pipe( + batch_size=3, + num_images_per_prompt=5, + **inputs, + ).images + + self.assertEqual(len(images), 10) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_multicontrolnet_hpu_graphs(self): + components = self.get_dummy_components() + + sd_pipe = GaudiStableDiffusionControlNetPipeline( + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs["prompt"] = [inputs["prompt"]] * 2 + + images = sd_pipe( + batch_size=3, + num_images_per_prompt=5, + **inputs, + ).images + + self.assertEqual(len(images), 10) + self.assertEqual(images[-1].shape, (64, 64, 3))