Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IP Adapters] feat: allow low_cpu_mem_usage in ip adapter loading #6946

Merged
merged 5 commits into from
Feb 15, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 12, 2024

What does this PR do?

This PR adds low_cpu_mem_usage support in load_ip_adapter() to speed up loading time.

Script to test:

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
import time
import argparse

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")

def run_sd(low_cpu_mem_usage):
    start = time.time()
    pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
    pipeline.load_ip_adapter(
        "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", low_cpu_mem_usage=low_cpu_mem_usage
    )
    end = time.time()
    print(f"Loading time -- {(end - start):.3f} seconds")

    _ = pipeline(
        prompt="best quality, high quality", 
        ip_adapter_image=image,
        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
        num_inference_steps=2,
    )


def run_sdxl(low_cpu_mem_usage):
    start = time.time()
    pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
    pipeline.load_ip_adapter(
        "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin", low_cpu_mem_usage=low_cpu_mem_usage
    )
    end = time.time()
    print(f"Loading time -- {(end - start):.3f} seconds")
    
    _ = pipeline(
        prompt="best quality, high quality", 
        ip_adapter_image=image,
        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
        num_inference_steps=2,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_sd", action="store_true")
    parser.add_argument("--run_sdxl", action="store_true")
    parser.add_argument("--low_cpu_mem_usage", action="store_true")
    args = parser.parse_args()

    if args.run_sd and args.run_sdxl:
        raise ValueError("Both `run_sd` and `run_sdxl` cannot be True.")

    if not args.run_sd and not args.run_sdxl:
        raise ValueError("Both `run_sd` and `run_sdxl` cannot be False.")

    fn_to_run = run_sd if args.run_sd else run_sdxl
    fn_to_run(low_cpu_mem_usage=args.low_cpu_mem_usage)

On average, passing low_cpu_mem_usage=True in load_ip_adapter() saves about 2-3 seconds.

Will add documentation once the PR is approved.

TODO

  • Documentation

@sayakpaul sayakpaul requested a review from yiyixuxu February 12, 2024 10:22
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! this is great!!!
left a comment, I think we only need to warn once when we actually apply low_cpu_mem_uesage

src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/ip_adapter.py Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

Documentation has been dealt with as well.

@sayakpaul sayakpaul merged commit e6d1728 into main Feb 15, 2024
15 checks passed
@sayakpaul sayakpaul deleted the ip-adapter-low-cpu-mem-usage branch February 15, 2024 10:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants