Skip to content

Commit

Permalink
add a from_pipe method to DiffusionPipeline (#7241)
Browse files Browse the repository at this point in the history
* add from_pipe



---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
4 people authored Apr 1, 2024
1 parent 5266ab7 commit 7956c36
Show file tree
Hide file tree
Showing 22 changed files with 675 additions and 63 deletions.
204 changes: 204 additions & 0 deletions docs/source/en/using-diffusers/loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,210 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
)
```

### Switch loaded pippelines

There are many diffuser pipelines that use the same pre-trained model as [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`], but they implement specific features to help you achieve better generation results. This guide will show you how to use the `from_pipe` API to create multiple pipelines without increasing memory usage. By using this approach, you can easily switch between pipelines to use different features.

Let's take an example where we first create a [`StableDiffusionPipeline`] and then reuse the already loaded model components to create a [`StableDiffusionSAGPipeline`] to enhance generation quality.

we will generate an image of a bear eating pizza using Stable Diffusion with the IP-Adapter

```python
from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
import torch
import gc
from diffusers.utils import load_image
from accelerate.utils import compute_module_sizes

base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
num_inference_steps = 50
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
prompt="bear eats pizza"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_sd.set_ip_adapter_scale(0.6)
pipe_sd.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```

let’s take a look at the image and also print out the memory used

<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_0.png"/>
</div>

```python
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
```

```bash
Max memory allocated: 4.406213283538818 GB
```

Now, we can use `from_pipe` to switch to the SAG pipeline.

```python
pipe_sag = StableDiffusionSAGPipeline.from_pipe(
pipe_sd,
)
```

It already has IP-Adapter loaded so that you can pass the same bear image as `ip_adapter_image`

```python
generator = torch.Generator(device="cpu").manual_seed(33)
out_sag = pipe_sag(
prompt = prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=1.0,
sag_scale=0.75).images[0]
```

You can see a pretty nice improvement in the output

<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sag_1.png"/>
</div>

Now we have both `stableDiffusionPipeline` and `StableDiffusionSAGPipeline` co-existing with the same loaded model components; You can use them interchangeably without additional memory.

```
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
```

```bash
Max memory allocated: 4.406213283538818 GB
```

Let's unload the IP adapter from the SAG pipeline. It's important to note that methods like `load_ip_adapter` and `unload_ip_adapter` modify the state of the model components. Therefore, when you use these methods on one pipeline, it will affect all other pipelines that share the same model components.

```bash
pipe_sag.unload_ip_adapter()
```

If you try to use the Stable Diffusion pipeline with IP adapter again, it will fail

```bash
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```

```bash
AttributeError: 'NoneType' object has no attribute 'image_projection_layers'
```

Please note that the pipeline methods may not function properly on a new pipeline created using the `from_pipe` method. For instance, the `enable_model_cpu_offload` method installs hooks to the model components based on a unique offloading sequence for each pipeline. Therefore, if the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.

To ensure proper functionality, we recommend re-applying the pipeline methods on the new pipeline created using the `from_pipe` method.

You can also add or subtract model components when you create new pipelines. Let's now create a AnimateDiff pipeline with an additional `MotionAdapter` module

```bash
from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)

pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
# load ip_adapter again and load lora weights
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe_animate.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
out = pipe_animate(
prompt= prompt,
num_frames=16,
num_inference_steps=num_inference_steps,
ip_adapter_image = image,
generator=generator,
).frames[0]
export_to_gif(out, "out_animate.gif")
```
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_animate_3.gif"/>
</div>


When creating multiple pipelines using the `from_pipe` method, it is important to note that the memory requirement will be determined by the pipeline with the highest memory usage. This means that regardless of the number of pipelines you create, the total memory requirement will always be the same as the highest memory requirement among the pipelines.

For example, we have created three pipelines - `stableDiffusionPipeline`, `StableDiffusionSAGPipeline`, and `AnimateDiffPipeline` - and the `AnimateDiffPipeline` has the highest memory requirement, then the total memory usage will be based on the memory requirement of the `AnimateDiffPipeline`.

Therefore, creating additional pipelines will not add up to the total memory requirement. Each pipeline can be used interchangeably without any additional memory overhead.


Did you know that you can use `from_pipe` with a community pipeline? Let me show you an example of using long negative prompt and prompt weighting!

```bash
pipe_lpw = DiffusionPipeline.from_pipe(
pipe_sd,
custom_pipeline="lpw_stable_diffusion",
).to("cuda")

prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
generator = torch.Generator(device="cpu").manual_seed(33)
out_lpw = pipe_lpw.text2img(
prompt,
negative_prompt=neg_prompt,
width=512,height=512,
max_embeddings_multiples=3,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```

<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_lpw_4.png"/>
</div>

let’s run StableDiffusionPipeline with the same inputs to compare: the result from the long prompt weighting pipeline is more aligned with the text prompt.

```
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=num_inference_steps,
).images[0]
out_sd
```
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_5.png"/>
</div>


You can easily switch between different pipelines using the `from_pipe` method, similar to turning on and off a feature on your pipeline. To switch between tasks, you can use the `from_pipe` method with `AutoPipeline`, which automatically identifies the pipeline class based on the task. You can find more information about this feature at the [AutoPipe Guide](https://huggingface.co/docs/diffusers/tutorials/autopipeline).


## Checkpoint variants

A checkpoint variant is usually a checkpoint whose weights are:
Expand Down
2 changes: 2 additions & 0 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,9 @@ class StableDiffusionLongPromptWeightingPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

model_cpu_offload_seq = "text_encoder-->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]

def __init__(
self,
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn
import torch.utils.checkpoint

from ...configuration_utils import ConfigMixin, register_to_config
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
Expand Down Expand Up @@ -393,8 +393,11 @@ def from_unet2d(
):
has_motion_adapter = motion_adapter is not None

if has_motion_adapter:
motion_adapter.to(device=unet.device)

# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = unet.config
config = dict(unet.config)
config["_class_name"] = cls.__name__

down_blocks = []
Expand Down Expand Up @@ -427,6 +430,7 @@ def from_unet2d(
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]

config = FrozenDict(config)
model = cls.from_config(config)

if not load_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
Expand Down
54 changes: 36 additions & 18 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,39 @@ def get_class_obj_and_candidates(
return class_obj, class_candidates


def _get_custom_pipeline_class(
custom_pipeline,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revision`
revision = hub_revision

return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
)


def _get_pipeline_class(
class_obj,
config=None,
Expand All @@ -304,25 +337,10 @@ def _get_pipeline_class(
revision=None,
):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revision`
revision = hub_revision

return get_class_from_dynamic_module(
return _get_custom_pipeline_class(
custom_pipeline,
module_file=file_name,
repo_id=repo_id,
hub_revision=hub_revision,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
Expand Down
Loading

0 comments on commit 7956c36

Please sign in to comment.