Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ControlNet API #162

Merged
merged 4 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ jobs:
curl -Lo "$filename" "$url"
fi
done
- name: Download ControlNet models
run: |
declare -a urls=(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth"
)

for url in "${urls[@]}"; do
filename="models/ControlNet/${url##*/}" # Extracts the last part of the URL
if [ ! -f "$filename" ]; then
curl -Lo "$filename" "$url"
fi
done
- name: Start test server
run: >
python -m coverage run
Expand All @@ -71,6 +83,14 @@ jobs:
run: |
wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Run ControlNet tests
run: >
python -m pytest
--junitxml=test/results.xml
--cov ./extensions-builtin/sd_forge_controlnet
--cov-report=xml
--verify-base-url
./extensions-builtin/sd_forge_controlnet/tests
- name: Kill test server
if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ notification.mp3
/.coverage*
/test/test_outputs
/test/results.xml
coverage.xml
coverage.xml
**/tests/**/expectations
108 changes: 108 additions & 0 deletions extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List

import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import gradio as gr

from modules.api import api
from .global_state import (
get_all_preprocessor_names,
get_all_controlnet_names,
get_preprocessor,
)
from .logging import logger


def encode_to_base64(image):
if isinstance(image, str):
return image
elif isinstance(image, Image.Image):
return api.encode_pil_to_base64(image)
elif isinstance(image, np.ndarray):
return encode_np_to_base64(image)
else:
return ""


def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)


def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/model_list")
async def model_list(update: bool = True):
up_to_date_model_list = get_all_controlnet_names()
logger.debug(up_to_date_model_list)
return {"model_list": up_to_date_model_list}

@app.get("/controlnet/module_list")
async def module_list(alias_names: bool = False):
module_list = get_all_preprocessor_names()
logger.debug(module_list)

return {
"module_list": module_list,
# TODO: Add back module detail.
# "module_detail": external_code.get_modules_detail(alias_names),
}

@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"),
controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"),
controlnet_processor_res: int = Body(
512, title="Controlnet Processor Resolution"
),
controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"),
controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"),
):
processor_module = get_preprocessor(controlnet_module)
if processor_module is None:
raise HTTPException(status_code=422, detail="Module not available")

if len(controlnet_input_images) == 0:
raise HTTPException(status_code=422, detail="No image selected")

logger.debug(
f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module."
)

results = []
poses = []

for input_image in controlnet_input_images:
img = np.array(api.decode_base64_to_image(input_image)).astype('uint8')

class JsonAcceptor:
def __init__(self) -> None:
self.value = None

def accept(self, json_dict: dict) -> None:
self.value = json_dict

json_acceptor = JsonAcceptor()

results.append(
processor_module(
img,
res=controlnet_processor_res,
thr_a=controlnet_threshold_a,
thr_b=controlnet_threshold_b,
json_pose_callback=json_acceptor.accept,
)[0]
)

if "openpose" in controlnet_module:
assert json_acceptor.value is not None
poses.append(json_acceptor.value)

results64 = list(map(encode_to_base64, results))
res = {"images": results64, "info": "Success"}
if poses:
res["poses"] = poses

return res

Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,6 @@ def is_openpose(module: str):
slider_1=pthr_a,
slider_2=pthr_b,
input_mask=mask,
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
json_pose_callback=json_acceptor.accept
if is_openpose(module)
else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ def pixel_perfect_resolution(
@dataclass
class UiControlNetUnit:
input_mode: InputMode = InputMode.SIMPLE
use_preview_as_input: bool = False,
batch_image_dir: str = '',
batch_mask_dir: str = '',
batch_input_gallery: list = [],
batch_mask_gallery: list = [],
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
use_preview_as_input: bool = False
batch_image_dir: str = ''
batch_mask_dir: str = ''
batch_input_gallery: Optional[List[str]] = None
batch_mask_gallery: Optional[List[str]] = None
generated_image: Optional[np.ndarray] = None
mask_image: Optional[np.ndarray] = None
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
enabled: bool = True
module: str = "None"
Expand All @@ -169,6 +171,13 @@ class UiControlNetUnit:
guidance_end: float = 1.0
pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# ====== Start of API only fields ======
# Whether save the detected map of this unit. Setting this option to False
# prevents saving the detected map or sending detected map along with
# generated images via API. Currently the option is only accessible in API
# calls.
save_detected_map: bool = True
# ====== End of API only fields ======

@staticmethod
def infotext_fields():
Expand All @@ -192,6 +201,23 @@ def infotext_fields():
"hr_option",
)

@staticmethod
def from_dict(d: Dict) -> "UiControlNetUnit":
"""Create UiControlNetUnit from dict. This is primarily used to convert
API json dict to UiControlNetUnit."""
unit = UiControlNetUnit(
**{k: v for k, v in d.items() if k in vars(UiControlNetUnit)}
)
if isinstance(unit.image, str):
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
if isinstance(unit.mask_image, str):
unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
return unit


# Backward Compatible
ControlNetUnit = UiControlNetUnit
Expand Down
48 changes: 29 additions & 19 deletions extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
import gradio as gr

from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption
from lib_controlnet.api import controlnet_api

import numpy as np
import functools
Expand Down Expand Up @@ -67,7 +69,7 @@ def ui(self, is_img2img):
max_models = shared.opts.data.get("control_net_unit_count", 3)
gen_type = "img2img" if is_img2img else "txt2img"
elem_id_tabname = gen_type + "_controlnet"
default_unit = UiControlNetUnit(enabled=False, module="None", model="None")
default_unit = ControlNetUnit(enabled=False, module="None", model="None")
with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet",
elem_classes=["controlnet"]):
Expand Down Expand Up @@ -95,13 +97,19 @@ def ui(self, is_img2img):
return tuple(controls)

def get_enabled_units(self, units):
# Parse dict from API calls.
units = [
ControlNetUnit.from_dict(unit) if isinstance(unit, dict) else unit
for unit in units
]
assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled]
return enabled_units

@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
preprocessor
Expand Down Expand Up @@ -252,7 +260,7 @@ def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int,
@torch.no_grad()
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):

Expand All @@ -279,8 +287,6 @@ def optional_tqdm(iterable, use_tqdm):
return tqdm(iterable) if use_tqdm else iterable

for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
# p.extra_result_images.append(input_image)

if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_image,
Expand Down Expand Up @@ -319,31 +325,36 @@ def optional_tqdm(iterable, use_tqdm):
hr_option = HiResFixOption.BOTH

alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
if (
(is_high_res and hr_option.high_res_enabled) or
(not is_high_res and hr_option.low_res_enabled)
) and unit.save_detected_map:
p.extra_result_images.append(img)

if preprocessor_output_is_image:
params.control_cond = []
params.control_cond_for_hr_fix = []

for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
if hr_option.low_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))

params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()

if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
if hr_option.high_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True)
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
params.control_cond_for_hr_fix = params.control_cond
else:
params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image)
attach_extra_result_image(input_image)

if len(control_masks) > 0:
params.control_mask = []
Expand All @@ -352,15 +363,13 @@ def optional_tqdm(iterable, use_tqdm):
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
if hr_option.low_res_enabled:
p.extra_result_images.append(control_mask)
attach_extra_result_image(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)

if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
if hr_option.high_res_enabled:
p.extra_result_images.append(control_mask_for_hr_fix)
attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)

Expand Down Expand Up @@ -390,7 +399,7 @@ def optional_tqdm(iterable, use_tqdm):
@torch.no_grad()
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):

Expand Down Expand Up @@ -473,14 +482,14 @@ def process_unit_before_every_sampling(self,
return

@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
def bound_check_params(unit: ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.

Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
unit (ControlNetUnit): The ControlNetUnit instance to check.
"""
preprocessor = global_state.get_preprocessor(unit.module)

Expand All @@ -498,7 +507,7 @@ def bound_check_params(unit: external_code.ControlNetUnit) -> None:
@torch.no_grad()
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):

Expand Down Expand Up @@ -577,3 +586,4 @@ def on_ui_settings():
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
script_callbacks.on_before_reload(ControlNetUiGroup.reset)
script_callbacks.on_app_started(controlnet_api)
8 changes: 8 additions & 0 deletions extensions-builtin/sd_forge_controlnet/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os


def pytest_configure(config):
# We don't want to fail on Py.test command line arguments being
# parsed by webui:
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
os.environ.setdefault("FORGE_CQ_TEST", "1")
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Loading
Loading