Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline #5887

Merged
merged 3 commits into from
Nov 29, 2023
Merged

Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline #5887

merged 3 commits into from
Nov 29, 2023

Conversation

juancopi81
Copy link
Contributor

@juancopi81 juancopi81 commented Nov 21, 2023

What does this PR do?

This PR adds support of the ip-adapter to the StableDiffusionControlNetInpaintPipeline. The ip-adapter was added in #5713 and help was asked in #5884 I think it is very cool to have this support (see example below).

This PR refers to #5884

This is how it works:

from diffusers import (
    StableDiffusionControlNetInpaintPipeline,
    ControlNetModel,
    LCMScheduler,
)
import torch
from diffusers.utils import load_image
from PIL import Image
import cv2
import numpy as np

def get_canny_filter(image):
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    canny_image = Image.fromarray(image)
    return canny_image

controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/control_v11p_sd15_canny",
    torch_dtype=torch.float16,
)

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "jayparmr/icbinp_v8_inpaint_v2",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    requires_safety_checker=False,
    safety_checker=None,
)
pipe.to("cuda")

#Load images
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
cn_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
cn_image = get_canny_filter(cn_image)
# Resize images
image = image.resize((512, 768))
mask = mask.resize((512, 768))
cn_image = cn_image.resize((512, 768))

# Load ip-adapter
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

# Generate image
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipe(
    prompt='best quality, high quality', 
    image = image,
    mask_image = mask,
    control_image=cn_image,
    ip_adapter_image=ip_image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_images_per_prompt=1, 
    num_inference_steps=50,
    generator=generator,
    strength=0.5,
    controlnet_conditioning_scale=0.5,
).images
images[0].save("juancopi_test_4_out.png")

Initial image | Mask | ip-adapter image | canny image | output
stablediffusioncontrolnetinpaintpipeline

Final output:
output_final

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. [IP-Adapter] adding IP adapter support to all ControlNet and T2I pipelines #5884
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests? Some test are failing, but I wanted to check if the threshold is ok:
    FAILED tests/pipelines/controlnet/test_controlnet_inpaint.py::ControlNetInpaintPipelineFastTests::test_save_load_local - AssertionError: 0.0016172826 not less than 0.0005
    FAILED tests/pipelines/controlnet/test_controlnet_inpaint.py::ControlNetSimpleInpaintPipelineFastTests::test_save_load_local - AssertionError: 0.00053209066 not less than 0.0005
    FAILED tests/pipelines/controlnet/test_controlnet_inpaint.py::MultiControlNetInpaintPipelineFastTests::test_save_load_local - AssertionError: 0.00051498413 not less than 0.0005

Who can review?

@yiyixuxu Hi @yiyixuxu please let me know any comments 😃 I hope it is ok, I am happy to change anything you need.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@juancopi81 juancopi81 requested a review from yiyixuxu November 22, 2023 12:47
@@ -342,6 +344,7 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add one test related to IP Adapters here :-)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu wdyt?

@wwirving
Copy link

Hi @juancopi81 @yiyixuxu - Can we expect this PR to support StableDiffusionXLControlNetInpaintPipeline?

@juancopi81
Copy link
Contributor Author

Hi @wwirving, not really, this PR is only for the StableDiffusionControlNetInpaintPipeline pipeline.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! looks great :)
can we add a test like @patrickvonplaten requested?

@juancopi81
Copy link
Contributor Author

juancopi81 commented Nov 28, 2023

Sure @patrickvonplaten and @yiyixuxu I'll work on that. Just to be sure, you are expecting something like this:

class MultiControlNetInpaintPipelineFastTests(
    PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
    pipeline_class = StableDiffusionControlNetInpaintPipeline
    params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
    batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS

    def get_dummy_components(self):
        torch.manual_seed(0)
        unet = ...

        ### NEW CODE ####
        torch.manual_seed(0)
        image_encoder_config = CLIPVisionConfig(
            image_size=64,
            num_channels=3,
            hidden_size=32,
            projection_dim=32,
            num_hidden_layers=2,
            num_attention_heads=4,
            intermediate_size=37,
            dropout=0.1,
            attention_dropout=0.1,
            initializer_range=0.02,
            scope=None,
        )
        image_encoder = CLIPVisionModelWithProjection(config=image_encoder_config)

        torch.manual_seed(0)
        feature_extractor_config = {
            "image_size": 64,
            "num_channels": 3,
            "do_resize": True,
            "size": {"shortest_edge": 20},
            "do_center_crop": True,
            "crop_size": {"height": 64, "width": 64},
            "do_normalize": True,
            "image_mean": [0.48145466, 0.4578275, 0.40821073],
            "image_std": [0.26862954, 0.26130258, 0.27577711],
            "do_convert_rgb": True
        }
        feature_extractor = CLIPImageProcessor(**feature_extractor_config)

        components = {
            "unet": unet,
            "controlnet": controlnet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
            "feature_extractor": feature_extractor,
            "image_encoder": image_encoder,
        }
        return components

And something else? Maybe a specific test of the image_adapter? like:

    def test_with_ip_adapter_image_encoder(self):
        components = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe.to(torch_device)

        inputs = self.get_dummy_inputs(torch_device)
        pipe.set_ip_adapter_scale(1.0)
        output_1 = pipe(**inputs)[0]

        inputs = self.get_dummy_inputs(torch_device)
        pipe.set_ip_adapter_scale(0.0)
        output_2 = pipe(**inputs)[0]

        # make sure that all outputs are different
        assert np.sum(np.abs(output_1 - output_2)) > 1e-3
        self.assertEqual(output_2.shape, (1, 64, 64, 3))

Right now this test:

assert np.sum(np.abs(output_1 - output_2)) > 1e-3

fails 😢, but I could take a look into that if that is what you expect, so I wanted to ask first.

EDIT: I changed the code so the feature_extractor gets also initialized in the get_dummy_components function.

@patrickvonplaten
Copy link
Contributor

Actually, I think he should be good here without adapting all the controlnet tests. We also didn't add it here: #5713

@patrickvonplaten patrickvonplaten merged commit 9f7b2cf into huggingface:main Nov 29, 2023
@juancopi81 juancopi81 deleted the ip-adapter-controlnet-inpaint-pipeline-sd branch November 29, 2023 15:09
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…huggingface#5887)

* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet

* Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None

* Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…huggingface#5887)

* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet

* Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None

* Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants