Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inpainting example script #241

Merged
merged 3 commits into from
Aug 26, 2022
Merged

add inpainting example script #241

merged 3 commits into from
Aug 26, 2022

Conversation

nagolinc
Copy link
Contributor

This script is a copy of https://github.com/huggingface/diffusers/blob/main/examples/inference/image_to_image.py but for inpainting.

example usage


from examples.inference.inpainting import StableDiffusionInpaintingPipeline

pipeimg = StableDiffusionInpaintingPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="fp16", 
    torch_dtype=torch.float16,
    use_auth_token=True,
    cache_dir="/content/drive/MyDrive/AI/StableDiffusion"
).to("cuda")

num_samples = 1


def infer(prompt, init_image, mask_image, strength=0.75):
  if init_image != None:
      init_image = init_image.resize((512, 512))
      init_image = preprocess(init_image)
      with autocast("cuda"):
          images = pipeimg([prompt] * num_samples, init_image=init_image,mask_image=mask_image,strength=strength, guidance_scale=7.5)["sample"]
  else: 
      with autocast("cuda"):
          images = pipe([prompt] * num_samples, strength=strength, guidance_scale=7.5)["sample"]

  return images


init_img=Image.open(init_img_path)
mask_img=Image.open(mask_img_path)
imgOut=infer(prompt,init_img,mask_img,0.9)


@ghost
Copy link

ghost commented Aug 24, 2022

It may be worth adding insights from the RePaint paper to avoid having semantic artifacts in the inpainting. All it would take would be additional backtracking (Xt -> Xt-1 -> masked addition -> noising -> Xt -> Xt-1).

section 4.2. Resampling in the paper

https://arxiv.org/pdf/2201.09865.pdf

https://github.com/andreas128/RePaint

@ghost
Copy link

ghost commented Aug 24, 2022

I believe this inpainting attempt may be defective, it adds the masked initial latent at every time step, should be adding noised latents.

init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )

results are still semantically poor, whereas before it was all blurry. RePaint may be the only way to go.

@nagolinc
Copy link
Contributor Author

I believe this inpainting attempt may be defective, it adds the masked initial latent at every time step, should be adding noised latents.

init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )

results are still semantically poor, whereas before it was all blurry. RePaint may be the only way to go.

Thanks! That change does improve the results a lot.

I will try to read the RePaint paper when I get a chance (although it looks like work is ongoing on that). This was just me trying the first thing that seemed to work.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @nagolinc for adding this example! I just left a some nits. Let me know if you want to tackle those, otherwise I would be happy to do it :)

Also @jackloomen said, the result are not perfect but are very interesting nonetheless. This would be a cool simple example for in-painting.

I've added a colab here to play with it. https://colab.research.google.com/drive/196L1Kfodck2ZXkdIdLXPCGP2PMwJ2d5z?usp=sharing.

Think we can merge this as an initial simple example for in-painting @anton-l @patrickvonplaten

@patil-suraj patil-suraj requested a review from anton-l August 25, 2022 07:44
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice addition @nagolinc !

The only thing that I'd feel quite strongly about is "API" consistency here in a sense that both mask_image and init_image will both be processed inside the pipeline

Is it ok if we tweet about your contribution here as it's the first major "community" contribution to the repo? I'd maybe tag you in a tweet (https://twitter.com/nagolinc) if that's ok?

@leszekhanusz
Copy link
Contributor

leszekhanusz commented Aug 25, 2022

Proposition: could we merge all the different stable diffusion pipelines into a single one ?

Currently there is:

  • StableDiffusionPipeline using only prompts
  • StableDiffusionImg2ImgPipeline using an init_image
  • StableDiffusionInpaintingPipeline using an init_image and mask_image, the init_image here having a different purpose than the init_image in the StableDiffusionImg2ImgPipeline

Would'nt it make more sense to have a single pipeline doing all of the above (including inpainting and imgtoimg at the same time if needed).

This pipeline would have as argument:

  • init_image the initial image for img to img
  • inpainting_image the image used for inpainting
  • inpainting_mask the mask uses for inpainting

There would be 4 ways to use it:

  • no images provided -> like StableDiffusionPipeline
  • init_image provided -> like StableDiffusionImg2ImgPipeline
  • inpainting images provided -> StableDiffusionInpaintingPipeline
  • and finally if you provide both init_image and inpainting images, it could do inpainting using an img to img processing inside the masked area

@ghost
Copy link

ghost commented Aug 25, 2022

If the schedulers could implement an undo_step() operation, we could get RePaint working easily I think.

@nagolinc
Copy link
Contributor Author

If the schedulers could implement an undo_step() operation, we could get RePaint working easily I think.

It looks like @anton-l has written a RePaint Scheduler here (https://github.com/huggingface/diffusers/blob/aa6da5ad722ff361a599d2196e2be91f06744813/src/diffusers/schedulers/scheduling_repaint.py), so after this is merged, I will see if I can get that working with latent-diffusion.

@nagolinc
Copy link
Contributor Author

Super nice addition @nagolinc !

The only thing that I'd feel quite strongly about is "API" consistency here in a sense that both mask_image and init_image will both be processed inside the pipeline

Is it ok if we tweet about your contribution here as it's the first major "community" contribution to the repo? I'd maybe tag you in a tweet (https://twitter.com/nagolinc) if that's ok?

Sounds great!

@brthor
Copy link

brthor commented Aug 25, 2022

If the schedulers could implement an undo_step() operation, we could get RePaint working easily I think.

@jackloomen

The undo_step() is easy enough to bring into the pipeline loop, but through our tests, resampling the image back to x_t with any jump length greater than 2 collapses the noise into a solid color or slight gradient in the masked areas. Also there doesn't seem to be any discernable quality improvement.

We (@EteriaAI) integrated #243 into the pipeline in this pull request in 3 different implementations, and none of them seem to produce good results.

It's possible there's an error in our implementations though, so I am not attempting to discourage further experimentation here.

EDIT: the noise sampled technique in this pipeline works okay, we posted a sample image here

@HammadB
Copy link

HammadB commented Aug 25, 2022

This is awesome. Thanks @nagolinc . Just a thought - the approach here seems similar to https://arxiv.org/pdf/2206.02779.pdf and a I wonder if borrowing the insight around fine-tuning the decoder’s weights is worthwhile to explore in the future here?

Algorithm from linked paper for reference
Screen Shot 2022-08-25 at 11 26 32 AM

@ghost
Copy link

ghost commented Aug 25, 2022

If the schedulers could implement an undo_step() operation, we could get RePaint working easily I think.

@jackloomen

The undo_step() is easy enough to bring into the pipeline loop, but through our tests, resampling the image back to x_t with any jump length greater than 2 collapses the noise into a solid color or slight gradient in the masked areas. Also there doesn't seem to be any discernable quality improvement.

We (@EteriaAI) integrated #243 into the pipeline in this pull request in 3 different implementations, and none of them seem to produce good results.

It's possible there's an error in our implementations though, so I am not attempting to discourage further experimentation here.

EDIT: the noise sampled technique in this pipeline works okay, we posted a sample image here

I'm getting the same results, but there should be no reason for it to happen. Worst case scario nothing meaningful would change in the results.

I can only assume that undo_step was done incorrectly, or there is some mismatch in what noise level is applied.

@brthor
Copy link

brthor commented Aug 26, 2022

I can only assume that undo_step was done incorrectly, or there is some mismatch in what noise level is applied.

@jackloomen

undo step seems to be correct because you can apply the same at any timestep t (in a regular non-repaint loop) to get more noise in the whole image, which doesn't collapse into a single color.

I output some debug images at each step in the inference process and you can slowly see the noise in the known area decreasing but the noise in the unknown area just collapses to a single color really fast.

Increasing the 'r' or 'j' values exacerbates the problem.

It looks like in some way the jump and undo causes a mismatch in the noise levels between masked and unmasked parts of the image and then that probably causes some kind of issue with the model.

@nagolinc
Copy link
Contributor Author

@leszekhanusz

I added a unified pipeline (examples/inference/unified.py) that does diffusion if you supply neither a init_image nor a mask, img2img if you just give it init_image and inpainting if you provide both.

mask = torch.cat([mask] * batch_size)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError(f"The mask and init_image should be the same size!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flake8 reports: F541 f-string is missing placeholders, the f is not necessary

@leszekhanusz
Copy link
Contributor

@leszekhanusz

I added a unified pipeline (examples/inference/unified.py) that does diffusion if you supply neither a init_image nor a mask, img2img if you just give it init_image and inpainting if you provide both.

That's great! It works well, but you can't do a pure inpainting now (without image help). I was thinking having three images as input (init_image (drawing for example) , inpainting_image and inpainting_mask)

@leszekhanusz
Copy link
Contributor

To help everyone to test this, in attachment is the same code as a script with arguments, executable directly:
unified.zip

@nagolinc
Copy link
Contributor Author

@leszekhanusz
I added a unified pipeline (examples/inference/unified.py) that does diffusion if you supply neither a init_image nor a mask, img2img if you just give it init_image and inpainting if you provide both.

That's great! It works well, but you can't do a pure inpainting now (without image help). I was thinking having three images as input (init_image (drawing for example) , inpainting_image and inpainting_mask)

@leszekhanusz I’m not sure I follow. By pure inpainting do you mean something different than inpainting with strength=1 (so init image has no effect on inpainted regions)?

@patil-suraj
Copy link
Contributor

Proposition: could we merge all the different stable diffusion pipelines into a single one ?

Currently there is:

  • StableDiffusionPipeline using only prompts
  • StableDiffusionImg2ImgPipeline using an init_image
  • StableDiffusionInpaintingPipeline using an init_image and mask_image, the init_image here having a different purpose than the init_image in the StableDiffusionImg2ImgPipeline

Would'nt it make more sense to have a single pipeline doing all of the above (including inpainting and imgtoimg at the same time if needed).

This pipeline would have as argument:

  • init_image the initial image for img to img
  • inpainting_image the image used for inpainting
  • inpainting_mask the mask uses for inpainting

There would be 4 ways to use it:

  • no images provided -> like StableDiffusionPipeline
  • init_image provided -> like StableDiffusionImg2ImgPipeline
  • inpainting images provided -> StableDiffusionInpaintingPipeline
  • and finally if you provide both init_image and inpainting images, it could do inpainting using an img to img processing inside the masked area

Hi @leszekhanusz , those are really good points !

As stated in the diffusers philosophy, in diffusers we prefer to provide clean, readable examples rather than examples that are optimized and cover many use-cases. This is because we want the examples to be simple so users could easily follow it, understand it and customize it according to their needs. Having all of these tasks in one pipeline would complicate the code with multiple if/else branches which would make it hard for many users to understand and tweak it. With pipelines the idea is: one pipeline -> one task.

I'm very much in favor of adding the in-painting examples, it's a really nice example, but not in favor of unified pipeline. This PR is good for merge once we remove the new example :) Thanks a lot for working on this @nagolinc , great work!

@patil-suraj
Copy link
Contributor

patil-suraj commented Aug 26, 2022

Also , very nice discussion here. We have a discord channel for such discussions, feel free to join if you are interested.

@nagolinc
Copy link
Contributor Author

Also , very nice discussion here. We have a discord channel for such discussions, feel free to join if you are interested.

@patil-suraj Okay, I've moved the unified to a different branch (main...nagolinc:diffusers:unified)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! @nagolinc Do you have any colab that we could link from readme ?

no worries if not, we could link the colab that I shared if you want :)

@nagolinc
Copy link
Contributor Author

Awesome! @nagolinc Do you have any colab that we could link from readme ?

no worries if not, we could link the colab that I shared if you want :)

@patil-suraj this is the notebook I have been using for testing: https://github.com/nagolinc/notebooks/blob/main/inpainting.ipynb

@patil-suraj
Copy link
Contributor

Great, would it be alright if we use the notebook I shared, want to use something simpler :)

@nagolinc
Copy link
Contributor Author

Great, would it be alright if we use the notebook I shared, want to use something simpler :)

@patil-suraj For sure

@patil-suraj patil-suraj merged commit bb4d605 into huggingface:main Aug 26, 2022
@ghost
Copy link

ghost commented Aug 26, 2022

looking closer at it, shouldn't the original latent be noised to t-1 instead of t, and for t=1 no noise at all?

if t > 1:
    init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t-1)
else:
    init_latents_proper = init_latents_orig

@leszekhanusz
Copy link
Contributor

@leszekhanusz I’m not sure I follow. By pure inpainting do you mean something different than inpainting with strength=1 (so init image has no effect on inpainted regions)?

Yes, that's what I meant. I see now that your approach with only 2 images is perfectly fine.
But I did some experimentation today and I was not able to do a correct inpainting with strength=1.
I've put my results in the issue #261

@leszekhanusz
Copy link
Contributor

Also , very nice discussion here. We have a discord channel for such discussions, feel free to join if you are interested.

Thanks, but it appears that this link is not available for me. Should I join somewhere first?

@ghost
Copy link

ghost commented Aug 27, 2022

some results, using some modifications:

def my_preprocess(image): #here image is a numpy array
    image = Image.fromarray(image) #remove if input is PIL
    w, h = image.size
    if w > 512:
      h = int(h * (512/w))
      w = 512
    if h > 512:
      w = int(w*(512/h))
      h = 512
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64, 32 can sometimes result in tensor mismatch errors

    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    print(image.size)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def my_preprocess_mask(mask): #here mask is a numpy array
    mask = Image.fromarray(mask) #remove if input is PIL
    mask = mask.convert("L")
    w, h = mask.size
    if w > 512:
      h = int(h * (512/w))
      w = 512
    if h > 512:
      w = int(w*(512/h))
      h = 512
    w, h = map(lambda x: x - x % 64, (w, h)) 
    w //= 8
    h //= 8
    mask = mask.resize((w, h), resample=PIL.Image.LANCZOS)
    print(mask.size)
    #mask = mask.resize((64,64), resample=PIL.Image.LANCZOS)
    mask = np.array(mask).astype(np.float32) / 255.0
    mask = np.tile(mask,(4,1,1))
    mask = mask[None].transpose(0, 1, 2, 3)
    mask[np.where(mask != 0.0 )] = 1.0 #make sure mask is actually valid
    mask = torch.from_numpy(mask)
    return mask #may need to 1-mask depending on goal of mask selection
              #masking
              if t > 1:
                t_noise = torch.randn(latents.shape, generator=generator, device=self.device)
                init_latents_proper = self.scheduler.add_noise(init_latents_orig, t_noise, t-1)
                latents = init_latents_proper * mask    +    latents * (1-mask)
              else:
                latents = init_latents_orig * mask    +    latents * (1-mask)

prompt: golden retriever, dog, depth of field, centered, photo

original: (from prompt)
https://cdn.discordapp.com/attachments/1004159122335354970/1012917323008589865/000.png

mask is dog face and 2 grass spots on either side.

inpainted: (using the same prompt)
https://cdn.discordapp.com/attachments/1004159122335354970/1012917324128469054/333.png
https://cdn.discordapp.com/attachments/1004159122335354970/1012917324833095740/555.png
https://cdn.discordapp.com/attachments/1004159122335354970/1012917324447232050/444.png
https://cdn.discordapp.com/attachments/1004159122335354970/1012917323717414972/222.png
https://cdn.discordapp.com/attachments/1004159122335354970/1012917323381878864/111.png

No shared seeds. 60-130 steps.

This was done without resampling anything, just standard. I suspect results highly depend on whether the original came from the model, when trying custom non generated images results are always poorer. Using the same prompt is probably also important.

If there is a way to come up with a 'text' embedding based on an arbitrary image, it may improve results for non SD generated images.

natolambert pushed a commit that referenced this pull request Sep 7, 2022
* add inpainting

* added proper noising of init_latent as reccommened by jackloomen (#241 (comment))

* move image preprocessing inside pipeline and allow non 512x512 mask
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
huggingface#241)

* Upload benchmark results for every test-models workflow (excl. Vulkan)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants