Skip to content

Commit

Permalink
Add more ControlNet examples and MultiControlNet tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nngokhale committed Jan 19, 2024
1 parent 7129fc7 commit cbaab7e
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 0 deletions.
34 changes: 34 additions & 0 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ python text_to_image_generation.py \

ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image.

Here is how to generate images conditioned by canny edge model:
```bash
pip install -r requirements.txt
Expand All @@ -223,6 +224,39 @@ python text_to_image_generation.py \
--bf16
```

Here is how to generate images conditioned by canny edge model and with multiple prompts:
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \
--prompts "futuristic-looking woman" "a rusty robot" \
--control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
--num_images_per_prompt 10 \
--batch_size 4 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```

Here is how to generate images conditioned by open pose model:
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path lllyasviel/sd-controlnet-openpose \
--prompts "Chef in the kitchen" \
--control_image https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png \
--num_images_per_prompt 20 \
--batch_size 4 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```

## Textual Inversion

Expand Down
293 changes: 293 additions & 0 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.testing_utils import slow
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel

from optimum.habana import GaudiConfig
from optimum.habana.diffusers import (
Expand Down Expand Up @@ -1182,6 +1183,7 @@ def init_weights(m):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)

torch.manual_seed(0)
controlnet = ControlNetModel(
block_out_channels=(4, 8),
layers_per_block=2,
Expand Down Expand Up @@ -1427,3 +1429,294 @@ def test_stable_diffusion_controlnet_hpu_graphs(self):

self.assertEqual(len(images), 10)
self.assertEqual(images[-1].shape, (64, 64, 3))

class GaudiStableDiffusionMultiControlNetPipelineTester(TestCase):
"""
Tests the StableDiffusionControlNetPipeline for Gaudi.
"""

def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
time_cond_proj_dim=time_cond_proj_dim,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=1,
)

def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)

torch.manual_seed(0)
controlnet1 = ControlNetModel(
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
controlnet1.controlnet_down_blocks.apply(init_weights)

torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
controlnet2.controlnet_down_blocks.apply(init_weights)

scheduler = GaudiDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

controlnet = MultiControlNetModel([controlnet1, controlnet2])

components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components

def get_dummy_inputs(self, device, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
controlnet_embedder_scale_factor = 2
images = [
randn_tensor(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
generator=generator,
device=torch.device(device),
),
randn_tensor(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
generator=generator,
device=torch.device(device),
),
]
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"image": images,
}
return inputs

def test_stable_diffusion_multicontrolnet_num_images_per_prompt(self):
components = self.get_dummy_components()
gaudi_config = GaudiConfig()

sd_pipe = GaudiStableDiffusionControlNetPipeline(
use_habana=True,
gaudi_config=gaudi_config,
**components,
)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device="cpu")
prompt = inputs["prompt"]
# Test num_images_per_prompt=1 (default)
images = sd_pipe(**inputs).images

self.assertEqual(len(images), 1)
self.assertEqual(images[0].shape, (64, 64, 3))

# Test num_images_per_prompt=1 (default) for several prompts
num_prompts = 3
inputs["prompt"] = [prompt] * num_prompts
images = sd_pipe(**inputs).images

self.assertEqual(len(images), num_prompts)
self.assertEqual(images[-1].shape, (64, 64, 3))

# Test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs["prompt"] = prompt
images = sd_pipe(
num_images_per_prompt=num_images_per_prompt,
**inputs
).images

self.assertEqual(len(images), num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

## Test num_images_per_prompt for several prompts
num_prompts = 2
inputs["prompt"] = [prompt] * num_prompts
images = sd_pipe(
num_images_per_prompt=num_images_per_prompt,
**inputs
).images

self.assertEqual(len(images), num_prompts * num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

def test_stable_diffusion_multicontrolnet_batch_sizes(self):
components = self.get_dummy_components()
gaudi_config = GaudiConfig()

sd_pipe = GaudiStableDiffusionControlNetPipeline(
use_habana=True,
gaudi_config=gaudi_config,
**components,
)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device="cpu")
prompt = inputs["prompt"]
# Test batch_size > 1 where batch_size is a divider of the total number of generated images
batch_size = 3
num_images_per_prompt = batch_size**2
images = sd_pipe(
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
**inputs,
).images
self.assertEqual(len(images), num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

# Same test for several prompts
num_prompts = 3
inputs["prompt"] = [prompt] * num_prompts

images = sd_pipe(
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
**inputs,
).images

self.assertEqual(len(images), num_prompts * num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

inputs["prompt"] = prompt
# Test batch_size when it is not a divider of the total number of generated images for a single prompt
num_images_per_prompt = 7
images = sd_pipe(
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
**inputs,
).images

self.assertEqual(len(images), num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

# Same test for several prompts
num_prompts = 2
inputs["prompt"] = [prompt] * num_prompts
images = sd_pipe(
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
**inputs
).images

self.assertEqual(len(images), num_prompts * num_images_per_prompt)
self.assertEqual(images[-1].shape, (64, 64, 3))

def test_stable_diffusion_multicontrolnet_bf16(self):
"""Test that stable diffusion works with bf16"""
components = self.get_dummy_components()
gaudi_config = GaudiConfig()

sd_pipe = GaudiStableDiffusionControlNetPipeline(
use_habana=True,
gaudi_config=gaudi_config,
**components,
)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device="cpu")
image = sd_pipe(**inputs).images[0]

self.assertEqual(image.shape, (64, 64, 3))

def test_stable_diffusion_multicontrolnet_default(self):
components = self.get_dummy_components()

sd_pipe = GaudiStableDiffusionControlNetPipeline(
use_habana=True,
gaudi_config="Habana/stable-diffusion",
**components,
)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device="cpu")
inputs["prompt"] = [inputs["prompt"]] * 2
images = sd_pipe(
batch_size=3,
num_images_per_prompt=5,
**inputs,
).images

self.assertEqual(len(images), 10)
self.assertEqual(images[-1].shape, (64, 64, 3))

def test_stable_diffusion_multicontrolnet_hpu_graphs(self):
components = self.get_dummy_components()

sd_pipe = GaudiStableDiffusionControlNetPipeline(
use_habana=True,
use_hpu_graphs=True,
gaudi_config="Habana/stable-diffusion",
**components,
)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device="cpu")
inputs["prompt"] = [inputs["prompt"]] * 2

images = sd_pipe(
batch_size=3,
num_images_per_prompt=5,
**inputs,
).images

self.assertEqual(len(images), 10)
self.assertEqual(images[-1].shape, (64, 64, 3))

0 comments on commit cbaab7e

Please sign in to comment.