Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ControlNet Pipeline #585

Merged
merged 12 commits into from
Jan 30, 2024
74 changes: 74 additions & 0 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,80 @@ python text_to_image_generation.py \
> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster.
> You can enable this mode with `--use_hpu_graphs`.

### ControlNet

ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image.

Here is how to generate images conditioned by canny edge model:
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \
--prompts "futuristic-looking woman" \
--control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
--num_images_per_prompt 20 \
--batch_size 4 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```

Here is how to generate images conditioned by canny edge model and with multiple prompts:
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \
--prompts "futuristic-looking woman" "a rusty robot" \
--control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
--num_images_per_prompt 10 \
--batch_size 4 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```

Here is how to generate images conditioned by open pose model:
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path lllyasviel/sd-controlnet-openpose \
--prompts "Chef in the kitchen" \
--control_image https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png \
--control_preprocessing_type "none" \
--num_images_per_prompt 20 \
--batch_size 4 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```

Here is how to generate images with conditioned by canny edge model using Stable Diffusion 2
```bash
pip install -r requirements.txt
python text_to_image_generation.py \
--model_name_or_path stabilityai/stable-diffusion-2-1 \
--controlnet_model_name_or_path thibaud/controlnet-sd21-canny-diffusers \
--control_image https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png \
--control_preprocessing_type "none" \
--prompts "bird" \
--seed 0 \
--num_images_per_prompt 10 \
--batch_size 2 \
--image_save_dir /tmp/controlnet-2-1_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion-2
```

## Textual Inversion

Expand Down
1 change: 1 addition & 0 deletions examples/stable-diffusion/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
opencv-python
89 changes: 86 additions & 3 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
from pathlib import Path

import numpy as np
import torch

from optimum.habana.diffusers import (
Expand Down Expand Up @@ -53,6 +54,13 @@ def main():
help="Path to pre-trained model",
)

parser.add_argument(
"--controlnet_model_name_or_path",
default="lllyasviel/sd-controlnet-canny",
type=str,
help="Path to pre-trained model",
)

parser.add_argument(
"--scheduler",
default="ddim",
Expand Down Expand Up @@ -83,6 +91,21 @@ def main():
default=None,
help="The second prompt or prompts to guide the image generation (applicable to SDXL).",
)
parser.add_argument(
"--control_image",
type=str,
default=None,
help=("Path to the controlnet conditioning image"),
)
parser.add_argument(
"--control_preprocessing_type",
type=str,
default="canny",
help=(
"The type of preprocessing to apply on contol image. Only `canny` is supported."
" Defaults to `canny`. Set to unsupported value to disable preprocessing."
),
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=1, help="The number of images to generate per prompt."
)
Expand Down Expand Up @@ -179,7 +202,18 @@ def main():
parser.add_argument(
"--ldm3d", action="store_true", help="Use LDM3D to generate an image and a depth map from a given text prompt."
)

parser.add_argument(
"--profiling_warmup_steps",
default=0,
type=int,
help="Number of steps to ignore for profiling.",
)
parser.add_argument(
"--profiling_steps",
default=0,
type=int,
help="Number of steps to capture for profiling.",
)
args = parser.parse_args()

# Set image resolution
Expand All @@ -188,10 +222,33 @@ def main():
res["width"] = args.width
res["height"] = args.height

# ControlNet
if args.control_image is not None:
from diffusers.utils import load_image
from PIL import Image

# get control image
control_image = load_image(args.control_image)
if args.control_preprocessing_type == "canny":
import cv2

image = np.array(control_image)
# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
control_image = Image.fromarray(image)

# Import selected pipeline
sdxl_models = ["stable-diffusion-xl-base-1.0", "sdxl-turbo"]

if any(model in args.model_name_or_path for model in sdxl_models):
if args.control_image is not None:
from diffusers import ControlNetModel

from optimum.habana.diffusers import GaudiStableDiffusionControlNetPipeline

sdxl = False
elif any(model in args.model_name_or_path for model in sdxl_models):
from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline

sdxl = True
Expand Down Expand Up @@ -237,7 +294,33 @@ def main():
kwargs["torch_dtype"] = torch.bfloat16

# Generate images
if sdxl:
if args.control_image is not None:
model_dtype = torch.bfloat16 if args.bf16 else None
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=model_dtype)
pipeline = GaudiStableDiffusionControlNetPipeline.from_pretrained(
args.model_name_or_path,
controlnet=controlnet,
**kwargs,
)

# Set seed before running the model
set_seed(args.seed)

outputs = pipeline(
prompt=args.prompts,
image=control_image,
num_images_per_prompt=args.num_images_per_prompt,
batch_size=args.batch_size,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
negative_prompt=args.negative_prompts,
eta=args.eta,
output_type=args.output_type,
profiling_warmup_steps=args.profiling_warmup_steps,
profiling_steps=args.profiling_steps,
**res,
)
elif sdxl:
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline
from .pipelines.pipeline_utils import GaudiDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline
Expand Down
Loading
Loading