From c5e223492c6f7ed9b09b6808654e832c674425b5 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 29 May 2024 17:05:32 +0800 Subject: [PATCH 01/24] refactor --- .../compile/__init__.py | 9 + .../{ => compile}/compile_ldm.py | 0 .../{ => compile}/compile_sgm.py | 0 .../compile/compile_utils.py | 74 ++++++++ .../{ => compile}/compile_vae.py | 0 .../compile/onediff_compiled_graph.py | 29 +++ onediff_sd_webui_extensions/onediff_hijack.py | 3 +- onediff_sd_webui_extensions/onediff_shared.py | 13 ++ .../scripts/onediff.py | 179 +++++++----------- onediff_sd_webui_extensions/ui_utils.py | 72 ++++++- 10 files changed, 262 insertions(+), 117 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/__init__.py rename onediff_sd_webui_extensions/{ => compile}/compile_ldm.py (100%) rename onediff_sd_webui_extensions/{ => compile}/compile_sgm.py (100%) create mode 100644 onediff_sd_webui_extensions/compile/compile_utils.py rename onediff_sd_webui_extensions/{ => compile}/compile_vae.py (100%) create mode 100644 onediff_sd_webui_extensions/compile/onediff_compiled_graph.py create mode 100644 onediff_sd_webui_extensions/onediff_shared.py diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py new file mode 100644 index 000000000..4d225f4c6 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -0,0 +1,9 @@ +# from .compile_ldm import SD21CompileCtx, compile_ldm_unet +from .compile_ldm import SD21CompileCtx + +# from .compile_sgm import compile_sgm_unet +from .compile_vae import VaeCompileCtx + +# from .compile_utils import compile_unet, get_compiled_unet +from .compile_utils import get_compiled_graph +from .onediff_compiled_graph import OneDiffCompiledGraph diff --git a/onediff_sd_webui_extensions/compile_ldm.py b/onediff_sd_webui_extensions/compile/compile_ldm.py similarity index 100% rename from onediff_sd_webui_extensions/compile_ldm.py rename to onediff_sd_webui_extensions/compile/compile_ldm.py diff --git a/onediff_sd_webui_extensions/compile_sgm.py b/onediff_sd_webui_extensions/compile/compile_sgm.py similarity index 100% rename from onediff_sd_webui_extensions/compile_sgm.py rename to onediff_sd_webui_extensions/compile/compile_sgm.py diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py new file mode 100644 index 000000000..66c5fc503 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -0,0 +1,74 @@ +import os +from typing import Dict + +# import modules.shared as shared +import warnings +from typing import Union, Dict +from pathlib import Path + +from .compile_ldm import compile_ldm_unet +from .compile_sgm import compile_sgm_unet +from .onediff_compiled_graph import OneDiffCompiledGraph +from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM +from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM +from onediff.optimization.quant_optimizer import ( + quantize_model, + varify_can_use_quantization, +) +from onediff.utils import logger +from onediff_shared import graph_dict + +from modules.sd_models import select_checkpoint + + +def compile_unet( + unet_model, quantization=False, *, options=None, +): + if isinstance(unet_model, UNetModelLDM): + compiled_unet = compile_ldm_unet(unet_model, options=options) + elif isinstance(unet_model, UNetModelSGM): + compiled_unet = compile_sgm_unet(unet_model, options=options) + else: + warnings.warn( + f"Unsupported model type: {type(unet_model)} for compilation , skip", + RuntimeWarning, + ) + compiled_unet = unet_model + # In OneDiff Community, quantization can be True when called by api + if quantization and varify_can_use_quantization(): + calibrate_info = get_calibrate_info( + f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" + ) + compiled_unet = quantize_model( + compiled_unet, inplace=False, calibrate_info=calibrate_info + ) + return compiled_unet + + +def get_calibrate_info(filename: str) -> Union[None, Dict]: + calibration_path = Path(select_checkpoint().filename).parent / filename + if not calibration_path.exists(): + return None + + logger.info(f"Got calibrate info at {str(calibration_path)}") + calibrate_info = {} + with open(calibration_path, "r") as f: + for line in f.readlines(): + line = line.strip() + items = line.split(" ") + calibrate_info[items[0]] = [ + float(items[1]), + int(items[2]), + [float(x) for x in items[3].split(",")], + ] + return calibrate_info + + +def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: + if sd_model.sd_model_hash in graph_dict: + return graph_dict[sd_model.sd_model_hash] + else: + compiled_unet = compile_unet( + sd_model.model.diffusion_model, quantization=quantization + ) + return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile_vae.py b/onediff_sd_webui_extensions/compile/compile_vae.py similarity index 100% rename from onediff_sd_webui_extensions/compile_vae.py rename to onediff_sd_webui_extensions/compile/compile_vae.py diff --git a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py new file mode 100644 index 000000000..efeaf6cfc --- /dev/null +++ b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py @@ -0,0 +1,29 @@ +import dataclasses +import torch +from onediff.infer_compiler import DeployableModule +from modules import sd_models_types + + +@dataclasses.dataclass +class OneDiffCompiledGraph: + name: str = None + filename: str = None + sha: str = None + eager_module: torch.nn.Module = None + graph_module: DeployableModule = None + quantized: bool = False + + def __init__( + self, + sd_model: sd_models_types.WebuiSdModel = None, + graph_module: DeployableModule = None, + quantized=False, + ): + if sd_model is None: + return + self.name = sd_model.sd_checkpoint_info.name + self.filename = sd_model.sd_checkpoint_info.filename + self.sha = sd_model.sd_model_hash + self.eager_module = sd_model.model.diffusion_model + self.graph_module = graph_module + self.quantized = quantized diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index c8da677c6..65241da36 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,5 +1,4 @@ -import compile_ldm -import compile_sgm +from compile import compile_ldm, compile_sgm import oneflow diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py new file mode 100644 index 000000000..a2b04c834 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -0,0 +1,13 @@ +from typing import Dict +from compile.onediff_compiled_graph import OneDiffCompiledGraph + +# from compile_utils import OneDiffCompiledGraph + +current_unet_graph = OneDiffCompiledGraph() +graph_dict = dict() +current_unet_type = { + "is_sdxl": False, + "is_sd2": False, + "is_sd1": False, + "is_ssd": False, +} diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 5e5766c04..b39caa716 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -7,9 +7,12 @@ import gradio as gr import modules.scripts as scripts import modules.shared as shared -from compile_ldm import SD21CompileCtx, compile_ldm_unet -from compile_sgm import compile_sgm_unet -from compile_vae import VaeCompileCtx +from compile import ( + SD21CompileCtx, + VaeCompileCtx, + get_compiled_graph, + OneDiffCompiledGraph, +) from modules import script_callbacks from modules.processing import process_images from modules.sd_models import select_checkpoint @@ -22,6 +25,9 @@ get_all_compiler_caches, hints_message, refresh_all_compiler_caches, + check_structure_change_and_update, + load_graph, + save_graph, ) from onediff import __version__ as onediff_version @@ -30,11 +36,13 @@ varify_can_use_quantization, ) from onediff.utils import logger, parse_boolean_from_env +import onediff_shared """oneflow_compiled UNetModel""" -compiled_unet = None -is_unet_quantized = False -compiled_ckpt_name = None +# compiled_unet = {} +# compiled_unet = None +# is_unet_quantized = False +# compiled_ckpt_name = None def generate_graph_path(ckpt_name: str, model_name: str) -> str: @@ -68,43 +76,18 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: return calibrate_info -def compile_unet( - unet_model, quantization=False, *, options=None, -): - from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM - from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM - - if isinstance(unet_model, UNetModelLDM): - compiled_unet = compile_ldm_unet(unet_model, options=options) - elif isinstance(unet_model, UNetModelSGM): - compiled_unet = compile_sgm_unet(unet_model, options=options) - else: - warnings.warn( - f"Unsupported model type: {type(unet_model)} for compilation , skip", - RuntimeWarning, - ) - compiled_unet = unet_model - # In OneDiff Community, quantization can be True when called by api - if quantization and varify_can_use_quantization(): - calibrate_info = get_calibrate_info( - f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" - ) - compiled_unet = quantize_model( - compiled_unet, inplace=False, calibrate_info=calibrate_info - ) - return compiled_unet - - class UnetCompileCtx(object): """The unet model is stored in a global variable. The global variables need to be replaced with compiled_unet before process_images is run, and then the original model restored so that subsequent reasoning with onediff disabled meets expectations. """ + def __init__(self, compiled_unet): + self.compiled_unet = compiled_unet + def __enter__(self): self._original_model = shared.sd_model.model.diffusion_model - global compiled_unet - shared.sd_model.model.diffusion_model = compiled_unet + shared.sd_model.model.diffusion_model = self.compiled_unet def __exit__(self, exc_type, exc_val, exc_tb): shared.sd_model.model.diffusion_model = self._original_model @@ -112,16 +95,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): class Script(scripts.Script): - current_type = None - def title(self): return "onediff_diffusion_model" def ui(self, is_img2img): - """this function should create gradio UI elements. See https://gradio.app/docs/#components - The return value should be an array of all components that are used in processing. - Values of those returned components will be passed to run() and process() functions. - """ with gr.Row(): # TODO: set choices as Tuple[str, str] after the version of gradio specified webui upgrades compiler_cache = gr.Dropdown( @@ -142,7 +119,11 @@ def ui(self, is_img2img): label="always_recompile", visible=parse_boolean_from_env("ONEDIFF_DEBUG"), ) - gr.HTML(hints_message, elem_id="hintMessage", visible=not varify_can_use_quantization()) + gr.HTML( + hints_message, + elem_id="hintMessage", + visible=not varify_can_use_quantization(), + ) is_quantized = gr.components.Checkbox( label="Model Quantization(int8) Speed Up", visible=varify_can_use_quantization(), @@ -150,30 +131,7 @@ def ui(self, is_img2img): return [is_quantized, compiler_cache, save_cache_name, always_recompile] def show(self, is_img2img): - return True - - def check_model_change(self, model): - is_changed = False - - def get_model_type(model): - return { - "is_sdxl": model.is_sdxl, - "is_sd2": model.is_sd2, - "is_sd1": model.is_sd1, - "is_ssd": model.is_ssd, - } - - if self.current_type is None: - is_changed = True - else: - for key, v in self.current_type.items(): - if v != getattr(model, key): - is_changed = True - break - - if is_changed is True: - self.current_type = get_model_type(model) - return is_changed + return scripts.AlwaysVisible def run( self, @@ -184,67 +142,44 @@ def run( always_recompile=False, ): - global compiled_unet, compiled_ckpt_name, is_unet_quantized - current_checkpoint = shared.opts.sd_model_checkpoint - original_diffusion_model = shared.sd_model.model.diffusion_model - - ckpt_changed = current_checkpoint != compiled_ckpt_name - model_changed = self.check_model_change(shared.sd_model) - quantization_changed = quantization != is_unet_quantized + current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name + ckpt_changed = ( + shared.sd_model.sd_checkpoint_info.name + != onediff_shared.current_unet_graph.name + ) + structure_changed = check_structure_change_and_update( + onediff_shared.current_unet_type, shared.sd_model + ) + quantization_changed = ( + quantization != onediff_shared.current_unet_graph.quantized + ) need_recompile = ( ( quantization and ckpt_changed ) # always recompile when switching ckpt with 'int8 speed model' enabled - or model_changed # always recompile when switching model to another structure + or structure_changed # always recompile when switching model to another structure or quantization_changed # always recompile when switching model from non-quantized to quantized (and vice versa) or always_recompile ) - - is_unet_quantized = quantization - compiled_ckpt_name = current_checkpoint if need_recompile: - compiled_unet = compile_unet( - original_diffusion_model, quantization=quantization + onediff_shared.current_unet_graph = get_compiled_graph( + shared.sd_model, quantization ) - - # Due to the version of gradio compatible with sd-webui, the CompilerCache dropdown box always returns a string - if compiler_cache not in [None, "None"]: - compiler_cache_path = all_compiler_caches_path() + f"/{compiler_cache}" - if not Path(compiler_cache_path).exists(): - raise FileNotFoundError( - f"Cannot find cache {compiler_cache_path}, please make sure it exists" - ) - try: - compiled_unet.load_graph(compiler_cache_path, run_warmup=True) - except zipfile.BadZipFile: - raise RuntimeError( - "Load cache failed. Please make sure that the --disable-safe-unpickle parameter is added when starting the webui" - ) - except Exception as e: - raise RuntimeError( - f"Load cache failed ({e}). Please make sure cache has the same sd version (or unet architure) with current checkpoint" - ) - + load_graph(onediff_shared.current_unet_graph, compiler_cache) else: logger.info( - f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile" + f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.current_unet_type}, skip compile" ) - with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + # register graph + onediff_shared.graph_dict[shared.sd_model.sd_model_hash] = OneDiffCompiledGraph( + shared.sd_model, graph_module=onediff_shared.current_unet_graph.graph_module + ) + with UnetCompileCtx( + onediff_shared.current_unet_graph.graph_module + ), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) - - if saved_cache_name != "": - if not os.access(str(all_compiler_caches_path()), os.W_OK): - raise PermissionError( - f"The directory {all_compiler_caches_path()} does not have write permissions, and compiler cache cannot be written to this directory. \ - Please change it in the settings to a directory with write permissions" - ) - if not Path(all_compiler_caches_path()).exists(): - Path(all_compiler_caches_path()).mkdir() - saved_cache_name = all_compiler_caches_path() + f"/{saved_cache_name}" - if not Path(saved_cache_name).exists(): - compiled_unet.save_graph(saved_cache_name) - + save_graph(onediff_shared.current_unet_graph, saved_cache_name) return proc @@ -260,5 +195,23 @@ def on_ui_settings(): ) +def cfg_denoisers_callback(params): + # print(f"current checkpoint: {shared.opts.sd_model_checkpoint}") + # import ipdb; ipdb.set_trace() + if "refiner" in shared.sd_model.sd_checkpoint_info.name: + pass + # import ipdb; ipdb.set_trace() + # shared.sd_model.model.diffusion_model + + print(f"current checkpoint info: {shared.sd_model.sd_checkpoint_info.name}") + # shared.sd_model.model.diffusion_model = compile_unet( + # shared.sd_model.model.diffusion_model + # ) + + # have to check if onediff enabled + # print('onediff denoiser callback') + + script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) onediff_do_hijack() diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/ui_utils.py index 7e442be4a..a23efbdf1 100644 --- a/onediff_sd_webui_extensions/ui_utils.py +++ b/onediff_sd_webui_extensions/ui_utils.py @@ -1,7 +1,12 @@ +import os from pathlib import Path from textwrap import dedent +from onediff.infer_compiler import DeployableModule +from zipfile import BadZipFile +import onediff_shared -hints_message = dedent("""\ +hints_message = dedent( + """\
@@ -21,7 +26,8 @@ https://github.com/siliconflow/onediff/issues

-""") +""" +) all_compiler_caches = [] @@ -46,3 +52,65 @@ def refresh_all_compiler_caches(path: Path = None): global all_compiler_caches path = path or all_compiler_caches_path() all_compiler_caches = [f.stem for f in Path(path).iterdir() if f.is_file()] + + +def check_structure_change_and_update(current_type: dict[str, bool], model): + def get_model_type(model): + return { + "is_sdxl": model.is_sdxl, + "is_sd2": model.is_sd2, + "is_sd1": model.is_sd1, + "is_ssd": model.is_ssd, + } + + changed = current_type != get_model_type(model) + current_type.update(**get_model_type(model)) + return changed + + +def load_graph(compiled_unet: DeployableModule, compiler_cache: str): + from compile import OneDiffCompiledGraph + + if isinstance(compiled_unet, OneDiffCompiledGraph): + compiled_unet = compiled_unet.graph_module + + if compiler_cache in [None, "None"]: + return + + compiler_cache_path = all_compiler_caches_path() + f"/{compiler_cache}" + if not Path(compiler_cache_path).exists(): + raise FileNotFoundError( + f"Cannot find cache {compiler_cache_path}, please make sure it exists" + ) + try: + compiled_unet.load_graph(compiler_cache_path, run_warmup=True) + except BadZipFile: + raise RuntimeError( + "Load cache failed. Please make sure that the --disable-safe-unpickle parameter is added when starting the webui" + ) + except Exception as e: + raise RuntimeError( + f"Load cache failed ({e}). Please make sure cache has the same sd version (or unet architure) with current checkpoint" + ) + return compiled_unet + + +def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): + from compile import OneDiffCompiledGraph + + if isinstance(compiled_unet, OneDiffCompiledGraph): + compiled_unet = compiled_unet.graph_module + + if saved_cache_name in ["", None]: + return + + if not os.access(str(all_compiler_caches_path()), os.W_OK): + raise PermissionError( + f"The directory {all_compiler_caches_path()} does not have write permissions, and compiler cache cannot be written to this directory. \ + Please change it in the settings to a directory with write permissions" + ) + if not Path(all_compiler_caches_path()).exists(): + Path(all_compiler_caches_path()).mkdir() + saved_cache_name = all_compiler_caches_path() + f"/{saved_cache_name}" + if not Path(saved_cache_name).exists(): + compiled_unet.save_graph(saved_cache_name) From e4332cf7dec6cefaaa14ce29aab57f590b3ce469 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 29 May 2024 17:07:56 +0800 Subject: [PATCH 02/24] move mock utils --- .../{ => compile}/sd_webui_onediff_utils.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename onediff_sd_webui_extensions/{ => compile}/sd_webui_onediff_utils.py (100%) diff --git a/onediff_sd_webui_extensions/sd_webui_onediff_utils.py b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py similarity index 100% rename from onediff_sd_webui_extensions/sd_webui_onediff_utils.py rename to onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py From 686d5333248e6ea6decaf5817179aebd99a0520b Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 4 Jun 2024 15:58:20 +0800 Subject: [PATCH 03/24] fix bug of refiner --- .../compile/__init__.py | 6 +- .../compile/compile_ldm.py | 2 +- .../compile/compile_sgm.py | 2 +- .../compile/compile_utils.py | 14 +-- .../compile/onediff_compiled_graph.py | 4 +- onediff_sd_webui_extensions/onediff_hijack.py | 2 +- onediff_sd_webui_extensions/onediff_lora.py | 118 ++++++++++++++++++ onediff_sd_webui_extensions/onediff_shared.py | 5 +- .../scripts/onediff.py | 36 ++++-- onediff_sd_webui_extensions/ui_utils.py | 14 ++- 10 files changed, 176 insertions(+), 27 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 4d225f4c6..c08ce8c49 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -1,9 +1,9 @@ # from .compile_ldm import SD21CompileCtx, compile_ldm_unet from .compile_ldm import SD21CompileCtx -# from .compile_sgm import compile_sgm_unet -from .compile_vae import VaeCompileCtx - # from .compile_utils import compile_unet, get_compiled_unet from .compile_utils import get_compiled_graph + +# from .compile_sgm import compile_sgm_unet +from .compile_vae import VaeCompileCtx from .onediff_compiled_graph import OneDiffCompiledGraph diff --git a/onediff_sd_webui_extensions/compile/compile_ldm.py b/onediff_sd_webui_extensions/compile/compile_ldm.py index e87f7f696..9847e91b1 100644 --- a/onediff_sd_webui_extensions/compile/compile_ldm.py +++ b/onediff_sd_webui_extensions/compile/compile_ldm.py @@ -9,7 +9,7 @@ from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import shared -from sd_webui_onediff_utils import ( +from .sd_webui_onediff_utils import ( CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding, diff --git a/onediff_sd_webui_extensions/compile/compile_sgm.py b/onediff_sd_webui_extensions/compile/compile_sgm.py index 154b3dc5c..4a6ad6d7e 100644 --- a/onediff_sd_webui_extensions/compile/compile_sgm.py +++ b/onediff_sd_webui_extensions/compile/compile_sgm.py @@ -1,5 +1,5 @@ import oneflow as flow -from sd_webui_onediff_utils import ( +from .sd_webui_onediff_utils import ( CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding, diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 66c5fc503..26b4fa39c 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -1,24 +1,23 @@ import os -from typing import Dict # import modules.shared as shared import warnings -from typing import Union, Dict from pathlib import Path +from typing import Dict, Union -from .compile_ldm import compile_ldm_unet -from .compile_sgm import compile_sgm_unet -from .onediff_compiled_graph import OneDiffCompiledGraph from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM +from modules.sd_models import select_checkpoint from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM + from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, ) from onediff.utils import logger -from onediff_shared import graph_dict -from modules.sd_models import select_checkpoint +from .compile_ldm import compile_ldm_unet +from .compile_sgm import compile_sgm_unet +from .onediff_compiled_graph import OneDiffCompiledGraph def compile_unet( @@ -65,6 +64,7 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: + from onediff_shared import graph_dict if sd_model.sd_model_hash in graph_dict: return graph_dict[sd_model.sd_model_hash] else: diff --git a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py index efeaf6cfc..d6a09aca3 100644 --- a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py +++ b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py @@ -1,8 +1,10 @@ import dataclasses + import torch -from onediff.infer_compiler import DeployableModule from modules import sd_models_types +from onediff.infer_compiler import DeployableModule + @dataclasses.dataclass class OneDiffCompiledGraph: diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index 65241da36..b6df91af0 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,5 +1,5 @@ -from compile import compile_ldm, compile_sgm import oneflow +from compile import compile_ldm, compile_sgm # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8 diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index 0bee88e9d..0d8ccfa80 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,10 +1,17 @@ import torch +from typing import Mapping, Any from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.backends.oneflow.param_utils import ( update_graph_related_tensor, ) +from onediff_shared import onediff_enabled + +from modules import sd_models +from modules.sd_hijack_utils import CondFunc +from compile import OneDiffCompiledGraph + class HijackLoraActivate: def __init__(self): @@ -57,3 +64,114 @@ def activate(self, p, params_list): activate._onediff_hijacked = True return activate + + +# class HijackLoadModelWeights: +# # def __init__(self): +# # from modules import extra_networks + +# # if "lora" in extra_networks.extra_network_registry: +# # cls_extra_network_lora = type(extra_networks.extra_network_registry["lora"]) +# # else: +# # cls_extra_network_lora = None +# # self.lora_class = cls_extra_network_lora + +# def __enter__(self): +# self.orig_func = sd_models.load_model_weights +# sd_models.load_model_weights = onediff_hijack_load_model_weights + +# def __exit__(self, exc_type, exc_val, exc_tb): +# sd_models.load_model_weights = self.orig_func + +def onediff_hijack_load_model_weights(orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer): + # load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer) + sd_model_hash = checkpoint_info.calculate_shorthash() + import onediff_shared + cached_model: OneDiffCompiledGraph = onediff_shared.graph_dict.get(sd_model_hash, None) + if cached_model is not None: + model.model.diffusion_model = cached_model.graph_module + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("model.diffusion_model.")} + return orig_func(model, checkpoint_info, state_dict, timer) + + +def onediff_hijack_load_state_dict(orig_func, self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(self.parameters()).is_meta: + return orig_func(self, state_dict, strict, assign=True) + else: + return orig_func(self, state_dict, strict, assign) + + +def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self): + from modules import shared + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): + used_param_keys = [] + + for name, param in module._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + used_param_keys.append(key) + + if param.is_meta: + dtype = sd_param.dtype if sd_param is not None else param.dtype + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in module._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) + + original(module, state_dict, prefix, *args, **kwargs) + + for key in used_param_keys: + state_dict.pop(key, None) + + # def load_state_dict(original, module, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict is sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + # ------------------- DIFF HERE ------------------- + # original(module, state_dict, strict=strict) + if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta: + assign = True + else: + assign = False + # orig_func(original, module, state_dict, strict=strict, assign=assign) + original(module, state_dict, strict=strict, assign=assign) + + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) + + +CondFunc("modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", onediff_hijaced_LoadStateDictOnMeta___enter__, lambda _, *args, **kwargs: onediff_enabled) +CondFunc("modules.sd_models.load_model_weights", onediff_hijack_load_model_weights, lambda _, *args, **kwargs: onediff_enabled) \ No newline at end of file diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index a2b04c834..9bdd82678 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -1,13 +1,16 @@ from typing import Dict + from compile.onediff_compiled_graph import OneDiffCompiledGraph # from compile_utils import OneDiffCompiledGraph current_unet_graph = OneDiffCompiledGraph() -graph_dict = dict() +graph_dict: Dict[str, OneDiffCompiledGraph] = dict() +refiner_dict: Dict[str, str] = dict() current_unet_type = { "is_sdxl": False, "is_sd2": False, "is_sd1": False, "is_ssd": False, } +onediff_enabled = True \ No newline at end of file diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index b39caa716..4e27db5d5 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,4 +1,5 @@ import os +import torch import warnings import zipfile from pathlib import Path @@ -7,11 +8,13 @@ import gradio as gr import modules.scripts as scripts import modules.shared as shared +import modules.sd_models as sd_models +import onediff_shared from compile import ( + OneDiffCompiledGraph, SD21CompileCtx, VaeCompileCtx, get_compiled_graph, - OneDiffCompiledGraph, ) from modules import script_callbacks from modules.processing import process_images @@ -22,12 +25,13 @@ from oneflow import __version__ as oneflow_version from ui_utils import ( all_compiler_caches_path, + check_structure_change_and_update, get_all_compiler_caches, hints_message, - refresh_all_compiler_caches, - check_structure_change_and_update, load_graph, + refresh_all_compiler_caches, save_graph, + onediff_enabled, ) from onediff import __version__ as onediff_version @@ -36,7 +40,6 @@ varify_can_use_quantization, ) from onediff.utils import logger, parse_boolean_from_env -import onediff_shared """oneflow_compiled UNetModel""" # compiled_unet = {} @@ -82,12 +85,13 @@ class UnetCompileCtx(object): and then the original model restored so that subsequent reasoning with onediff disabled meets expectations. """ - def __init__(self, compiled_unet): - self.compiled_unet = compiled_unet + # def __init__(self, compiled_unet): + # self.compiled_unet = compiled_unet def __enter__(self): self._original_model = shared.sd_model.model.diffusion_model - shared.sd_model.model.diffusion_model = self.compiled_unet + # onediff_shared.current_unet_graph.graph_module + shared.sd_model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module def __exit__(self, exc_type, exc_val, exc_tb): shared.sd_model.model.diffusion_model = self._original_model @@ -131,7 +135,7 @@ def ui(self, is_img2img): return [is_quantized, compiler_cache, save_cache_name, always_recompile] def show(self, is_img2img): - return scripts.AlwaysVisible + return True def run( self, @@ -141,6 +145,11 @@ def run( saved_cache_name="", always_recompile=False, ): + # restore checkpoint_info from refiner to base model + if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: + p.override_settings.pop('sd_model_checkpoint', None) + sd_models.reload_model_weights() + torch.cuda.empty_cache() current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( @@ -175,9 +184,8 @@ def run( onediff_shared.graph_dict[shared.sd_model.sd_model_hash] = OneDiffCompiledGraph( shared.sd_model, graph_module=onediff_shared.current_unet_graph.graph_module ) - with UnetCompileCtx( - onediff_shared.current_unet_graph.graph_module - ), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + + with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(), onediff_enabled(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) return proc @@ -196,9 +204,15 @@ def on_ui_settings(): def cfg_denoisers_callback(params): + # check refiner model # print(f"current checkpoint: {shared.opts.sd_model_checkpoint}") # import ipdb; ipdb.set_trace() if "refiner" in shared.sd_model.sd_checkpoint_info.name: + # onediff_shared.current_unet_graph = get_compiled_graph( + # shared.sd_model, quantization + # ) + # load_graph(onediff_shared.current_unet_graph, compiler_cache) + # import ipdb; ipdb.set_trace() pass # import ipdb; ipdb.set_trace() # shared.sd_model.model.diffusion_model diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/ui_utils.py index a23efbdf1..b4fbf369e 100644 --- a/onediff_sd_webui_extensions/ui_utils.py +++ b/onediff_sd_webui_extensions/ui_utils.py @@ -1,10 +1,12 @@ import os from pathlib import Path from textwrap import dedent -from onediff.infer_compiler import DeployableModule from zipfile import BadZipFile + import onediff_shared +from onediff.infer_compiler import DeployableModule + hints_message = dedent( """\
@@ -114,3 +116,13 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): saved_cache_name = all_compiler_caches_path() + f"/{saved_cache_name}" if not Path(saved_cache_name).exists(): compiled_unet.save_graph(saved_cache_name) + + +from contextlib import contextmanager +@contextmanager +def onediff_enabled(): + onediff_shared.onediff_enabled = True + try: + yield + finally: + onediff_shared.onediff_enabled = False From 156724c0c78a845bdfb78c4eecd912923e77c0d3 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 4 Jun 2024 16:14:06 +0800 Subject: [PATCH 04/24] refine, format --- .../compile/__init__.py | 12 ++- .../compile/compile_ldm.py | 7 +- .../compile/compile_sgm.py | 11 ++- .../compile/compile_utils.py | 4 +- onediff_sd_webui_extensions/onediff_lora.py | 60 +++++++----- onediff_sd_webui_extensions/onediff_shared.py | 2 +- .../scripts/onediff.py | 94 ++++--------------- onediff_sd_webui_extensions/ui_utils.py | 2 +- 8 files changed, 72 insertions(+), 120 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index c08ce8c49..90afcaceb 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -1,9 +1,11 @@ -# from .compile_ldm import SD21CompileCtx, compile_ldm_unet from .compile_ldm import SD21CompileCtx - -# from .compile_utils import compile_unet, get_compiled_unet from .compile_utils import get_compiled_graph - -# from .compile_sgm import compile_sgm_unet from .compile_vae import VaeCompileCtx from .onediff_compiled_graph import OneDiffCompiledGraph + +__all__ = [ + "get_compiled_graph", + "SD21CompileCtx", + "VaeCompileCtx", + "OneDiffCompiledGraph", +] diff --git a/onediff_sd_webui_extensions/compile/compile_ldm.py b/onediff_sd_webui_extensions/compile/compile_ldm.py index 9847e91b1..7b04e16aa 100644 --- a/onediff_sd_webui_extensions/compile/compile_ldm.py +++ b/onediff_sd_webui_extensions/compile/compile_ldm.py @@ -9,15 +9,16 @@ from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import shared + +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register + from .sd_webui_onediff_utils import ( CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding, ) -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register - __all__ = ["compile_ldm_unet"] diff --git a/onediff_sd_webui_extensions/compile/compile_sgm.py b/onediff_sd_webui_extensions/compile/compile_sgm.py index 4a6ad6d7e..09b86be59 100644 --- a/onediff_sd_webui_extensions/compile/compile_sgm.py +++ b/onediff_sd_webui_extensions/compile/compile_sgm.py @@ -1,9 +1,4 @@ import oneflow as flow -from .sd_webui_onediff_utils import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) from sgm.modules.attention import ( BasicTransformerBlock, CrossAttention, @@ -15,6 +10,12 @@ from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from .sd_webui_onediff_utils import ( + CrossAttentionOflow, + GroupNorm32Oflow, + timestep_embedding, +) + __all__ = ["compile_sgm_unet"] diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 26b4fa39c..42d53bc40 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -1,6 +1,3 @@ -import os - -# import modules.shared as shared import warnings from pathlib import Path from typing import Dict, Union @@ -65,6 +62,7 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: from onediff_shared import graph_dict + if sd_model.sd_model_hash in graph_dict: return graph_dict[sd_model.sd_model_hash] else: diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index 0d8ccfa80..a11705867 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -66,41 +66,44 @@ def activate(self, p, params_list): return activate -# class HijackLoadModelWeights: -# # def __init__(self): -# # from modules import extra_networks - -# # if "lora" in extra_networks.extra_network_registry: -# # cls_extra_network_lora = type(extra_networks.extra_network_registry["lora"]) -# # else: -# # cls_extra_network_lora = None -# # self.lora_class = cls_extra_network_lora - -# def __enter__(self): -# self.orig_func = sd_models.load_model_weights -# sd_models.load_model_weights = onediff_hijack_load_model_weights - -# def __exit__(self, exc_type, exc_val, exc_tb): -# sd_models.load_model_weights = self.orig_func - -def onediff_hijack_load_model_weights(orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer): +def onediff_hijack_load_model_weights( + orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer +): # load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer) sd_model_hash = checkpoint_info.calculate_shorthash() import onediff_shared - cached_model: OneDiffCompiledGraph = onediff_shared.graph_dict.get(sd_model_hash, None) + + cached_model: OneDiffCompiledGraph = onediff_shared.graph_dict.get( + sd_model_hash, None + ) if cached_model is not None: model.model.diffusion_model = cached_model.graph_module - state_dict = {k: v for k, v in state_dict.items() if not k.startswith("model.diffusion_model.")} + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith("model.diffusion_model.") + } return orig_func(model, checkpoint_info, state_dict, timer) -def onediff_hijack_load_state_dict(orig_func, self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): - if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(self.parameters()).is_meta: +def onediff_hijack_load_state_dict( + orig_func, + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, +): + if ( + len(state_dict) > 0 + and next(iter(state_dict.values())).is_cuda + and next(self.parameters()).is_meta + ): return orig_func(self, state_dict, strict, assign=True) else: return orig_func(self, state_dict, strict, assign) +# fmt: off def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self): from modules import shared if shared.cmd_opts.disable_model_loading_ram_optimization: @@ -171,7 +174,16 @@ def load_state_dict(original, module, state_dict, strict=True): mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) +# fmt: on -CondFunc("modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", onediff_hijaced_LoadStateDictOnMeta___enter__, lambda _, *args, **kwargs: onediff_enabled) -CondFunc("modules.sd_models.load_model_weights", onediff_hijack_load_model_weights, lambda _, *args, **kwargs: onediff_enabled) \ No newline at end of file +CondFunc( + "modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", + onediff_hijaced_LoadStateDictOnMeta___enter__, + lambda _, *args, **kwargs: onediff_enabled, +) +CondFunc( + "modules.sd_models.load_model_weights", + onediff_hijack_load_model_weights, + lambda _, *args, **kwargs: onediff_enabled, +) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index 9bdd82678..233f0c887 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -13,4 +13,4 @@ "is_sd1": False, "is_ssd": False, } -onediff_enabled = True \ No newline at end of file +onediff_enabled = True diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 4e27db5d5..890cff67e 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,15 +1,11 @@ -import os -import torch -import warnings -import zipfile from pathlib import Path -from typing import Dict, Union import gradio as gr import modules.scripts as scripts -import modules.shared as shared import modules.sd_models as sd_models +import modules.shared as shared import onediff_shared +import torch from compile import ( OneDiffCompiledGraph, SD21CompileCtx, @@ -18,65 +14,23 @@ ) from modules import script_callbacks from modules.processing import process_images -from modules.sd_models import select_checkpoint from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate -from oneflow import __version__ as oneflow_version from ui_utils import ( - all_compiler_caches_path, check_structure_change_and_update, get_all_compiler_caches, hints_message, load_graph, + onediff_enabled, refresh_all_compiler_caches, save_graph, - onediff_enabled, ) -from onediff import __version__ as onediff_version -from onediff.optimization.quant_optimizer import ( - quantize_model, - varify_can_use_quantization, -) +from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff.utils import logger, parse_boolean_from_env """oneflow_compiled UNetModel""" -# compiled_unet = {} -# compiled_unet = None -# is_unet_quantized = False -# compiled_ckpt_name = None - - -def generate_graph_path(ckpt_name: str, model_name: str) -> str: - base_output_dir = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples - save_ckpt_graphs_path = os.path.join(base_output_dir, "graphs", ckpt_name) - os.makedirs(save_ckpt_graphs_path, exist_ok=True) - - file_name = f"{model_name}_graph_{onediff_version}_oneflow_{oneflow_version}" - - graph_file_path = os.path.join(save_ckpt_graphs_path, file_name) - - return graph_file_path - - -def get_calibrate_info(filename: str) -> Union[None, Dict]: - calibration_path = Path(select_checkpoint().filename).parent / filename - if not calibration_path.exists(): - return None - - logger.info(f"Got calibrate info at {str(calibration_path)}") - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - return calibrate_info class UnetCompileCtx(object): @@ -85,13 +39,11 @@ class UnetCompileCtx(object): and then the original model restored so that subsequent reasoning with onediff disabled meets expectations. """ - # def __init__(self, compiled_unet): - # self.compiled_unet = compiled_unet - def __enter__(self): self._original_model = shared.sd_model.model.diffusion_model - # onediff_shared.current_unet_graph.graph_module - shared.sd_model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module + shared.sd_model.model.diffusion_model = ( + onediff_shared.current_unet_graph.graph_module + ) def __exit__(self, exc_type, exc_val, exc_tb): shared.sd_model.model.diffusion_model = self._original_model @@ -146,8 +98,13 @@ def run( always_recompile=False, ): # restore checkpoint_info from refiner to base model - if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: - p.override_settings.pop('sd_model_checkpoint', None) + if ( + sd_models.checkpoint_aliases.get( + p.override_settings.get("sd_model_checkpoint") + ) + is None + ): + p.override_settings.pop("sd_model_checkpoint", None) sd_models.reload_model_weights() torch.cuda.empty_cache() @@ -204,28 +161,9 @@ def on_ui_settings(): def cfg_denoisers_callback(params): - # check refiner model - # print(f"current checkpoint: {shared.opts.sd_model_checkpoint}") - # import ipdb; ipdb.set_trace() - if "refiner" in shared.sd_model.sd_checkpoint_info.name: - # onediff_shared.current_unet_graph = get_compiled_graph( - # shared.sd_model, quantization - # ) - # load_graph(onediff_shared.current_unet_graph, compiler_cache) - # import ipdb; ipdb.set_trace() - pass - # import ipdb; ipdb.set_trace() - # shared.sd_model.model.diffusion_model - - print(f"current checkpoint info: {shared.sd_model.sd_checkpoint_info.name}") - # shared.sd_model.model.diffusion_model = compile_unet( - # shared.sd_model.model.diffusion_model - # ) - - # have to check if onediff enabled - # print('onediff denoiser callback') + pass script_callbacks.on_ui_settings(on_ui_settings) -script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) +# script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) onediff_do_hijack() diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/ui_utils.py index b4fbf369e..bdb875a38 100644 --- a/onediff_sd_webui_extensions/ui_utils.py +++ b/onediff_sd_webui_extensions/ui_utils.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager from pathlib import Path from textwrap import dedent from zipfile import BadZipFile @@ -118,7 +119,6 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): compiled_unet.save_graph(saved_cache_name) -from contextlib import contextmanager @contextmanager def onediff_enabled(): onediff_shared.onediff_enabled = True From 7b51da0b3ac3ea60d432df4316241b57508939ac Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 4 Jun 2024 17:01:16 +0800 Subject: [PATCH 05/24] add test --- tests/sd-webui/test_api.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index c745ad86d..2fbc40cfb 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -79,3 +79,14 @@ def test_onediff_load_graph(url_txt2img): } data = {**get_base_args(), **script_args} post_request_and_check(url_txt2img, data) + + +def test_onediff_refiner(url_txt2img): + extra_args = { + "refiner_checkpoint" :"sd_xl_refiner_1.0.safetensors [7440042bbd]", + "refiner_switch_at" : 0.8, + } + data = {**get_base_args(), **extra_args} + # loop 5 times for checking model switching between base and refiner + for _ in range(5): + post_request_and_check(url_txt2img, data) From 0843f459251a52627c67bf70cabbd93707702ef0 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 4 Jun 2024 23:12:27 +0800 Subject: [PATCH 06/24] fix cuda memory of refiner --- .../compile/compile_utils.py | 14 +++----- onediff_sd_webui_extensions/onediff_lora.py | 32 +++++++++++-------- onediff_sd_webui_extensions/onediff_shared.py | 6 ++-- .../scripts/onediff.py | 16 +++------- tests/sd-webui/test_api.py | 1 + 5 files changed, 31 insertions(+), 38 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 42d53bc40..89339832f 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -5,6 +5,7 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM from modules.sd_models import select_checkpoint from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM +from ui_utils import check_structure_change_and_update from onediff.optimization.quant_optimizer import ( quantize_model, @@ -61,12 +62,7 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: - from onediff_shared import graph_dict - - if sd_model.sd_model_hash in graph_dict: - return graph_dict[sd_model.sd_model_hash] - else: - compiled_unet = compile_unet( - sd_model.model.diffusion_model, quantization=quantization - ) - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) + compiled_unet = compile_unet( + sd_model.model.diffusion_model, quantization=quantization + ) + return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index a11705867..fb8e8b817 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,17 +1,15 @@ +from typing import Any, Mapping + import torch -from typing import Mapping, Any +from modules import sd_models +from modules.sd_hijack_utils import CondFunc +from onediff_shared import onediff_enabled from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.backends.oneflow.param_utils import ( update_graph_related_tensor, ) -from onediff_shared import onediff_enabled - -from modules import sd_models -from modules.sd_hijack_utils import CondFunc -from compile import OneDiffCompiledGraph - class HijackLoraActivate: def __init__(self): @@ -60,7 +58,11 @@ def activate(self, p, params_list): continue networks.network_apply_weights(sub_module) if isinstance(sub_module, torch.nn.Conv2d): - update_graph_related_tensor(sub_module) + # TODO(WangYi): refine here + try: + update_graph_related_tensor(sub_module) + except: + pass activate._onediff_hijacked = True return activate @@ -73,16 +75,20 @@ def onediff_hijack_load_model_weights( sd_model_hash = checkpoint_info.calculate_shorthash() import onediff_shared - cached_model: OneDiffCompiledGraph = onediff_shared.graph_dict.get( - sd_model_hash, None - ) - if cached_model is not None: - model.model.diffusion_model = cached_model.graph_module + if onediff_shared.current_unet_graph.sha == sd_model_hash: + model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module state_dict = { k: v for k, v in state_dict.items() if not k.startswith("model.diffusion_model.") } + + # for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check + state_dict[ + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ] = model.get_parameter( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ) return orig_func(model, checkpoint_info, state_dict, timer) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index 233f0c887..a5dcd563a 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -2,10 +2,8 @@ from compile.onediff_compiled_graph import OneDiffCompiledGraph -# from compile_utils import OneDiffCompiledGraph - current_unet_graph = OneDiffCompiledGraph() -graph_dict: Dict[str, OneDiffCompiledGraph] = dict() +current_quantization = False refiner_dict: Dict[str, str] = dict() current_unet_type = { "is_sdxl": False, @@ -13,4 +11,4 @@ "is_sd1": False, "is_ssd": False, } -onediff_enabled = True +onediff_enabled = False diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 890cff67e..0ab98eab2 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -5,13 +5,9 @@ import modules.sd_models as sd_models import modules.shared as shared import onediff_shared +import oneflow as flow import torch -from compile import ( - OneDiffCompiledGraph, - SD21CompileCtx, - VaeCompileCtx, - get_compiled_graph, -) +from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph from modules import script_callbacks from modules.processing import process_images from modules.ui_common import create_refresh_button @@ -97,7 +93,7 @@ def run( saved_cache_name="", always_recompile=False, ): - # restore checkpoint_info from refiner to base model + # restore checkpoint_info from refiner to base model if necessary if ( sd_models.checkpoint_aliases.get( p.override_settings.get("sd_model_checkpoint") @@ -107,6 +103,7 @@ def run( p.override_settings.pop("sd_model_checkpoint", None) sd_models.reload_model_weights() torch.cuda.empty_cache() + flow.cuda.empty_cache() current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( @@ -137,11 +134,6 @@ def run( f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.current_unet_type}, skip compile" ) - # register graph - onediff_shared.graph_dict[shared.sd_model.sd_model_hash] = OneDiffCompiledGraph( - shared.sd_model, graph_module=onediff_shared.current_unet_graph.graph_module - ) - with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(), onediff_enabled(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index 2fbc40cfb..9c6d32fdc 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -83,6 +83,7 @@ def test_onediff_load_graph(url_txt2img): def test_onediff_refiner(url_txt2img): extra_args = { + "sd_model_checkpoint": "sd_xl_base_1.0.safetensors", "refiner_checkpoint" :"sd_xl_refiner_1.0.safetensors [7440042bbd]", "refiner_switch_at" : 0.8, } From 345da80d6de630114d4c1654989585b13e29d16d Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 5 Jun 2024 12:43:53 +0800 Subject: [PATCH 07/24] refine --- onediff_sd_webui_extensions/README.md | 2 + .../compile/compile_utils.py | 1 - onediff_sd_webui_extensions/onediff_hijack.py | 133 ++++++++++++++++++ onediff_sd_webui_extensions/onediff_lora.py | 132 ----------------- onediff_sd_webui_extensions/onediff_shared.py | 3 - .../scripts/onediff.py | 6 +- tests/sd-webui/test_api.py | 3 +- 7 files changed, 141 insertions(+), 139 deletions(-) diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index e4a0e3f3a..0e7b14d14 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -4,8 +4,10 @@ - [Installation Guide](#installation-guide) - [Extensions Usage](#extensions-usage) - [Fast Model Switching](#fast-model-switching) + - [Compiler cache saving and loading](#compiler-cache-saving-and-loading) - [LoRA](#lora) - [Quantization](#quantization) +- [Use OneDiff by API](#use-onediff-by-api) - [Contact](#contact) ## Performance of Community Edition diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 89339832f..9d39fbc96 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -5,7 +5,6 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM from modules.sd_models import select_checkpoint from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM -from ui_utils import check_structure_change_and_update from onediff.optimization.quant_optimizer import ( quantize_model, diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index b6df91af0..355180202 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,5 +1,11 @@ +from typing import Any, Mapping + import oneflow +import torch from compile import compile_ldm, compile_sgm +from modules import sd_models +from modules.sd_hijack_utils import CondFunc +from onediff_shared import onediff_enabled # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8 @@ -94,3 +100,130 @@ def undo_hijack(): name="send_model_to_cpu", new_name="__onediff_original_send_model_to_cpu", ) + + +def onediff_hijack_load_model_weights( + orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer +): + # load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer) + sd_model_hash = checkpoint_info.calculate_shorthash() + import onediff_shared + + if onediff_shared.current_unet_graph.sha == sd_model_hash: + model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith("model.diffusion_model.") + } + + # for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check + state_dict[ + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ] = model.get_parameter( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ) + return orig_func(model, checkpoint_info, state_dict, timer) + + +def onediff_hijack_load_state_dict( + orig_func, + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, +): + if ( + len(state_dict) > 0 + and next(iter(state_dict.values())).is_cuda + and next(self.parameters()).is_meta + ): + return orig_func(self, state_dict, strict, assign=True) + else: + return orig_func(self, state_dict, strict, assign) + + +# fmt: off +def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self): + from modules import shared + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): + used_param_keys = [] + + for name, param in module._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + used_param_keys.append(key) + + if param.is_meta: + dtype = sd_param.dtype if sd_param is not None else param.dtype + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in module._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) + + original(module, state_dict, prefix, *args, **kwargs) + + for key in used_param_keys: + state_dict.pop(key, None) + + # def load_state_dict(original, module, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict is sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + # ------------------- DIFF HERE ------------------- + # original(module, state_dict, strict=strict) + if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta: + assign = True + else: + assign = False + # orig_func(original, module, state_dict, strict=strict, assign=assign) + original(module, state_dict, strict=strict, assign=assign) + + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) +# fmt: on + + +CondFunc( + "modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", + onediff_hijaced_LoadStateDictOnMeta___enter__, + lambda _, *args, **kwargs: onediff_enabled, +) +CondFunc( + "modules.sd_models.load_model_weights", + onediff_hijack_load_model_weights, + lambda _, *args, **kwargs: onediff_enabled, +) diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index fb8e8b817..a1f4da8da 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,9 +1,4 @@ -from typing import Any, Mapping - import torch -from modules import sd_models -from modules.sd_hijack_utils import CondFunc -from onediff_shared import onediff_enabled from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.backends.oneflow.param_utils import ( @@ -66,130 +61,3 @@ def activate(self, p, params_list): activate._onediff_hijacked = True return activate - - -def onediff_hijack_load_model_weights( - orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer -): - # load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer) - sd_model_hash = checkpoint_info.calculate_shorthash() - import onediff_shared - - if onediff_shared.current_unet_graph.sha == sd_model_hash: - model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module - state_dict = { - k: v - for k, v in state_dict.items() - if not k.startswith("model.diffusion_model.") - } - - # for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check - state_dict[ - "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" - ] = model.get_parameter( - "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" - ) - return orig_func(model, checkpoint_info, state_dict, timer) - - -def onediff_hijack_load_state_dict( - orig_func, - self, - state_dict: Mapping[str, Any], - strict: bool = True, - assign: bool = False, -): - if ( - len(state_dict) > 0 - and next(iter(state_dict.values())).is_cuda - and next(self.parameters()).is_meta - ): - return orig_func(self, state_dict, strict, assign=True) - else: - return orig_func(self, state_dict, strict, assign) - - -# fmt: off -def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self): - from modules import shared - if shared.cmd_opts.disable_model_loading_ram_optimization: - return - - sd = self.state_dict - device = self.device - - def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): - used_param_keys = [] - - for name, param in module._parameters.items(): - if param is None: - continue - - key = prefix + name - sd_param = sd.pop(key, None) - if sd_param is not None: - state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) - used_param_keys.append(key) - - if param.is_meta: - dtype = sd_param.dtype if sd_param is not None else param.dtype - module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - - for name in module._buffers: - key = prefix + name - - sd_param = sd.pop(key, None) - if sd_param is not None: - state_dict[key] = sd_param - used_param_keys.append(key) - - original(module, state_dict, prefix, *args, **kwargs) - - for key in used_param_keys: - state_dict.pop(key, None) - - # def load_state_dict(original, module, state_dict, strict=True): - def load_state_dict(original, module, state_dict, strict=True): - """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help - because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with - all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. - - In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). - - The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads - the function and does not call the original) the state dict will just fail to load because weights - would be on the meta device. - """ - - if state_dict is sd: - state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - - # ------------------- DIFF HERE ------------------- - # original(module, state_dict, strict=strict) - if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta: - assign = True - else: - assign = False - # orig_func(original, module, state_dict, strict=strict, assign=assign) - original(module, state_dict, strict=strict, assign=assign) - - module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) - module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) - linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) - conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) - mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) - layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) - group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) -# fmt: on - - -CondFunc( - "modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", - onediff_hijaced_LoadStateDictOnMeta___enter__, - lambda _, *args, **kwargs: onediff_enabled, -) -CondFunc( - "modules.sd_models.load_model_weights", - onediff_hijack_load_model_weights, - lambda _, *args, **kwargs: onediff_enabled, -) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index a5dcd563a..8d9e4cf15 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -1,10 +1,7 @@ -from typing import Dict - from compile.onediff_compiled_graph import OneDiffCompiledGraph current_unet_graph = OneDiffCompiledGraph() current_quantization = False -refiner_dict: Dict[str, str] = dict() current_unet_type = { "is_sdxl": False, "is_sd2": False, diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 0ab98eab2..0561469d8 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -6,9 +6,9 @@ import modules.shared as shared import onediff_shared import oneflow as flow -import torch from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph from modules import script_callbacks +from modules.devices import torch_gc from modules.processing import process_images from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack @@ -102,7 +102,7 @@ def run( ): p.override_settings.pop("sd_model_checkpoint", None) sd_models.reload_model_weights() - torch.cuda.empty_cache() + torch_gc() flow.cuda.empty_cache() current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name @@ -137,6 +137,8 @@ def run( with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(), onediff_enabled(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) + torch_gc() + flow.cuda.empty_cache() return proc diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index 9c6d32fdc..0ec72553c 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -1,3 +1,4 @@ +import os import numpy as np import pytest from PIL import Image @@ -89,5 +90,5 @@ def test_onediff_refiner(url_txt2img): } data = {**get_base_args(), **extra_args} # loop 5 times for checking model switching between base and refiner - for _ in range(5): + for _ in range(3): post_request_and_check(url_txt2img, data) From 03b3a89ee357c4b7a8ae4990da962602fa48afcc Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 6 Jun 2024 11:37:45 +0800 Subject: [PATCH 08/24] api test add model --- tests/sd-webui/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/sd-webui/utils.py b/tests/sd-webui/utils.py index 4dc28773b..f0f520f2e 100644 --- a/tests/sd-webui/utils.py +++ b/tests/sd-webui/utils.py @@ -30,6 +30,7 @@ def get_base_args() -> Dict[str, Any]: return { "prompt": "1girl", "negative_prompt": "", + "sd_model_checkpoint": "checkpoints/AWPainting_v1.2.safetensors", "seed": SEED, "steps": NUM_STEPS, "width": WIDTH, From e3acdbb830b47a4982a323f915520d6f33ffabe9 Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 13 Jun 2024 16:48:38 +0800 Subject: [PATCH 09/24] support controlnet unet (controlnet model not supported now) --- .../compile/compile_utils.py | 6 +- .../compile/sd_webui_onediff_utils.py | 25 +- .../onediff_controlnet.py | 1008 +++++++++++++++++ onediff_sd_webui_extensions/onediff_shared.py | 4 + .../{ui_utils.py => onediff_utils.py} | 27 +- .../scripts/onediff.py | 40 +- 6 files changed, 1083 insertions(+), 27 deletions(-) create mode 100644 onediff_sd_webui_extensions/onediff_controlnet.py rename onediff_sd_webui_extensions/{ui_utils.py => onediff_utils.py} (89%) diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 9d39fbc96..451fc26ba 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -61,7 +61,11 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: + diffusion_model = sd_model.model.diffusion_model + # for controlnet + if "forward" in diffusion_model.__dict__: + diffusion_model.__dict__.pop("forward") compiled_unet = compile_unet( - sd_model.model.diffusion_model, quantization=quantization + diffusion_model, quantization=quantization ) return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py index db338fbf1..93aad2f49 100644 --- a/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py +++ b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py @@ -13,17 +13,20 @@ def forward(self, x): # https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/util.py#L207 -def timestep_embedding(timesteps, dim, max_period=10000): - half = dim // 2 - freqs = flow.exp( - -math.log(max_period) - * flow.arange(start=0, end=half, dtype=flow.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = flow.cat([flow.cos(args), flow.sin(args)], dim=-1) - if dim % 2: - embedding = flow.cat([embedding, flow.zeros_like(embedding[:, :1])], dim=-1) +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + if not repeat_only: + half = dim // 2 + freqs = flow.exp( + -math.log(max_period) + * flow.arange(start=0, end=half, dtype=flow.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = flow.cat([flow.cos(args), flow.sin(args)], dim=-1) + if dim % 2: + embedding = flow.cat([embedding, flow.zeros_like(embedding[:, :1])], dim=-1) + else: + raise NotImplementedError("repeat_only=True is not implemented in timestep_embedding") return embedding diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet.py new file mode 100644 index 000000000..6e3899a7d --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet.py @@ -0,0 +1,1008 @@ +import onediff_shared +import oneflow as flow +import torch +import torch as th +from compile import OneDiffCompiledGraph +from compile.sd_webui_onediff_utils import (CrossAttentionOflow, + GroupNorm32Oflow, + timestep_embedding) +from ldm.modules.attention import BasicTransformerBlock, CrossAttention +from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from ldm.modules.diffusionmodules.util import GroupNorm32 +from modules import devices +from modules.sd_hijack_utils import CondFunc +from onediff_utils import singleton_decorator + +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler.backends.oneflow.transform import (proxy_class, + register) + + +def torch_aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + if x == 0.0: + return base + return base + x + + if require_channel_alignment: + zeros = torch.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1: + if base_h != xh or base_w != xw: + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + + return base + x + + +def oneflow_aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + return base + x + + if require_channel_alignment: + zeros = flow.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1 and (base_h != xh or base_w != xw): + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + return base + x + + +cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + + +class TorchOnediffControlNetModel(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.time_embed = unet.time_embed + self.input_blocks = unet.input_blocks + self.label_emb = getattr(unet, "label_emb", None) + self.middle_block = unet.middle_block + self.output_blocks = unet.output_blocks + self.out = unet.out + self.model_channels = unet.model_channels + + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + from ldm.modules.diffusionmodules.util import timestep_embedding + + hs = [] + with th.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + emb = self.time_embed(t_emb) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = torch_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = torch_aligned_adding( + h, total_controlnet_embedding.pop(), require_inpaint_hijack + ) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = torch_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = th.cat( + [ + h, + torch_aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h + + +class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + x = x.half() + if y is not None: + y = y.half() + context = context.half() + hs = [] + with flow.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + emb = self.time_embed(t_emb.half()) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = oneflow_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = oneflow_aligned_adding( + h, total_controlnet_embedding.pop(), require_inpaint_hijack + ) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = oneflow_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = flow.cat( + [ + h, + oneflow_aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = h.half() + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h + + +def compile_controlnet_ldm_unet(sd_model, unet_model, *, options=None): + for module in unet_model.modules(): + if isinstance(module, BasicTransformerBlock): + module.checkpoint = False + if isinstance(module, ResBlock): + module.use_checkpoint = False + # return oneflow_compile(unet_model, options=options) + compiled_model = oneflow_compile(unet_model, options=options) + compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) + compiled_graph.eager_module = unet_model + compiled_graph.name += "_controlnet" + return compiled_graph + + +torch2oflow_class_map = { + CrossAttention: CrossAttentionOflow, + GroupNorm32: GroupNorm32Oflow, + TorchOnediffControlNetModel: OneFlowOnediffControlNetModel, +} +register(package_names=["scripts.hook"], torch2oflow_class_map=torch2oflow_class_map) + + +def check_if_controlnet_ext_loaded() -> bool: + from modules import extensions + + return "sd-webui-controlnet" in extensions.loaded_extensions + + +def hijacked_main_entry(self, p): + self._original_controlnet_main_entry(p) + sd_ldm = p.sd_model + unet = sd_ldm.model.diffusion_model + + if onediff_shared.controlnet_compiled is False: + # if not getattr(self, "compiled", False): + from onediff_controlnet import TorchOnediffControlNetModel + onediff_model = TorchOnediffControlNetModel(unet) + onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( + sd_ldm, onediff_model + ) + onediff_shared.controlnet_compiled = True + else: + pass + + + + +def get_controlnet_script(p): + for script in p.scripts.scripts: + if script.__module__ == "controlnet.py": + return script + return None + + +def check_if_controlnet_enabled(p): + controlnet_script_class = get_controlnet_script(p) + if controlnet_script_class is None: + return False + return len(controlnet_script_class.get_enabled_units(p)) != 0 + + +@singleton_decorator +def create_condfunc(p): + CondFunc( + "scripts.hook.UnetHook.hook", hijacked_hook, lambda _, *arg, **kwargs: True + ) + # get controlnet script + controlnet_script = get_controlnet_script(p) + if controlnet_script is None: + return + + controlnet_script._original_controlnet_main_entry = ( + controlnet_script.controlnet_main_entry + ) + controlnet_script.controlnet_main_entry = hijacked_main_entry.__get__( + controlnet_script + ) + + + +def hijacked_hook( + orig_func, + self, + model, + sd_ldm, + control_params, + process, + batch_option_uint_separate=False, + batch_option_style_align=False, +): + from modules import devices, lowvram, scripts, shared + from scripts.controlnet_sparsectrl import SparseCtrl + from scripts.enums import AutoMachine, ControlModelType, HiResFixOption + from scripts.hook import (AbstractLowScaleModel, blur, mark_prompt_context, + predict_noise_from_start, predict_q_sample, + predict_start_from_noise, register_schedule, + torch_dfs, unmark_prompt_context) + from scripts.ipadapter.ipadapter_model import ImageEmbed + from scripts.logging import logger + + self.model = model + self.sd_ldm = sd_ldm + self.control_params = control_params + + model_is_sdxl = getattr(self.sd_ldm, "is_sdxl", False) + + outer = self + + def process_sample(*args, **kwargs): + # ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt). + # You can use the hook.py's `mark_prompt_context` to mark the prompts that will be seen by ControlNet. + # Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components, + # if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True) + # if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False) + # After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected. + # After you mark the prompts, the mismatch errors will disappear. + mark_prompt_context(kwargs.get("conditioning", []), positive=True) + mark_prompt_context( + kwargs.get("unconditional_conditioning", []), positive=False + ) + mark_prompt_context(getattr(process, "hr_c", []), positive=True) + mark_prompt_context(getattr(process, "hr_uc", []), positive=False) + return process.sample_before_CN_hack(*args, **kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + is_sdxl = y is not None and model_is_sdxl + total_t2i_adapter_embedding = [0.0] * 4 + if is_sdxl: + total_controlnet_embedding = [0.0] * 10 + else: + total_controlnet_embedding = [0.0] * 13 + require_inpaint_hijack = False + is_in_high_res_fix = False + batch_size = int(x.shape[0]) + + # Handle cond-uncond marker + ( + cond_mark, + outer.current_uc_indices, + outer.current_c_indices, + context, + ) = unmark_prompt_context(context) + outer.model.cond_mark = cond_mark + # logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices)) + + # Revision + if is_sdxl: + revision_y1280 = 0 + + for param in outer.control_params: + if param.guidance_stopped: + continue + if param.control_model_type == ControlModelType.ReVision: + if param.vision_hint_count is None: + k = ( + torch.Tensor( + [int(param.preprocessor["threshold_a"] * 1000)] + ) + .to(param.hint_cond) + .long() + .clip(0, 999) + ) + param.vision_hint_count = outer.revision_q_sampler.q_sample( + param.hint_cond, k + ) + revision_emb = param.vision_hint_count + if isinstance(revision_emb, torch.Tensor): + revision_y1280 += revision_emb * param.weight + + if isinstance(revision_y1280, torch.Tensor): + y[:, :1280] = revision_y1280 * cond_mark[:, :, 0, 0] + if any( + "ignore_prompt" in param.preprocessor["name"] + for param in outer.control_params + ) or ( + getattr(process, "prompt", "") == "" + and getattr(process, "negative_prompt", "") == "" + ): + context = torch.zeros_like(context) + + # High-res fix + for param in outer.control_params: + # select which hint_cond to use + if param.used_hint_cond is None: + param.used_hint_cond = param.hint_cond + param.used_hint_cond_latent = None + param.used_hint_inpaint_hijack = None + + # has high-res fix + if ( + isinstance(param.hr_hint_cond, torch.Tensor) + and x.ndim == 4 + and param.hint_cond.ndim == 4 + and param.hr_hint_cond.ndim == 4 + ): + _, _, h_lr, w_lr = param.hint_cond.shape + _, _, h_hr, w_hr = param.hr_hint_cond.shape + _, _, h, w = x.shape + h, w = h * 8, w * 8 + if abs(h - h_lr) < abs(h - h_hr): + is_in_high_res_fix = False + if param.used_hint_cond is not param.hint_cond: + param.used_hint_cond = param.hint_cond + param.used_hint_cond_latent = None + param.used_hint_inpaint_hijack = None + else: + is_in_high_res_fix = True + if param.used_hint_cond is not param.hr_hint_cond: + param.used_hint_cond = param.hr_hint_cond + param.used_hint_cond_latent = None + param.used_hint_inpaint_hijack = None + + self.is_in_high_res_fix = is_in_high_res_fix + outer.is_in_high_res_fix = is_in_high_res_fix + + # Convert control image to latent + for param in outer.control_params: + if param.used_hint_cond_latent is not None: + continue + if ( + param.control_model_type not in [ControlModelType.AttentionInjection] + and "colorfix" not in param.preprocessor["name"] + and "inpaint_only" not in param.preprocessor["name"] + ): + continue + param.used_hint_cond_latent = outer.call_vae_using_process( + process, param.used_hint_cond, batch_size=batch_size + ) + + # vram + for param in outer.control_params: + if getattr(param.control_model, "disable_memory_management", False): + continue + + if param.control_model is not None: + if ( + outer.lowvram + and is_sdxl + and hasattr(param.control_model, "aggressive_lowvram") + ): + param.control_model.aggressive_lowvram() + elif hasattr(param.control_model, "fullvram"): + param.control_model.fullvram() + elif hasattr(param.control_model, "to"): + param.control_model.to(devices.get_device_for("controlnet")) + + # handle prompt token control + for param in outer.control_params: + if param.guidance_stopped or param.disabled_by_hr_option( + self.is_in_high_res_fix + ): + continue + + if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]: + continue + + control = param.control_model( + x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context + ) + control = torch.cat([control.clone() for _ in range(batch_size)], dim=0) + control *= param.weight + control *= cond_mark[:, :, :, 0] + context = torch.cat([context, control.clone()], dim=1) + + # handle ControlNet / T2I_Adapter + for param_index, param in enumerate(outer.control_params): + if param.guidance_stopped or param.disabled_by_hr_option( + self.is_in_high_res_fix + ): + continue + + if not ( + param.control_model_type.is_controlnet + or param.control_model_type == ControlModelType.T2I_Adapter + ): + continue + + # inpaint model workaround + x_in = x + control_model = param.control_model.control_model + + if param.control_model_type.is_controlnet: + if ( + x.shape[1] != control_model.input_blocks[0][0].in_channels + and x.shape[1] == 9 + ): + # inpaint_model: 4 data + 4 downscaled image + 1 mask + x_in = x[:, :4, ...] + require_inpaint_hijack = True + + assert ( + param.used_hint_cond is not None + ), "Controlnet is enabled but no input image is given" + + hint = param.used_hint_cond + if param.control_model_type == ControlModelType.InstantID: + assert isinstance(param.control_context_override, ImageEmbed) + controlnet_context = param.control_context_override.eval(cond_mark).to( + x.device, dtype=x.dtype + ) + else: + controlnet_context = context + + # ControlNet inpaint protocol + if hint.shape[1] == 4 and not isinstance(control_model, SparseCtrl): + c = hint[:, 0:3, :, :] + m = hint[:, 3:4, :, :] + m = (m > 0.5).float() + hint = c * (1 - m) - m + + control = param.control_model( + x=x_in, hint=hint, timesteps=timesteps, context=controlnet_context, y=y + ) + + if is_sdxl: + control_scales = [param.weight] * 10 + else: + control_scales = [param.weight] * 13 + + if param.cfg_injection or param.global_average_pooling: + if param.control_model_type == ControlModelType.T2I_Adapter: + control = [ + torch.cat([c.clone() for _ in range(batch_size)], dim=0) + for c in control + ] + control = [c * cond_mark for c in control] + + high_res_fix_forced_soft_injection = False + + if is_in_high_res_fix: + if "canny" in param.preprocessor["name"]: + high_res_fix_forced_soft_injection = True + if "mlsd" in param.preprocessor["name"]: + high_res_fix_forced_soft_injection = True + + if param.soft_injection or high_res_fix_forced_soft_injection: + # important! use the soft weights with high-res fix can significantly reduce artifacts. + if param.control_model_type == ControlModelType.T2I_Adapter: + control_scales = [ + param.weight * x for x in (0.25, 0.62, 0.825, 1.0) + ] + elif param.control_model_type.is_controlnet: + control_scales = [ + param.weight * (0.825 ** float(12 - i)) for i in range(13) + ] + + if is_sdxl and param.control_model_type.is_controlnet: + control_scales = control_scales[:10] + + if param.advanced_weighting is not None: + logger.info(f"Advanced weighting enabled. {param.advanced_weighting}") + if param.soft_injection or high_res_fix_forced_soft_injection: + logger.warn("Advanced weighting overwrites soft_injection effect.") + control_scales = param.advanced_weighting + + control = [ + param.apply_effective_region_mask(c * scale) + for c, scale in zip(control, control_scales) + ] + if param.global_average_pooling: + control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control] + + for idx, item in enumerate(control): + target = None + if param.control_model_type.is_controlnet: + target = total_controlnet_embedding + if param.control_model_type == ControlModelType.T2I_Adapter: + target = total_t2i_adapter_embedding + if target is not None: + if batch_option_uint_separate: + for pi, ci in enumerate(outer.current_c_indices): + if pi % len(outer.control_params) != param_index: + item[ci] = 0 + for pi, ci in enumerate(outer.current_uc_indices): + if pi % len(outer.control_params) != param_index: + item[ci] = 0 + target[idx] = item + target[idx] + else: + target[idx] = item + target[idx] + + # Replace x_t to support inpaint models + for param in outer.control_params: + if not isinstance(param.used_hint_cond, torch.Tensor): + continue + if param.used_hint_cond.ndim < 2 or param.used_hint_cond.shape[1] != 4: + continue + if x.shape[1] != 9: + continue + if param.used_hint_inpaint_hijack is None: + mask_pixel = param.used_hint_cond[:, 3:4, :, :] + image_pixel = param.used_hint_cond[:, 0:3, :, :] + mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype) + masked_latent = outer.call_vae_using_process( + process, image_pixel, batch_size, mask=mask_pixel + ) + mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8)) + if mask_latent.shape[0] != batch_size: + mask_latent = torch.cat( + [mask_latent.clone() for _ in range(batch_size)], dim=0 + ) + param.used_hint_inpaint_hijack = torch.cat( + [mask_latent, masked_latent], dim=1 + ) + param.used_hint_inpaint_hijack.to(x.dtype).to(x.device) + x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1) + + # vram + for param in outer.control_params: + if param.control_model is not None: + if outer.lowvram: + param.control_model.to("cpu") + + # A1111 fix for medvram. + if shared.cmd_opts.medvram or ( + getattr(shared.cmd_opts, "medvram_sdxl", False) and is_sdxl + ): + try: + # Trigger the register_forward_pre_hook + outer.sd_ldm.model() + except Exception as e: + logger.debug("register_forward_pre_hook") + logger.debug(e) + + # Clear attention and AdaIn cache + for module in outer.attn_module_list: + module.bank = [] + module.style_cfgs = [] + for module in outer.gn_module_list: + module.mean_bank = [] + module.var_bank = [] + module.style_cfgs = [] + + # Handle attention and AdaIn control + for param in outer.control_params: + if param.guidance_stopped or param.disabled_by_hr_option( + self.is_in_high_res_fix + ): + continue + + if param.used_hint_cond_latent is None: + continue + + if param.control_model_type not in [ControlModelType.AttentionInjection]: + continue + + ref_xt = predict_q_sample( + outer.sd_ldm, + param.used_hint_cond_latent, + torch.round(timesteps.float()).long(), + ) + + # Inpaint Hijack + if x.shape[1] == 9: + ref_xt = torch.cat( + [ + ref_xt, + torch.zeros_like(ref_xt)[:, 0:1, :, :], + param.used_hint_cond_latent, + ], + dim=1, + ) + + outer.current_style_fidelity = float(param.preprocessor["threshold_a"]) + outer.current_style_fidelity = max( + 0.0, min(1.0, outer.current_style_fidelity) + ) + + if is_sdxl: + # sdxl's attention hacking is highly unstable. + # We have no other methods but to reduce the style_fidelity a bit. + # By default, 0.5 ** 3.0 = 0.125 + outer.current_style_fidelity = outer.current_style_fidelity ** 3.0 + + if param.cfg_injection: + outer.current_style_fidelity = 1.0 + elif param.soft_injection or is_in_high_res_fix: + outer.current_style_fidelity = 0.0 + + control_name = param.preprocessor["name"] + + if control_name in ["reference_only", "reference_adain+attn"]: + outer.attention_auto_machine = AutoMachine.Write + outer.attention_auto_machine_weight = param.weight + + if control_name in ["reference_adain", "reference_adain+attn"]: + outer.gn_auto_machine = AutoMachine.Write + outer.gn_auto_machine_weight = param.weight + + if is_sdxl: + outer.original_forward( + x=ref_xt.to(devices.dtype_unet), + timesteps=timesteps.to(devices.dtype_unet), + context=context.to(devices.dtype_unet), + y=y, + ) + else: + outer.original_forward( + x=ref_xt.to(devices.dtype_unet), + timesteps=timesteps.to(devices.dtype_unet), + context=context.to(devices.dtype_unet), + ) + + outer.attention_auto_machine = AutoMachine.Read + outer.gn_auto_machine = AutoMachine.Read + + h = onediff_shared.current_unet_graph.graph_module( + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ) + + # Post-processing for color fix + for param in outer.control_params: + if param.used_hint_cond_latent is None: + continue + if "colorfix" not in param.preprocessor["name"]: + continue + + k = int(param.preprocessor["threshold_a"]) + if is_in_high_res_fix and not param.disabled_by_hr_option( + self.is_in_high_res_fix + ): + k *= 2 + + # Inpaint hijack + xt = x[:, :4, :, :] + + x0_origin = param.used_hint_cond_latent + t = torch.round(timesteps.float()).long() + x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) + x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k) + + if "+sharp" in param.preprocessor["name"]: + detail_weight = float(param.preprocessor["threshold_b"]) * 0.01 + neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0 + x0 = cond_mark * x0 + (1 - cond_mark) * neg + + eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) + + w = max(0.0, min(1.0, float(param.weight))) + h = eps_prd * w + h * (1 - w) + + # Post-processing for restore + for param in outer.control_params: + if param.used_hint_cond_latent is None: + continue + if "inpaint_only" not in param.preprocessor["name"]: + continue + if param.used_hint_cond.shape[1] != 4: + continue + + # Inpaint hijack + xt = x[:, :4, :, :] + + mask = param.used_hint_cond[:, 3:4, :, :] + mask = torch.nn.functional.max_pool2d( + mask, (10, 10), stride=(8, 8), padding=1 + ) + + x0_origin = param.used_hint_cond_latent + t = torch.round(timesteps.float()).long() + x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) + x0 = x0_prd * mask + x0_origin * (1 - mask) + eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) + + w = max(0.0, min(1.0, float(param.weight))) + h = eps_prd * w + h * (1 - w) + + return h + + def move_all_control_model_to_cpu(): + for param in getattr(outer, "control_params", []) or []: + if isinstance(param.control_model, torch.nn.Module): + param.control_model.to("cpu") + + def forward_webui(*args, **kwargs): + forward_func = None + if "forward" in onediff_shared.current_unet_graph.graph_module._torch_module.__dict__: + forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop("forward") + _original_forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop("_original_forward") + # webui will handle other compoments + try: + if shared.cmd_opts.lowvram: + lowvram.send_everything_to_cpu() + return forward(*args, **kwargs) + except Exception as e: + move_all_control_model_to_cpu() + raise e + finally: + if outer.lowvram: + move_all_control_model_to_cpu() + if forward_func is not None: + onediff_shared.current_unet_graph.graph_module._torch_module.forward = forward_func + onediff_shared.current_unet_graph.graph_module._torch_module._original_forward = _original_forward_func + + def hacked_basic_transformer_inner_forward(self, x, context=None): + x_norm1 = self.norm1(x) + self_attn1 = None + if self.disable_self_attn: + # Do not use self-attention + self_attn1 = self.attn1(x_norm1, context=context) + else: + # Use self-attention + self_attention_context = x_norm1 + if outer.attention_auto_machine == AutoMachine.Write: + if outer.attention_auto_machine_weight > self.attn_weight: + self.bank.append(self_attention_context.detach().clone()) + self.style_cfgs.append(outer.current_style_fidelity) + if outer.attention_auto_machine == AutoMachine.Read: + if len(self.bank) > 0: + style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) + self_attn1_uc = self.attn1( + x_norm1, + context=torch.cat([self_attention_context] + self.bank, dim=1), + ) + self_attn1_c = self_attn1_uc.clone() + if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: + self_attn1_c[outer.current_uc_indices] = self.attn1( + x_norm1[outer.current_uc_indices], + context=self_attention_context[outer.current_uc_indices], + ) + self_attn1 = ( + style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc + ) + self.bank = [] + self.style_cfgs = [] + if ( + outer.attention_auto_machine == AutoMachine.StyleAlign + and not outer.is_in_high_res_fix + ): + # very VRAM hungry - disable at high_res_fix + + def shared_attn1(inner_x): + BB, FF, CC = inner_x.shape + return self.attn1(inner_x.reshape(1, BB * FF, CC)).reshape( + BB, FF, CC + ) + + uc_layer = shared_attn1(x_norm1[outer.current_uc_indices]) + c_layer = shared_attn1(x_norm1[outer.current_c_indices]) + self_attn1 = torch.zeros_like(x_norm1).to(uc_layer) + self_attn1[outer.current_uc_indices] = uc_layer + self_attn1[outer.current_c_indices] = c_layer + del uc_layer, c_layer + if self_attn1 is None: + self_attn1 = self.attn1(x_norm1, context=self_attention_context) + + x = self_attn1.to(x.dtype) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + def hacked_group_norm_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward_cn_hijack(*args, **kwargs) + y = None + if outer.gn_auto_machine == AutoMachine.Write: + if outer.gn_auto_machine_weight > self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + self.style_cfgs.append(outer.current_style_fidelity) + if outer.gn_auto_machine == AutoMachine.Read: + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + y_uc = (((x - mean) / std) * std_acc) + mean_acc + y_c = y_uc.clone() + if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: + y_c[outer.current_uc_indices] = x.to(y_c.dtype)[ + outer.current_uc_indices + ] + y = style_cfg * y_c + (1.0 - style_cfg) * y_uc + self.mean_bank = [] + self.var_bank = [] + self.style_cfgs = [] + if y is None: + y = x + return y.to(x.dtype) + + if getattr(process, "sample_before_CN_hack", None) is None: + process.sample_before_CN_hack = process.sample + process.sample = process_sample + + model._original_forward = model.forward + outer.original_forward = model.forward + model.forward = forward_webui.__get__(model, UNetModel) + + if model_is_sdxl: + register_schedule(sd_ldm) + outer.revision_q_sampler = AbstractLowScaleModel() + + need_attention_hijack = False + + for param in outer.control_params: + if param.control_model_type in [ControlModelType.AttentionInjection]: + need_attention_hijack = True + + if batch_option_style_align: + need_attention_hijack = True + outer.attention_auto_machine = AutoMachine.StyleAlign + outer.gn_auto_machine = AutoMachine.StyleAlign + + all_modules = torch_dfs(model) + + if need_attention_hijack: + attn_modules = [ + module + for module in all_modules + if isinstance(module, BasicTransformerBlock) + or isinstance(module, BasicTransformerBlockSGM) + ] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + if getattr(module, "_original_inner_forward_cn_hijack", None) is None: + module._original_inner_forward_cn_hijack = module._forward + module._forward = hacked_basic_transformer_inner_forward.__get__( + module, BasicTransformerBlock + ) + module.bank = [] + module.style_cfgs = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + gn_modules = [model.middle_block] + model.middle_block.gn_weight = 0 + + if model_is_sdxl: + input_block_indices = [4, 5, 7, 8] + output_block_indices = [0, 1, 2, 3, 4, 5] + else: + input_block_indices = [4, 5, 7, 8, 10, 11] + output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] + + for w, i in enumerate(input_block_indices): + module = model.input_blocks[i] + module.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) + gn_modules.append(module) + + for w, i in enumerate(output_block_indices): + module = model.output_blocks[i] + module.gn_weight = float(w) / float(len(output_block_indices)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward_cn_hijack", None) is None: + module.original_forward_cn_hijack = module.forward + module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module) + module.mean_bank = [] + module.var_bank = [] + module.style_cfgs = [] + module.gn_weight *= 2 + + outer.attn_module_list = attn_modules + outer.gn_module_list = gn_modules + else: + for module in all_modules: + _original_inner_forward_cn_hijack = getattr( + module, "_original_inner_forward_cn_hijack", None + ) + original_forward_cn_hijack = getattr( + module, "original_forward_cn_hijack", None + ) + if _original_inner_forward_cn_hijack is not None: + module._forward = _original_inner_forward_cn_hijack + if original_forward_cn_hijack is not None: + module.forward = original_forward_cn_hijack + outer.attn_module_list = [] + outer.gn_module_list = [] + + scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index 8d9e4cf15..e06a51b24 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -9,3 +9,7 @@ "is_ssd": False, } onediff_enabled = False + +# controlnet +controlnet_compiled = False +current_is_controlnet = False diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/onediff_utils.py similarity index 89% rename from onediff_sd_webui_extensions/ui_utils.py rename to onediff_sd_webui_extensions/onediff_utils.py index bdb875a38..441bcdfc7 100644 --- a/onediff_sd_webui_extensions/ui_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -1,4 +1,5 @@ import os +from functools import wraps from contextlib import contextmanager from pathlib import Path from textwrap import dedent @@ -119,10 +120,22 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): compiled_unet.save_graph(saved_cache_name) -@contextmanager -def onediff_enabled(): - onediff_shared.onediff_enabled = True - try: - yield - finally: - onediff_shared.onediff_enabled = False +def onediff_enabled_decorator(func): + @wraps(func) + def wrapper(*arg, **kwargs): + onediff_shared.onediff_enabled = True + try: + return func(*arg, **kwargs) + finally: + onediff_shared.onediff_enabled = False + return wrapper + + +def singleton_decorator(func): + has_been_called = False + def wrapper(*args, **kwargs): + nonlocal has_been_called + if not has_been_called: + has_been_called = True + return func(*args, **kwargs) + return wrapper diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 0561469d8..5e7f23513 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -13,14 +13,15 @@ from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate -from ui_utils import ( +import onediff_controlnet +from onediff_utils import ( check_structure_change_and_update, get_all_compiler_caches, hints_message, load_graph, - onediff_enabled, refresh_all_compiler_caches, save_graph, + onediff_enabled_decorator, ) from onediff.optimization.quant_optimizer import varify_can_use_quantization @@ -34,19 +35,25 @@ class UnetCompileCtx(object): The global variables need to be replaced with compiled_unet before process_images is run, and then the original model restored so that subsequent reasoning with onediff disabled meets expectations. """ + def __init__(self, enabled): + self.enabled = enabled def __enter__(self): + if not self.enabled: + return self._original_model = shared.sd_model.model.diffusion_model shared.sd_model.model.diffusion_model = ( onediff_shared.current_unet_graph.graph_module ) def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return shared.sd_model.model.diffusion_model = self._original_model - return False class Script(scripts.Script): + def title(self): return "onediff_diffusion_model" @@ -85,6 +92,7 @@ def ui(self, is_img2img): def show(self, is_img2img): return True + @onediff_enabled_decorator def run( self, p, @@ -93,6 +101,10 @@ def run( saved_cache_name="", always_recompile=False, ): + controlnet_enabled = onediff_controlnet.check_if_controlnet_enabled(p) + if controlnet_enabled: + onediff_controlnet.create_condfunc(p) + # restore checkpoint_info from refiner to base model if necessary if ( sd_models.checkpoint_aliases.get( @@ -116,27 +128,39 @@ def run( quantization_changed = ( quantization != onediff_shared.current_unet_graph.quantized ) + controlnet_enabled_status_changed = ( + controlnet_enabled != onediff_shared.current_is_controlnet + ) need_recompile = ( ( quantization and ckpt_changed ) # always recompile when switching ckpt with 'int8 speed model' enabled or structure_changed # always recompile when switching model to another structure or quantization_changed # always recompile when switching model from non-quantized to quantized (and vice versa) + or controlnet_enabled_status_changed or always_recompile ) if need_recompile: - onediff_shared.current_unet_graph = get_compiled_graph( - shared.sd_model, quantization - ) - load_graph(onediff_shared.current_unet_graph, compiler_cache) + if not controlnet_enabled: + onediff_shared.current_unet_graph = get_compiled_graph( + shared.sd_model, quantization + ) + load_graph(onediff_shared.current_unet_graph, compiler_cache) else: logger.info( f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.current_unet_type}, skip compile" ) - with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(), onediff_enabled(): + with UnetCompileCtx(not controlnet_enabled), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) + + if controlnet_enabled: + onediff_shared.current_is_controlnet = True + else: + onediff_shared.controlnet_compiled = False + onediff_shared.current_is_controlnet = False + torch_gc() flow.cuda.empty_cache() return proc From b66fed59c4d7bc6bce49a3d0be6d50136633788e Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 17 Jun 2024 18:30:43 +0800 Subject: [PATCH 10/24] refine --- .../compile/compile_utils.py | 4 +- .../compile/sd_webui_onediff_utils.py | 4 +- .../onediff_controlnet.py | 113 ++++++++++++++---- onediff_sd_webui_extensions/onediff_shared.py | 3 +- onediff_sd_webui_extensions/onediff_utils.py | 13 +- .../scripts/onediff.py | 26 ++-- tests/sd-webui/test_api.py | 1 - tests/sd-webui/utils.py | 1 - 8 files changed, 113 insertions(+), 52 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 451fc26ba..d79278be2 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -65,7 +65,5 @@ def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: # for controlnet if "forward" in diffusion_model.__dict__: diffusion_model.__dict__.pop("forward") - compiled_unet = compile_unet( - diffusion_model, quantization=quantization - ) + compiled_unet = compile_unet(diffusion_model, quantization=quantization) return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py index 93aad2f49..c77f5c3d1 100644 --- a/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py +++ b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py @@ -26,7 +26,9 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if dim % 2: embedding = flow.cat([embedding, flow.zeros_like(embedding[:, :1])], dim=-1) else: - raise NotImplementedError("repeat_only=True is not implemented in timestep_embedding") + raise NotImplementedError( + "repeat_only=True is not implemented in timestep_embedding" + ) return embedding diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet.py index 6e3899a7d..fa25c8fbb 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet.py @@ -1,11 +1,15 @@ +from functools import wraps + import onediff_shared import oneflow as flow import torch import torch as th from compile import OneDiffCompiledGraph -from compile.sd_webui_onediff_utils import (CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding) +from compile.sd_webui_onediff_utils import ( + CrossAttentionOflow, + GroupNorm32Oflow, + timestep_embedding, +) from ldm.modules.attention import BasicTransformerBlock, CrossAttention from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 @@ -14,10 +18,10 @@ from onediff_utils import singleton_decorator from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import (proxy_class, - register) +from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 def torch_aligned_adding(base, x, require_channel_alignment): if isinstance(x, float): if x == 0.0: @@ -41,8 +45,12 @@ def torch_aligned_adding(base, x, require_channel_alignment): return base + x +# Due to the tracing mechanism in OneFlow, it's crucial to ensure that +# the same conditional branches are taken during the first run as in subsequent runs. +# Therefore, certain "optimizations" have been modified. def oneflow_aligned_adding(base, x, require_channel_alignment): if isinstance(x, float): + # remove `if x == 0.0: return base` here return base + x if require_channel_alignment: @@ -226,13 +234,31 @@ def forward( return h +def onediff_controlnet_decorator(func): + @wraps(func) + def wrapper(self, p, *arg, **kwargs): + try: + onediff_shared.controlnet_enabled = check_if_controlnet_enabled(p) + if onediff_shared.controlnet_enabled: + hijack_controlnet_extension(p) + return func(self, p, *arg, **kwargs) + finally: + if onediff_shared.controlnet_enabled: + onediff_shared.previous_is_controlnet = True + else: + onediff_shared.controlnet_compiled = False + onediff_shared.previous_is_controlnet = False + + return wrapper + + def compile_controlnet_ldm_unet(sd_model, unet_model, *, options=None): for module in unet_model.modules(): if isinstance(module, BasicTransformerBlock): module.checkpoint = False if isinstance(module, ResBlock): module.use_checkpoint = False - # return oneflow_compile(unet_model, options=options) + # TODO: refine here compiled_model = oneflow_compile(unet_model, options=options) compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) compiled_graph.eager_module = unet_model @@ -260,8 +286,6 @@ def hijacked_main_entry(self, p): unet = sd_ldm.model.diffusion_model if onediff_shared.controlnet_compiled is False: - # if not getattr(self, "compiled", False): - from onediff_controlnet import TorchOnediffControlNetModel onediff_model = TorchOnediffControlNetModel(unet) onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( sd_ldm, onediff_model @@ -271,8 +295,6 @@ def hijacked_main_entry(self, p): pass - - def get_controlnet_script(p): for script in p.scripts.scripts: if script.__module__ == "controlnet.py": @@ -282,15 +304,21 @@ def get_controlnet_script(p): def check_if_controlnet_enabled(p): controlnet_script_class = get_controlnet_script(p) - if controlnet_script_class is None: - return False - return len(controlnet_script_class.get_enabled_units(p)) != 0 + return ( + controlnet_script_class is not None + and len(controlnet_script_class.get_enabled_units(p)) != 0 + ) +# When OneDiff is initializing, the controlnet extension has not yet been loaded. +# Therefore, this function should be called during image generation +# rather than during the initialization of the OneDiff. @singleton_decorator -def create_condfunc(p): +def hijack_controlnet_extension(p): CondFunc( - "scripts.hook.UnetHook.hook", hijacked_hook, lambda _, *arg, **kwargs: True + "scripts.hook.UnetHook.hook", + hijacked_controlnet_hook, + lambda _, *arg, **kwargs: onediff_shared.onediff_enabled, ) # get controlnet script controlnet_script = get_controlnet_script(p) @@ -305,8 +333,18 @@ def create_condfunc(p): ) +# We were intended to only hack the closure function `forward` +# in the member function `hook` of the UnetHook class in the ControlNet extension. +# But due to certain limitations, we were unable to directly only hack +# the closure function `forward` within the `hook` method. +# So we have to hack the entire member function `hook` in the `UnetHook` class. -def hijacked_hook( +# The function largely retains its original content, +# with modifications specifically made within the `forward` function. +# To identify the altered parts, you can search for the tag "modified by OneDiff" + +# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L442 +def hijacked_controlnet_hook( orig_func, self, model, @@ -319,10 +357,17 @@ def hijacked_hook( from modules import devices, lowvram, scripts, shared from scripts.controlnet_sparsectrl import SparseCtrl from scripts.enums import AutoMachine, ControlModelType, HiResFixOption - from scripts.hook import (AbstractLowScaleModel, blur, mark_prompt_context, - predict_noise_from_start, predict_q_sample, - predict_start_from_noise, register_schedule, - torch_dfs, unmark_prompt_context) + from scripts.hook import ( + AbstractLowScaleModel, + blur, + mark_prompt_context, + predict_noise_from_start, + predict_q_sample, + predict_start_from_noise, + register_schedule, + torch_dfs, + unmark_prompt_context, + ) from scripts.ipadapter.ipadapter_model import ImageEmbed from scripts.logging import logger @@ -731,6 +776,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): outer.attention_auto_machine = AutoMachine.Read outer.gn_auto_machine = AutoMachine.Read + # modified by OneDiff h = onediff_shared.current_unet_graph.graph_module( x, timesteps, @@ -807,10 +853,20 @@ def move_all_control_model_to_cpu(): param.control_model.to("cpu") def forward_webui(*args, **kwargs): + # ------ modified by OneDiff below ------ forward_func = None - if "forward" in onediff_shared.current_unet_graph.graph_module._torch_module.__dict__: - forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop("forward") - _original_forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop("_original_forward") + if ( + "forward" + in onediff_shared.current_unet_graph.graph_module._torch_module.__dict__ + ): + forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( + "forward" + ) + _original_forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( + "_original_forward" + ) + # ------ modified by OneDiff above ------ + # webui will handle other compoments try: if shared.cmd_opts.lowvram: @@ -822,9 +878,16 @@ def forward_webui(*args, **kwargs): finally: if outer.lowvram: move_all_control_model_to_cpu() + + # ------ modified by OneDiff below ------ if forward_func is not None: - onediff_shared.current_unet_graph.graph_module._torch_module.forward = forward_func - onediff_shared.current_unet_graph.graph_module._torch_module._original_forward = _original_forward_func + onediff_shared.current_unet_graph.graph_module._torch_module.forward = ( + forward_func + ) + onediff_shared.current_unet_graph.graph_module._torch_module._original_forward = ( + _original_forward_func + ) + # ------ modified by OneDiff above ------ def hacked_basic_transformer_inner_forward(self, x, context=None): x_norm1 = self.norm1(x) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index e06a51b24..75da3d953 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -11,5 +11,6 @@ onediff_enabled = False # controlnet +controlnet_enabled = False controlnet_compiled = False -current_is_controlnet = False +previous_is_controlnet = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 441bcdfc7..1ea53bebe 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -1,11 +1,13 @@ import os -from functools import wraps from contextlib import contextmanager +from functools import wraps from pathlib import Path from textwrap import dedent from zipfile import BadZipFile import onediff_shared +import oneflow as flow +from modules.devices import torch_gc from onediff.infer_compiler import DeployableModule @@ -122,20 +124,25 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): def onediff_enabled_decorator(func): @wraps(func) - def wrapper(*arg, **kwargs): + def wrapper(self, p, *arg, **kwargs): onediff_shared.onediff_enabled = True try: - return func(*arg, **kwargs) + return func(self, p, *arg, **kwargs) finally: onediff_shared.onediff_enabled = False + torch_gc() + flow.cuda.empty_cache() + return wrapper def singleton_decorator(func): has_been_called = False + def wrapper(*args, **kwargs): nonlocal has_been_called if not has_been_called: has_been_called = True return func(*args, **kwargs) + return wrapper diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 5aa866cfe..180b31040 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -4,6 +4,7 @@ import modules.scripts as scripts import modules.sd_models as sd_models import modules.shared as shared +import onediff_controlnet import onediff_shared import oneflow as flow from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph @@ -13,15 +14,14 @@ from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate -import onediff_controlnet from onediff_utils import ( check_structure_change_and_update, get_all_compiler_caches, hints_message, load_graph, + onediff_enabled_decorator, refresh_all_compiler_caches, save_graph, - onediff_enabled_decorator, ) from onediff.optimization.quant_optimizer import varify_can_use_quantization @@ -35,6 +35,7 @@ class UnetCompileCtx(object): The global variables need to be replaced with compiled_unet before process_images is run, and then the original model restored so that subsequent reasoning with onediff disabled meets expectations. """ + def __init__(self, enabled): self.enabled = enabled @@ -92,6 +93,7 @@ def show(self, is_img2img): return True @onediff_enabled_decorator + @onediff_controlnet.onediff_controlnet_decorator def run( self, p, @@ -100,10 +102,6 @@ def run( saved_cache_name="", always_recompile=False, ): - controlnet_enabled = onediff_controlnet.check_if_controlnet_enabled(p) - if controlnet_enabled: - onediff_controlnet.create_condfunc(p) - # restore checkpoint_info from refiner to base model if necessary if ( sd_models.checkpoint_aliases.get( @@ -128,7 +126,7 @@ def run( quantization != onediff_shared.current_unet_graph.quantized ) controlnet_enabled_status_changed = ( - controlnet_enabled != onediff_shared.current_is_controlnet + onediff_shared.controlnet_enabled != onediff_shared.previous_is_controlnet ) need_recompile = ( ( @@ -140,7 +138,7 @@ def run( or always_recompile ) if need_recompile: - if not controlnet_enabled: + if not onediff_shared.controlnet_enabled: onediff_shared.current_unet_graph = get_compiled_graph( shared.sd_model, quantization ) @@ -150,18 +148,12 @@ def run( f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.current_unet_type}, skip compile" ) - with UnetCompileCtx(not controlnet_enabled), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + with UnetCompileCtx( + not onediff_shared.controlnet_enabled + ), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) - if controlnet_enabled: - onediff_shared.current_is_controlnet = True - else: - onediff_shared.controlnet_compiled = False - onediff_shared.current_is_controlnet = False - - torch_gc() - flow.cuda.empty_cache() return proc diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index accd2036d..fa7550abe 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -21,7 +21,6 @@ get_threshold, ) -THRESHOLD = 0.97 @pytest.fixture(scope="session", autouse=True) def change_model(): diff --git a/tests/sd-webui/utils.py b/tests/sd-webui/utils.py index 658829571..3a1bbaedd 100644 --- a/tests/sd-webui/utils.py +++ b/tests/sd-webui/utils.py @@ -30,7 +30,6 @@ def get_base_args() -> Dict[str, Any]: return { "prompt": "1girl", "negative_prompt": "", - "sd_model_checkpoint": "checkpoints/AWPainting_v1.2.safetensors", "seed": SEED, "steps": NUM_STEPS, "width": WIDTH, From e35667280fc2bb488055258f37fabd110b17197e Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 18 Jun 2024 11:23:49 +0800 Subject: [PATCH 11/24] support recompile when switching model --- onediff_sd_webui_extensions/onediff_controlnet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet.py index fa25c8fbb..13b8100a2 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet.py @@ -1,6 +1,7 @@ from functools import wraps import onediff_shared +from onediff_utils import check_structure_change_and_update import oneflow as flow import torch import torch as th @@ -285,7 +286,10 @@ def hijacked_main_entry(self, p): sd_ldm = p.sd_model unet = sd_ldm.model.diffusion_model - if onediff_shared.controlnet_compiled is False: + structure_changed = check_structure_change_and_update( + onediff_shared.current_unet_type, sd_ldm + ) + if onediff_shared.controlnet_compiled is False or structure_changed: onediff_model = TorchOnediffControlNetModel(unet) onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( sd_ldm, onediff_model From 7674aff212cd877212c0a4be1b4be4d9b035eac7 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 19 Jun 2024 15:02:02 +0800 Subject: [PATCH 12/24] support nexfort --- .../compile/compile_nexfort_backend.py | 64 +++++++++++++++++++ .../compile/compile_utils.py | 20 ++++-- .../onediff_controlnet.py | 3 +- 3 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/compile_nexfort_backend.py diff --git a/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py b/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py new file mode 100644 index 000000000..85eee1441 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py @@ -0,0 +1,64 @@ +import onediff_shared +import torch + +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from ldm.modules.diffusionmodules.util import timestep_embedding +from modules.sd_hijack_utils import CondFunc +from onediff_utils import singleton_decorator + +from onediff.infer_compiler import compile + + +def nexfort_compile_ldm_unet(unet_model, *, options=None): + create_cond_func() + if not isinstance(unet_model, UNetModel): + return + for module in unet_model.modules(): + if isinstance(module, BasicTransformerBlock): + module.checkpoint = False + if isinstance(module, ResBlock): + module.use_checkpoint = False + unet_model.convert_to_fp16() + return compile(unet_model, backend="nexfort", options=options) + + +@torch.autocast("cuda", enabled=False) +def onediff_nexfort_unet_ldm_forward( + orig_func, self, x, timesteps=None, context=None, y=None, **kwargs +): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels).half() + emb = self.time_embed(t_emb) + x = x.half() + context = context.half() + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +@singleton_decorator +def create_cond_func(): + CondFunc( + "ldm.modules.diffusionmodules.openaimodel.UNetModel.forward", + onediff_nexfort_unet_ldm_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + + +# def init_cond_func(): diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index d79278be2..9303527f2 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -60,10 +60,20 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: return calibrate_info -def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: +def get_compiled_graph( + sd_model, quantization, backend="nexfort" +) -> OneDiffCompiledGraph: diffusion_model = sd_model.model.diffusion_model - # for controlnet - if "forward" in diffusion_model.__dict__: - diffusion_model.__dict__.pop("forward") - compiled_unet = compile_unet(diffusion_model, quantization=quantization) + + if backend == "oneflow": + # for controlnet + if "forward" in diffusion_model.__dict__: + diffusion_model.__dict__.pop("forward") + compiled_unet = compile_unet(diffusion_model, quantization=quantization) + elif backend == "nexfort": + from .compile_nexfort_backend import nexfort_compile_ldm_unet + + compiled_unet = nexfort_compile_ldm_unet(diffusion_model) + else: + raise NotImplementedError return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet.py index 13b8100a2..d2faecceb 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet.py @@ -1,7 +1,6 @@ from functools import wraps import onediff_shared -from onediff_utils import check_structure_change_and_update import oneflow as flow import torch import torch as th @@ -16,7 +15,7 @@ from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import devices from modules.sd_hijack_utils import CondFunc -from onediff_utils import singleton_decorator +from onediff_utils import check_structure_change_and_update, singleton_decorator from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register From 7a24a0279ac943e3ffa9238265bf2f15b6b3cf5a Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 20 Jun 2024 16:16:05 +0800 Subject: [PATCH 13/24] use torch functional sdpa --- onediff_sd_webui_extensions/compile/compile_nexfort_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py b/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py index 85eee1441..a068fa099 100644 --- a/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py +++ b/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py @@ -5,6 +5,7 @@ from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import timestep_embedding from modules.sd_hijack_utils import CondFunc +from modules.sd_hijack import apply_optimizations from onediff_utils import singleton_decorator from onediff.infer_compiler import compile @@ -12,6 +13,7 @@ def nexfort_compile_ldm_unet(unet_model, *, options=None): create_cond_func() + apply_optimizations("sdp-no-mem - scaled dot product without memory efficient attention") if not isinstance(unet_model, UNetModel): return for module in unet_model.modules(): From e18b54a3d08cb50f1dccabd6d07b58994fbd9f70 Mon Sep 17 00:00:00 2001 From: WangYi Date: Fri, 21 Jun 2024 01:36:11 +0800 Subject: [PATCH 14/24] refactor compile --- .../compile/__init__.py | 8 +- .../compile/compile.py | 44 +++++ .../compile/compile_nexfort_backend.py | 66 -------- .../compile/compile_utils.py | 134 +++++++-------- .../compile/nexfort/compile.py | 14 ++ .../compile/nexfort/utils.py | 152 ++++++++++++++++++ .../compile/onediff_compiled_graph.py | 31 ---- .../mock/common.py} | 0 .../{compile_ldm.py => oneflow/mock/ldm.py} | 49 +----- .../{compile_sgm.py => oneflow/mock/sgm.py} | 28 +--- .../compile/oneflow/utils.py | 11 ++ .../compile/quantization.py | 28 ++++ onediff_sd_webui_extensions/compile/sd2.py | 19 +++ onediff_sd_webui_extensions/compile/utils.py | 50 ++++++ .../compile/{compile_vae.py => vae.py} | 0 .../onediff_controlnet.py | 2 +- onediff_sd_webui_extensions/onediff_hijack.py | 21 ++- onediff_sd_webui_extensions/onediff_shared.py | 2 +- onediff_sd_webui_extensions/onediff_utils.py | 19 +++ .../scripts/onediff.py | 4 + 20 files changed, 438 insertions(+), 244 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/compile.py delete mode 100644 onediff_sd_webui_extensions/compile/compile_nexfort_backend.py create mode 100644 onediff_sd_webui_extensions/compile/nexfort/compile.py create mode 100644 onediff_sd_webui_extensions/compile/nexfort/utils.py delete mode 100644 onediff_sd_webui_extensions/compile/onediff_compiled_graph.py rename onediff_sd_webui_extensions/compile/{sd_webui_onediff_utils.py => oneflow/mock/common.py} (100%) rename onediff_sd_webui_extensions/compile/{compile_ldm.py => oneflow/mock/ldm.py} (65%) rename onediff_sd_webui_extensions/compile/{compile_sgm.py => oneflow/mock/sgm.py} (77%) create mode 100644 onediff_sd_webui_extensions/compile/oneflow/utils.py create mode 100644 onediff_sd_webui_extensions/compile/quantization.py create mode 100644 onediff_sd_webui_extensions/compile/sd2.py create mode 100644 onediff_sd_webui_extensions/compile/utils.py rename onediff_sd_webui_extensions/compile/{compile_vae.py => vae.py} (100%) diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 90afcaceb..89454fd4c 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -1,7 +1,7 @@ -from .compile_ldm import SD21CompileCtx -from .compile_utils import get_compiled_graph -from .compile_vae import VaeCompileCtx -from .onediff_compiled_graph import OneDiffCompiledGraph +from .compile import get_compiled_graph +from .sd2 import SD21CompileCtx +from .utils import OneDiffCompiledGraph +from .vae import VaeCompileCtx __all__ = [ "get_compiled_graph", diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py new file mode 100644 index 000000000..31407dca8 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -0,0 +1,44 @@ +from modules.sd_hijack import apply_optimizations + +from onediff.infer_compiler import compile, oneflow_compile + +from .utils import disable_unet_checkpointing, OneDiffCompiledGraph + + +def get_compiled_graph(sd_model, quantization, *, options=None) -> OneDiffCompiledGraph: + diffusion_model = sd_model.model.diffusion_model + # TODO: quantization + if quantization is True: + raise + compiled_unet = onediff_compile(diffusion_model, options=options) + return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) + + +def onediff_compile(unet_model, *, backend="oneflow", options=None): + if backend == "oneflow": + # for controlnet + if "forward" in unet_model.__dict__: + unet_model.__dict__.pop("forward") + return compile_unet_oneflow(unet_model, options=options) + elif backend == "nexfort": + return compile_unet_nexfort(unet_model, options=options) + else: + raise NotImplementedError(f"Can't find backend {backend} for OneDiff") + + +def compile_unet_oneflow(unet_model, *, options=None): + from .oneflow.utils import init_oneflow_backend + + init_oneflow_backend() + disable_unet_checkpointing(unet_model) + return oneflow_compile(unet_model, options=options) + + +def compile_unet_nexfort(unet_model, *, options=None): + from .nexfort.utils import init_nexfort_backend + + init_nexfort_backend() + apply_optimizations("nexfort") + disable_unet_checkpointing(unet_model) + unet_model.convert_to_fp16() + return compile(unet_model, backend="nexfort", options=options) diff --git a/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py b/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py deleted file mode 100644 index a068fa099..000000000 --- a/onediff_sd_webui_extensions/compile/compile_nexfort_backend.py +++ /dev/null @@ -1,66 +0,0 @@ -import onediff_shared -import torch - -from ldm.modules.attention import BasicTransformerBlock -from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel -from ldm.modules.diffusionmodules.util import timestep_embedding -from modules.sd_hijack_utils import CondFunc -from modules.sd_hijack import apply_optimizations -from onediff_utils import singleton_decorator - -from onediff.infer_compiler import compile - - -def nexfort_compile_ldm_unet(unet_model, *, options=None): - create_cond_func() - apply_optimizations("sdp-no-mem - scaled dot product without memory efficient attention") - if not isinstance(unet_model, UNetModel): - return - for module in unet_model.modules(): - if isinstance(module, BasicTransformerBlock): - module.checkpoint = False - if isinstance(module, ResBlock): - module.use_checkpoint = False - unet_model.convert_to_fp16() - return compile(unet_model, backend="nexfort", options=options) - - -@torch.autocast("cuda", enabled=False) -def onediff_nexfort_unet_ldm_forward( - orig_func, self, x, timesteps=None, context=None, y=None, **kwargs -): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels).half() - emb = self.time_embed(t_emb) - x = x.half() - context = context.half() - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - h = x - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) - - -@singleton_decorator -def create_cond_func(): - CondFunc( - "ldm.modules.diffusionmodules.openaimodel.UNetModel.forward", - onediff_nexfort_unet_ldm_forward, - lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, - ) - - -# def init_cond_func(): diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index 9303527f2..5ff3930d5 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -1,79 +1,79 @@ -import warnings -from pathlib import Path -from typing import Dict, Union +# import warnings +# from pathlib import Path +# from typing import Dict, Union -from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM -from modules.sd_models import select_checkpoint -from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM +# from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM +# from modules.sd_models import select_checkpoint +# from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM -from onediff.optimization.quant_optimizer import ( - quantize_model, - varify_can_use_quantization, -) -from onediff.utils import logger +# from onediff.optimization.quant_optimizer import ( +# quantize_model, +# varify_can_use_quantization, +# ) +# from onediff.utils import logger -from .compile_ldm import compile_ldm_unet -from .compile_sgm import compile_sgm_unet -from .onediff_compiled_graph import OneDiffCompiledGraph +# from .compile_ldm import compile_ldm_unet +# from .compile_sgm import compile_sgm_unet +# from .utils import OneDiffCompiledGraph -def compile_unet( - unet_model, quantization=False, *, options=None, -): - if isinstance(unet_model, UNetModelLDM): - compiled_unet = compile_ldm_unet(unet_model, options=options) - elif isinstance(unet_model, UNetModelSGM): - compiled_unet = compile_sgm_unet(unet_model, options=options) - else: - warnings.warn( - f"Unsupported model type: {type(unet_model)} for compilation , skip", - RuntimeWarning, - ) - compiled_unet = unet_model - # In OneDiff Community, quantization can be True when called by api - if quantization and varify_can_use_quantization(): - calibrate_info = get_calibrate_info( - f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" - ) - compiled_unet = quantize_model( - compiled_unet, inplace=False, calibrate_info=calibrate_info - ) - return compiled_unet +# def compile_unet( +# unet_model, quantization=False, *, options=None, +# ): +# if isinstance(unet_model, UNetModelLDM): +# compiled_unet = compile_ldm_unet(unet_model, options=options) +# elif isinstance(unet_model, UNetModelSGM): +# compiled_unet = compile_sgm_unet(unet_model, options=options) +# else: +# warnings.warn( +# f"Unsupported model type: {type(unet_model)} for compilation , skip", +# RuntimeWarning, +# ) +# compiled_unet = unet_model +# # In OneDiff Community, quantization can be True when called by api +# if quantization and varify_can_use_quantization(): +# calibrate_info = get_calibrate_info( +# f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" +# ) +# compiled_unet = quantize_model( +# compiled_unet, inplace=False, calibrate_info=calibrate_info +# ) +# return compiled_unet -def get_calibrate_info(filename: str) -> Union[None, Dict]: - calibration_path = Path(select_checkpoint().filename).parent / filename - if not calibration_path.exists(): - return None +# def get_calibrate_info(filename: str) -> Union[None, Dict]: +# calibration_path = Path(select_checkpoint().filename).parent / filename +# if not calibration_path.exists(): +# return None - logger.info(f"Got calibrate info at {str(calibration_path)}") - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - return calibrate_info +# logger.info(f"Got calibrate info at {str(calibration_path)}") +# calibrate_info = {} +# with open(calibration_path, "r") as f: +# for line in f.readlines(): +# line = line.strip() +# items = line.split(" ") +# calibrate_info[items[0]] = [ +# float(items[1]), +# int(items[2]), +# [float(x) for x in items[3].split(",")], +# ] +# return calibrate_info -def get_compiled_graph( - sd_model, quantization, backend="nexfort" -) -> OneDiffCompiledGraph: - diffusion_model = sd_model.model.diffusion_model +# def get_compiled_graph( +# sd_model, quantization, backend="nexfort" +# ) -> OneDiffCompiledGraph: +# diffusion_model = sd_model.model.diffusion_model - if backend == "oneflow": - # for controlnet - if "forward" in diffusion_model.__dict__: - diffusion_model.__dict__.pop("forward") - compiled_unet = compile_unet(diffusion_model, quantization=quantization) - elif backend == "nexfort": - from .compile_nexfort_backend import nexfort_compile_ldm_unet +# if backend == "oneflow": +# # for controlnet +# if "forward" in diffusion_model.__dict__: +# diffusion_model.__dict__.pop("forward") +# compiled_unet = compile_unet(diffusion_model, quantization=quantization) +# elif backend == "nexfort": +# from .nexfort.compile import nexfort_compile_ldm_unet - compiled_unet = nexfort_compile_ldm_unet(diffusion_model) - else: - raise NotImplementedError - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) +# compiled_unet = nexfort_compile_ldm_unet(diffusion_model) +# else: +# raise NotImplementedError +# return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/nexfort/compile.py b/onediff_sd_webui_extensions/compile/nexfort/compile.py new file mode 100644 index 000000000..0e299135c --- /dev/null +++ b/onediff_sd_webui_extensions/compile/nexfort/compile.py @@ -0,0 +1,14 @@ +from modules.sd_hijack import apply_optimizations +from onediff_utils import disable_unet_checkpointing + +from onediff.infer_compiler import compile + +from .utils import init_nexfort + + +def nexfort_compile_ldm_unet(unet_model, *, options=None): + init_nexfort() + apply_optimizations("nexfort") + disable_unet_checkpointing(unet_model) + unet_model.convert_to_fp16() + return compile(unet_model, backend="nexfort", options=options) diff --git a/onediff_sd_webui_extensions/compile/nexfort/utils.py b/onediff_sd_webui_extensions/compile/nexfort/utils.py new file mode 100644 index 000000000..a677e72de --- /dev/null +++ b/onediff_sd_webui_extensions/compile/nexfort/utils.py @@ -0,0 +1,152 @@ +from typing import List + +import ldm.modules.attention +import onediff_shared +import sgm.modules.attention +import torch +from ldm.modules.diffusionmodules.util import timestep_embedding +from modules import shared +from modules.hypernetworks import hypernetwork +from modules.sd_hijack_optimizations import SdOptimization +from modules.sd_hijack_utils import CondFunc +from onediff_utils import singleton_decorator + + +@singleton_decorator +def init_nexfort_backend(): + CondFunc( + "ldm.modules.diffusionmodules.openaimodel.UNetModel.forward", + onediff_nexfort_unet_ldm_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + + CondFunc( + "sgm.modules.diffusionmodules.openaimodel.UNetModel.forward", + onediff_nexfort_unet_sgm_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + + +@torch.autocast("cuda", enabled=False) +def onediff_nexfort_unet_sgm_forward( + orig_func, self, x, timesteps=None, context=None, y=None, **kwargs +): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels).half() + emb = self.time_embed(t_emb) + x = x.half() + context = context.half() if context is not None else context + y = y.half() if y is not None else y + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + return self.out(h) + + +@torch.autocast("cuda", enabled=False) +def onediff_nexfort_unet_ldm_forward( + orig_func, self, x, timesteps=None, context=None, y=None, **kwargs +): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels).half() + emb = self.time_embed(t_emb) + x = x.half() + context = context.half() + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + + context = x if context is None else context + + context_k, context_v = hypernetwork.apply_hypernetworks( + shared.loaded_hypernetworks, context + ) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class SdOptimizationNexfort(SdOptimization): + name = "nexfort" + cmd_opt = "nexfort" + priority = 10 + + def is_available(self): + try: + import nexfort + except ImportError: + return False + finally: + return True + + def apply(self): + ldm.modules.attention.CrossAttention.forward = ( + scaled_dot_product_attention_forward + ) + sgm.modules.attention.CrossAttention.forward = ( + scaled_dot_product_attention_forward + ) + + +def add_nexfort_optimizer(res: List): + res.append(SdOptimizationNexfort()) diff --git a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py deleted file mode 100644 index d6a09aca3..000000000 --- a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py +++ /dev/null @@ -1,31 +0,0 @@ -import dataclasses - -import torch -from modules import sd_models_types - -from onediff.infer_compiler import DeployableModule - - -@dataclasses.dataclass -class OneDiffCompiledGraph: - name: str = None - filename: str = None - sha: str = None - eager_module: torch.nn.Module = None - graph_module: DeployableModule = None - quantized: bool = False - - def __init__( - self, - sd_model: sd_models_types.WebuiSdModel = None, - graph_module: DeployableModule = None, - quantized=False, - ): - if sd_model is None: - return - self.name = sd_model.sd_checkpoint_info.name - self.filename = sd_model.sd_checkpoint_info.filename - self.sha = sd_model.sd_model_hash - self.eager_module = sd_model.model.diffusion_model - self.graph_module = graph_module - self.quantized = quantized diff --git a/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py b/onediff_sd_webui_extensions/compile/oneflow/mock/common.py similarity index 100% rename from onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py rename to onediff_sd_webui_extensions/compile/oneflow/mock/common.py diff --git a/onediff_sd_webui_extensions/compile/compile_ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py similarity index 65% rename from onediff_sd_webui_extensions/compile/compile_ldm.py rename to onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 7b04e16aa..2a53fd55b 100644 --- a/onediff_sd_webui_extensions/compile/compile_ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -1,30 +1,17 @@ -import os - import oneflow as flow -from ldm.modules.attention import ( - BasicTransformerBlock, - CrossAttention, - SpatialTransformer, -) -from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from ldm.modules.attention import CrossAttention, SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 -from modules import shared -from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register -from .sd_webui_onediff_utils import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) - -__all__ = ["compile_ldm_unet"] +from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding # https://github.com/Stability-AI/stablediffusion/blob/b4bdae9916f628461e1e4edbc62aafedebb9f7ed/ldm/modules/diffusionmodules/openaimodel.py#L775 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + self.convert_to_fp16() assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -80,31 +67,5 @@ def forward(self, x, context=None): SpatialTransformer: SpatialTransformerOflow, UNetModel: UNetModelOflow, } -register(package_names=["ldm"], torch2oflow_class_map=torch2oflow_class_map) - - -def compile_ldm_unet(unet_model, *, options=None): - if not isinstance(unet_model, UNetModel): - return - for module in unet_model.modules(): - if isinstance(module, BasicTransformerBlock): - module.checkpoint = False - if isinstance(module, ResBlock): - module.use_checkpoint = False - return oneflow_compile(unet_model, options=options) - -class SD21CompileCtx(object): - """to avoid results for NaN when the model is v2-1_768-ema-pruned""" - - _var_name = "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION" - - def __enter__(self): - self._original = os.getenv(self._var_name) - if shared.opts.sd_model_checkpoint.startswith("v2-1"): - os.environ[self._var_name] = "0" - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._original is not None: - os.environ[self._var_name] = self._original - return False +register(package_names=["ldm"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/compile_sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py similarity index 77% rename from onediff_sd_webui_extensions/compile/compile_sgm.py rename to onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index 09b86be59..fabd6bcdc 100644 --- a/onediff_sd_webui_extensions/compile/compile_sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -1,22 +1,11 @@ import oneflow as flow -from sgm.modules.attention import ( - BasicTransformerBlock, - CrossAttention, - SpatialTransformer, -) -from sgm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from sgm.modules.attention import CrossAttention, SpatialTransformer +from sgm.modules.diffusionmodules.openaimodel import UNetModel from sgm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register -from .sd_webui_onediff_utils import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) - -__all__ = ["compile_sgm_unet"] +from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding # https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/openaimodel.py#L816 @@ -79,14 +68,3 @@ def forward(self, x, context=None): UNetModel: UNetModelOflow, } register(package_names=["sgm"], torch2oflow_class_map=torch2oflow_class_map) - - -def compile_sgm_unet(unet_model, *, options=None): - if not isinstance(unet_model, UNetModel): - return - for module in unet_model.modules(): - if isinstance(module, BasicTransformerBlock): - module.checkpoint = False - if isinstance(module, ResBlock): - module.use_checkpoint = False - return oneflow_compile(unet_model, options=options) diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py new file mode 100644 index 000000000..65e666aee --- /dev/null +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -0,0 +1,11 @@ +from onediff_utils import singleton_decorator + +from onediff.infer_compiler.backends.oneflow.transform import register + +from .mock import ldm, sgm + + +@singleton_decorator +def init_oneflow_backend(): + register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) + register(package_names=["ldm"], torch2oflow_class_map=sgm.torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py new file mode 100644 index 000000000..1168d360a --- /dev/null +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -0,0 +1,28 @@ +import warnings +from pathlib import Path +from typing import Dict, Union + +from modules.sd_models import select_checkpoint + +from onediff.utils import logger + +from .utils import OneDiffCompiledGraph + +def get_calibrate_info(filename: str) -> Union[None, Dict]: + calibration_path = Path(select_checkpoint().filename).parent / filename + if not calibration_path.exists(): + return None + + logger.info(f"Got calibrate info at {str(calibration_path)}") + calibrate_info = {} + with open(calibration_path, "r") as f: + for line in f.readlines(): + line = line.strip() + items = line.split(" ") + calibrate_info[items[0]] = [ + float(items[1]), + int(items[2]), + [float(x) for x in items[3].split(",")], + ] + return calibrate_info + diff --git a/onediff_sd_webui_extensions/compile/sd2.py b/onediff_sd_webui_extensions/compile/sd2.py new file mode 100644 index 000000000..a1516a7c1 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/sd2.py @@ -0,0 +1,19 @@ +import os + +from modules import shared + + +class SD21CompileCtx(object): + """to avoid results for NaN when the model is v2-1_768-ema-pruned""" + + _var_name = "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION" + + def __enter__(self): + self._original = os.getenv(self._var_name) + if shared.opts.sd_model_checkpoint.startswith("v2-1"): + os.environ[self._var_name] = "0" + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._original is not None: + os.environ[self._var_name] = self._original + return False diff --git a/onediff_sd_webui_extensions/compile/utils.py b/onediff_sd_webui_extensions/compile/utils.py new file mode 100644 index 000000000..5eda9cd9e --- /dev/null +++ b/onediff_sd_webui_extensions/compile/utils.py @@ -0,0 +1,50 @@ +import dataclasses +from typing import Union + +import torch +from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel +from modules import sd_models_types +from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel + +from onediff.infer_compiler import DeployableModule + + +def disable_unet_checkpointing( + unet_model: Union[LdmUNetModel, SgmUNetModel] +) -> Union[LdmUNetModel, SgmUNetModel]: + from ldm.modules.attention import BasicTransformerBlock as LdmBasicTransformerBlock + from ldm.modules.diffusionmodules.openaimodel import ResBlock as LdmResBlock + from sgm.modules.attention import BasicTransformerBlock as SgmBasicTransformerBlock + from sgm.modules.diffusionmodules.openaimodel import ResBlock as SgmResBlock + + for module in unet_model.modules(): + if isinstance(module, (LdmBasicTransformerBlock, SgmBasicTransformerBlock)): + module.checkpoint = False + if isinstance(module, (LdmResBlock, SgmResBlock)): + module.use_checkpoint = False + return unet_model + + +@dataclasses.dataclass +class OneDiffCompiledGraph: + name: str = None + filename: str = None + sha: str = None + eager_module: torch.nn.Module = None + graph_module: DeployableModule = None + quantized: bool = False + + def __init__( + self, + sd_model: sd_models_types.WebuiSdModel = None, + graph_module: DeployableModule = None, + quantized=False, + ): + if sd_model is None: + return + self.name = sd_model.sd_checkpoint_info.name + self.filename = sd_model.sd_checkpoint_info.filename + self.sha = sd_model.sd_model_hash + self.eager_module = sd_model.model.diffusion_model + self.graph_module = graph_module + self.quantized = quantized diff --git a/onediff_sd_webui_extensions/compile/compile_vae.py b/onediff_sd_webui_extensions/compile/vae.py similarity index 100% rename from onediff_sd_webui_extensions/compile/compile_vae.py rename to onediff_sd_webui_extensions/compile/vae.py diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet.py index d2faecceb..a49e7df16 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet.py @@ -5,7 +5,7 @@ import torch import torch as th from compile import OneDiffCompiledGraph -from compile.sd_webui_onediff_utils import ( +from compile.oneflow.mock.common import ( CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding, diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index 355180202..56dbd7436 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -2,7 +2,7 @@ import oneflow import torch -from compile import compile_ldm, compile_sgm +from compile.oneflow.mock import ldm, sgm from modules import sd_models from modules.sd_hijack_utils import CondFunc from onediff_shared import onediff_enabled @@ -66,8 +66,8 @@ def unhijack_function(module, name, new_name): def do_hijack(): - compile_ldm.flow = hijack_flow - compile_sgm.flow = hijack_flow + ldm.flow = hijack_flow + sgm.flow = hijack_flow from modules import script_callbacks, sd_models script_callbacks.on_script_unloaded(undo_hijack) @@ -86,8 +86,8 @@ def do_hijack(): def undo_hijack(): - compile_ldm.flow = oneflow - compile_sgm.flow = oneflow + ldm.flow = oneflow + sgm.flow = oneflow from modules import sd_models unhijack_function( @@ -227,3 +227,14 @@ def load_state_dict(original, module, state_dict, strict=True): onediff_hijack_load_model_weights, lambda _, *args, **kwargs: onediff_enabled, ) + + +def hijack_devices_manual_cast(orig_func, *args, **kwargs): + yield None + + +# CondFunc( +# "devices.manual_cast", +# hijack_devices_manual_cast, +# lambda _, *args, **kwargs: onediff_enabled, +# ) diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index 75da3d953..ec04f637b 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -1,4 +1,4 @@ -from compile.onediff_compiled_graph import OneDiffCompiledGraph +from compile import OneDiffCompiledGraph current_unet_graph = OneDiffCompiledGraph() current_quantization = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 1ea53bebe..c0b5e8c4a 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -3,11 +3,14 @@ from functools import wraps from pathlib import Path from textwrap import dedent +from typing import Union from zipfile import BadZipFile import onediff_shared import oneflow as flow +from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel from modules.devices import torch_gc +from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from onediff.infer_compiler import DeployableModule @@ -146,3 +149,19 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +# def disable_unet_checkpointing( +# unet_model: Union[LdmUNetModel, SgmUNetModel] +# ) -> Union[LdmUNetModel, SgmUNetModel]: +# from ldm.modules.attention import BasicTransformerBlock as LdmBasicTransformerBlock +# from ldm.modules.diffusionmodules.openaimodel import ResBlock as LdmResBlock +# from sgm.modules.attention import BasicTransformerBlock as SgmBasicTransformerBlock +# from sgm.modules.diffusionmodules.openaimodel import ResBlock as SgmResBlock + +# for module in unet_model.modules(): +# if isinstance(module, (LdmBasicTransformerBlock, SgmBasicTransformerBlock)): +# module.checkpoint = False +# if isinstance(module, (LdmResBlock, SgmResBlock)): +# module.use_checkpoint = False +# return unet_model diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 180b31040..4e2fd22ca 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -176,3 +176,7 @@ def cfg_denoisers_callback(params): script_callbacks.on_ui_settings(on_ui_settings) # script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) onediff_do_hijack() + +from compile.nexfort.utils import add_nexfort_optimizer + +script_callbacks.on_list_optimizers(add_nexfort_optimizer) From aa604e5aed6b212b3b320556386e19909461bb9f Mon Sep 17 00:00:00 2001 From: WangYi Date: Fri, 21 Jun 2024 01:42:56 +0800 Subject: [PATCH 15/24] refine --- .../compile/compile.py | 2 +- .../compile/compile_utils.py | 79 ------------------- .../compile/nexfort/compile.py | 14 ---- .../compile/oneflow/mock/ldm.py | 1 - .../compile/oneflow/utils.py | 2 +- .../compile/quantization.py | 2 +- onediff_sd_webui_extensions/onediff_hijack.py | 11 --- onediff_sd_webui_extensions/onediff_utils.py | 20 ----- .../scripts/onediff.py | 2 +- 9 files changed, 4 insertions(+), 129 deletions(-) delete mode 100644 onediff_sd_webui_extensions/compile/compile_utils.py delete mode 100644 onediff_sd_webui_extensions/compile/nexfort/compile.py diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 31407dca8..f58d2e2d8 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -2,7 +2,7 @@ from onediff.infer_compiler import compile, oneflow_compile -from .utils import disable_unet_checkpointing, OneDiffCompiledGraph +from .utils import OneDiffCompiledGraph, disable_unet_checkpointing def get_compiled_graph(sd_model, quantization, *, options=None) -> OneDiffCompiledGraph: diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py deleted file mode 100644 index 5ff3930d5..000000000 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ /dev/null @@ -1,79 +0,0 @@ -# import warnings -# from pathlib import Path -# from typing import Dict, Union - -# from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM -# from modules.sd_models import select_checkpoint -# from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM - -# from onediff.optimization.quant_optimizer import ( -# quantize_model, -# varify_can_use_quantization, -# ) -# from onediff.utils import logger - -# from .compile_ldm import compile_ldm_unet -# from .compile_sgm import compile_sgm_unet -# from .utils import OneDiffCompiledGraph - - -# def compile_unet( -# unet_model, quantization=False, *, options=None, -# ): -# if isinstance(unet_model, UNetModelLDM): -# compiled_unet = compile_ldm_unet(unet_model, options=options) -# elif isinstance(unet_model, UNetModelSGM): -# compiled_unet = compile_sgm_unet(unet_model, options=options) -# else: -# warnings.warn( -# f"Unsupported model type: {type(unet_model)} for compilation , skip", -# RuntimeWarning, -# ) -# compiled_unet = unet_model -# # In OneDiff Community, quantization can be True when called by api -# if quantization and varify_can_use_quantization(): -# calibrate_info = get_calibrate_info( -# f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" -# ) -# compiled_unet = quantize_model( -# compiled_unet, inplace=False, calibrate_info=calibrate_info -# ) -# return compiled_unet - - -# def get_calibrate_info(filename: str) -> Union[None, Dict]: -# calibration_path = Path(select_checkpoint().filename).parent / filename -# if not calibration_path.exists(): -# return None - -# logger.info(f"Got calibrate info at {str(calibration_path)}") -# calibrate_info = {} -# with open(calibration_path, "r") as f: -# for line in f.readlines(): -# line = line.strip() -# items = line.split(" ") -# calibrate_info[items[0]] = [ -# float(items[1]), -# int(items[2]), -# [float(x) for x in items[3].split(",")], -# ] -# return calibrate_info - - -# def get_compiled_graph( -# sd_model, quantization, backend="nexfort" -# ) -> OneDiffCompiledGraph: -# diffusion_model = sd_model.model.diffusion_model - -# if backend == "oneflow": -# # for controlnet -# if "forward" in diffusion_model.__dict__: -# diffusion_model.__dict__.pop("forward") -# compiled_unet = compile_unet(diffusion_model, quantization=quantization) -# elif backend == "nexfort": -# from .nexfort.compile import nexfort_compile_ldm_unet - -# compiled_unet = nexfort_compile_ldm_unet(diffusion_model) -# else: -# raise NotImplementedError -# return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/nexfort/compile.py b/onediff_sd_webui_extensions/compile/nexfort/compile.py deleted file mode 100644 index 0e299135c..000000000 --- a/onediff_sd_webui_extensions/compile/nexfort/compile.py +++ /dev/null @@ -1,14 +0,0 @@ -from modules.sd_hijack import apply_optimizations -from onediff_utils import disable_unet_checkpointing - -from onediff.infer_compiler import compile - -from .utils import init_nexfort - - -def nexfort_compile_ldm_unet(unet_model, *, options=None): - init_nexfort() - apply_optimizations("nexfort") - disable_unet_checkpointing(unet_model) - unet_model.convert_to_fp16() - return compile(unet_model, backend="nexfort", options=options) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 2a53fd55b..9667fa505 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -11,7 +11,6 @@ # https://github.com/Stability-AI/stablediffusion/blob/b4bdae9916f628461e1e4edbc62aafedebb9f7ed/ldm/modules/diffusionmodules/openaimodel.py#L775 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - self.convert_to_fp16() assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index 65e666aee..fa87ca33f 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -8,4 +8,4 @@ @singleton_decorator def init_oneflow_backend(): register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) - register(package_names=["ldm"], torch2oflow_class_map=sgm.torch2oflow_class_map) + register(package_names=["sgm"], torch2oflow_class_map=sgm.torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py index 1168d360a..802e61591 100644 --- a/onediff_sd_webui_extensions/compile/quantization.py +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -8,6 +8,7 @@ from .utils import OneDiffCompiledGraph + def get_calibrate_info(filename: str) -> Union[None, Dict]: calibration_path = Path(select_checkpoint().filename).parent / filename if not calibration_path.exists(): @@ -25,4 +26,3 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]: [float(x) for x in items[3].split(",")], ] return calibrate_info - diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index 56dbd7436..47b5c457c 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -227,14 +227,3 @@ def load_state_dict(original, module, state_dict, strict=True): onediff_hijack_load_model_weights, lambda _, *args, **kwargs: onediff_enabled, ) - - -def hijack_devices_manual_cast(orig_func, *args, **kwargs): - yield None - - -# CondFunc( -# "devices.manual_cast", -# hijack_devices_manual_cast, -# lambda _, *args, **kwargs: onediff_enabled, -# ) diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index c0b5e8c4a..f3839dbd8 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -1,16 +1,12 @@ import os -from contextlib import contextmanager from functools import wraps from pathlib import Path from textwrap import dedent -from typing import Union from zipfile import BadZipFile import onediff_shared import oneflow as flow -from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel from modules.devices import torch_gc -from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from onediff.infer_compiler import DeployableModule @@ -149,19 +145,3 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper - - -# def disable_unet_checkpointing( -# unet_model: Union[LdmUNetModel, SgmUNetModel] -# ) -> Union[LdmUNetModel, SgmUNetModel]: -# from ldm.modules.attention import BasicTransformerBlock as LdmBasicTransformerBlock -# from ldm.modules.diffusionmodules.openaimodel import ResBlock as LdmResBlock -# from sgm.modules.attention import BasicTransformerBlock as SgmBasicTransformerBlock -# from sgm.modules.diffusionmodules.openaimodel import ResBlock as SgmResBlock - -# for module in unet_model.modules(): -# if isinstance(module, (LdmBasicTransformerBlock, SgmBasicTransformerBlock)): -# module.checkpoint = False -# if isinstance(module, (LdmResBlock, SgmResBlock)): -# module.use_checkpoint = False -# return unet_model diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 4e2fd22ca..487ff6da5 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -8,6 +8,7 @@ import onediff_shared import oneflow as flow from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph +from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks from modules.devices import torch_gc from modules.processing import process_images @@ -177,6 +178,5 @@ def cfg_denoisers_callback(params): # script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) onediff_do_hijack() -from compile.nexfort.utils import add_nexfort_optimizer script_callbacks.on_list_optimizers(add_nexfort_optimizer) From 70bf929012ce93b9a168b20c24a26c76995addb1 Mon Sep 17 00:00:00 2001 From: WangYi Date: Fri, 21 Jun 2024 12:32:24 +0800 Subject: [PATCH 16/24] support quant and refine --- .../compile/compile.py | 46 +++++++++++++------ .../compile/quantization.py | 16 ++++++- .../scripts/onediff.py | 15 +++++- tests/sd-webui/test_api.py | 1 + 4 files changed, 61 insertions(+), 17 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index f58d2e2d8..4872cc578 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -3,38 +3,56 @@ from onediff.infer_compiler import compile, oneflow_compile from .utils import OneDiffCompiledGraph, disable_unet_checkpointing +from .quantization import quant_unet_oneflow -def get_compiled_graph(sd_model, quantization, *, options=None) -> OneDiffCompiledGraph: +def get_compiled_graph( + sd_model, *, backend, quantization=None, options=None +) -> OneDiffCompiledGraph: diffusion_model = sd_model.model.diffusion_model - # TODO: quantization - if quantization is True: - raise - compiled_unet = onediff_compile(diffusion_model, options=options) + compiled_unet = onediff_compile( + diffusion_model, backend=backend, quantization=quantization, options=options + ) return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) -def onediff_compile(unet_model, *, backend="oneflow", options=None): +def onediff_compile(unet_model, *, quantization=False, backend="oneflow", options=None): if backend == "oneflow": - # for controlnet - if "forward" in unet_model.__dict__: - unet_model.__dict__.pop("forward") - return compile_unet_oneflow(unet_model, options=options) + return compile_unet_oneflow( + unet_model, quantization=quantization, options=options + ) elif backend == "nexfort": - return compile_unet_nexfort(unet_model, options=options) + return compile_unet_nexfort( + unet_model, quantization=quantization, options=options + ) else: raise NotImplementedError(f"Can't find backend {backend} for OneDiff") -def compile_unet_oneflow(unet_model, *, options=None): +def compile_unet_oneflow(unet_model, *, quantization=False, options=None): from .oneflow.utils import init_oneflow_backend + # 1. register mock map for converting torch to oneflow init_oneflow_backend() + + # 2. (for controlnet) remove attr forward to prevent mock failing + if "forward" in unet_model.__dict__: + unet_model.__dict__.pop("forward") + + # 3. disable checkpoint to prevent mock failing disable_unet_checkpointing(unet_model) - return oneflow_compile(unet_model, options=options) + + compiled_unet_model = oneflow_compile(unet_model, options=options) + if quantization: + compiled_unet_model = quant_unet_oneflow(compiled_unet_model) + return compiled_unet_model -def compile_unet_nexfort(unet_model, *, options=None): +def compile_unet_nexfort(unet_model, *, quantization=False, options=None): + if quantization: + raise NotImplementedError( + "Quantization for nexfort backend is not implemented yet." + ) from .nexfort.utils import init_nexfort_backend init_nexfort_backend() diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py index 802e61591..b5a08d7ac 100644 --- a/onediff_sd_webui_extensions/compile/quantization.py +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -1,12 +1,24 @@ -import warnings from pathlib import Path from typing import Dict, Union from modules.sd_models import select_checkpoint from onediff.utils import logger +from onediff.optimization.quant_optimizer import ( + quantize_model, + varify_can_use_quantization, +) -from .utils import OneDiffCompiledGraph + +def quant_unet_oneflow(compiled_unet): + if varify_can_use_quantization(): + calibrate_info = get_calibrate_info( + f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" + ) + compiled_unet = quantize_model( + compiled_unet, inplace=False, calibrate_info=calibrate_info + ) + return compiled_unet def get_calibrate_info(filename: str) -> Union[None, Dict]: diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 487ff6da5..fd9a88c4a 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -102,6 +102,7 @@ def run( compiler_cache=None, saved_cache_name="", always_recompile=False, + backend=None, ): # restore checkpoint_info from refiner to base model if necessary if ( @@ -115,6 +116,8 @@ def run( torch_gc() flow.cuda.empty_cache() + backend = backend or shared.opts.onediff_compiler_backend + current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( shared.sd_model.sd_checkpoint_info.name @@ -141,7 +144,7 @@ def run( if need_recompile: if not onediff_shared.controlnet_enabled: onediff_shared.current_unet_graph = get_compiled_graph( - shared.sd_model, quantization + shared.sd_model, quantization=quantization, backend=backend, ) load_graph(onediff_shared.current_unet_graph, compiler_cache) else: @@ -168,6 +171,16 @@ def on_ui_settings(): section=section, ), ) + shared.opts.add_option( + "onediff_compiler_backend", + shared.OptionInfo( + "oneflow", + "Backend for onediff compiler", + gr.Radio, + {"choices": ["oneflow", "nexfort"]}, + section=section, + ), + ) def cfg_denoisers_callback(params): diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index fa7550abe..f4cee6601 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -1,6 +1,7 @@ import os import numpy as np import pytest +from pathlib import Path from PIL import Image from utils import ( IMG2IMG_API_ENDPOINT, From 2c650e7a128ceb2e3d411944984b04be1ea131d2 Mon Sep 17 00:00:00 2001 From: WangYi Date: Sat, 22 Jun 2024 09:27:57 +0800 Subject: [PATCH 17/24] refine --- onediff_sd_webui_extensions/onediff_utils.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 2f9307b95..dc3a42927 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -61,23 +61,8 @@ def refresh_all_compiler_caches(path: Path = None): all_compiler_caches = [f.stem for f in Path(path).iterdir() if f.is_file()] -<<<<<<< HEAD -def check_structure_change_and_update(current_type: dict[str, bool], model): - def get_model_type(model): - return { - "is_sdxl": model.is_sdxl, - "is_sd2": model.is_sd2, - "is_sd1": model.is_sd1, - "is_ssd": model.is_ssd, - } - - changed = current_type != get_model_type(model) - current_type.update(**get_model_type(model)) - return changed -======= def check_structure_change(current_type: dict[str, bool], model): return current_type != get_model_type(model) ->>>>>>> main def load_graph(compiled_unet: DeployableModule, compiler_cache: str): From 448d88d8605513fa46c6eae01ba4f118784047de Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 24 Jun 2024 09:59:21 +0800 Subject: [PATCH 18/24] remove oneflow for nexfort backend --- .../compile/compile.py | 2 +- .../compile/oneflow/mock/controlnet.py | 233 ++++++++++++++ .../compile/oneflow/mock/ldm.py | 2 - .../compile/oneflow/mock/sgm.py | 1 - .../compile/oneflow/utils.py | 17 +- .../compile/quantization.py | 2 +- .../onediff_controlnet/__init__.py | 6 + .../onediff_controlnet/compile.py | 43 +++ .../hijack.py} | 300 +----------------- .../onediff_controlnet/utils.py | 23 ++ onediff_sd_webui_extensions/onediff_shared.py | 1 + onediff_sd_webui_extensions/onediff_utils.py | 9 +- .../scripts/onediff.py | 14 +- 13 files changed, 348 insertions(+), 305 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py create mode 100644 onediff_sd_webui_extensions/onediff_controlnet/__init__.py create mode 100644 onediff_sd_webui_extensions/onediff_controlnet/compile.py rename onediff_sd_webui_extensions/{onediff_controlnet.py => onediff_controlnet/hijack.py} (77%) create mode 100644 onediff_sd_webui_extensions/onediff_controlnet/utils.py diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 4872cc578..725377ca2 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -2,8 +2,8 @@ from onediff.infer_compiler import compile, oneflow_compile -from .utils import OneDiffCompiledGraph, disable_unet_checkpointing from .quantization import quant_unet_oneflow +from .utils import OneDiffCompiledGraph, disable_unet_checkpointing def get_compiled_graph( diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py new file mode 100644 index 000000000..3c8299040 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py @@ -0,0 +1,233 @@ +import oneflow as flow +import torch +import torch as th +from compile.oneflow.mock.common import ( + CrossAttentionOflow, + GroupNorm32Oflow, + timestep_embedding, +) +from ldm.modules.attention import BasicTransformerBlock, CrossAttention +from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from ldm.modules.diffusionmodules.util import GroupNorm32 +from modules import devices + +from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register + +cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + +# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 +def torch_aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + if x == 0.0: + return base + return base + x + + if require_channel_alignment: + zeros = torch.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1: + if base_h != xh or base_w != xw: + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + + return base + x + + +# Due to the tracing mechanism in OneFlow, it's crucial to ensure that +# the same conditional branches are taken during the first run as in subsequent runs. +# Therefore, certain "optimizations" have been modified. +def oneflow_aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + # remove `if x == 0.0: return base` here + return base + x + + if require_channel_alignment: + zeros = flow.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1 and (base_h != xh or base_w != xw): + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + return base + x + + +class TorchOnediffControlNetModel(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.time_embed = unet.time_embed + self.input_blocks = unet.input_blocks + self.label_emb = getattr(unet, "label_emb", None) + self.middle_block = unet.middle_block + self.output_blocks = unet.output_blocks + self.out = unet.out + self.model_channels = unet.model_channels + + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + from ldm.modules.diffusionmodules.util import timestep_embedding + + hs = [] + with th.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + emb = self.time_embed(t_emb) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = torch_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = torch_aligned_adding( + h, total_controlnet_embedding.pop(), require_inpaint_hijack + ) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = torch_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = th.cat( + [ + h, + torch_aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h + + +class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + x = x.half() + if y is not None: + y = y.half() + context = context.half() + hs = [] + with flow.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + emb = self.time_embed(t_emb.half()) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = oneflow_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = oneflow_aligned_adding( + h, total_controlnet_embedding.pop(), require_inpaint_hijack + ) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = oneflow_aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = flow.cat( + [ + h, + oneflow_aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = h.half() + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h + + +torch2oflow_class_map = { + CrossAttention: CrossAttentionOflow, + GroupNorm32: GroupNorm32Oflow, + TorchOnediffControlNetModel: OneFlowOnediffControlNetModel, +} +# register(package_names=["scripts.hook"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 9667fa505..8d5a4a4bb 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -66,5 +66,3 @@ def forward(self, x, context=None): SpatialTransformer: SpatialTransformerOflow, UNetModel: UNetModelOflow, } - -register(package_names=["ldm"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index fabd6bcdc..4efaa82d4 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -67,4 +67,3 @@ def forward(self, x, context=None): SpatialTransformer: SpatialTransformerOflow, UNetModel: UNetModelOflow, } -register(package_names=["sgm"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index fa87ca33f..115a25ac3 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -2,10 +2,23 @@ from onediff.infer_compiler.backends.oneflow.transform import register -from .mock import ldm, sgm - @singleton_decorator def init_oneflow_backend(): + try: + import oneflow as flow + except ImportError: + raise RuntimeError( + "Backend oneflow for OneDiff is invalid, please make sure you have installed OneFlow" + ) + + from .mock import controlnet, ldm, sgm + register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) register(package_names=["sgm"], torch2oflow_class_map=sgm.torch2oflow_class_map) + register( + package_names=["scripts.hook"], + torch2oflow_class_map=controlnet.torch2oflow_class_map, + ) + + diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py index b5a08d7ac..9bd52755f 100644 --- a/onediff_sd_webui_extensions/compile/quantization.py +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -3,11 +3,11 @@ from modules.sd_models import select_checkpoint -from onediff.utils import logger from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, ) +from onediff.utils import logger def quant_unet_oneflow(compiled_unet): diff --git a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py new file mode 100644 index 000000000..35605949e --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py @@ -0,0 +1,6 @@ +from .compile import onediff_controlnet_decorator, compile_controlnet_ldm_unet +from .hijack import hijack_controlnet_extension + +__all__ = [ + "onediff_controlnet_decorator", +] \ No newline at end of file diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py new file mode 100644 index 000000000..732988d6b --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -0,0 +1,43 @@ +from .utils import check_if_controlnet_enabled +from modules import shared +from .hijack import hijack_controlnet_extension +import onediff_shared +from functools import wraps +from onediff.infer_compiler import oneflow_compile +from compile.utils import disable_unet_checkpointing +def onediff_controlnet_decorator(func): + @wraps(func) + # TODO: restore hijacked func here + def wrapper(self, p, *arg, **kwargs): + try: + onediff_shared.controlnet_enabled = check_if_controlnet_enabled(p) + if onediff_shared.controlnet_enabled: + hijack_controlnet_extension(p) + return func(self, p, *arg, **kwargs) + finally: + if onediff_shared.controlnet_enabled: + onediff_shared.previous_is_controlnet = True + else: + onediff_shared.controlnet_compiled = False + onediff_shared.previous_is_controlnet = False + + return wrapper + + +def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=None): + backend = backend or shared.opts.onediff_compiler_backend + from ldm.modules.attention import BasicTransformerBlock as BasicTransformerBlockLDM + from ldm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockLDM + from sgm.modules.attention import BasicTransformerBlock as BasicTransformerBlockSGM + from sgm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockSGM + + if backend == "oneflow": + disable_unet_checkpointing(unet_model) + compiled_model = oneflow_compile(unet_model, options=options) + elif backend == "nexfort": + raise NotImplementedError("nexfort backend for controlnet is not implemented yet") + # TODO: refine here + compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) + compiled_graph.eager_module = unet_model + compiled_graph.name += "_controlnet" + return compiled_graph \ No newline at end of file diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py similarity index 77% rename from onediff_sd_webui_extensions/onediff_controlnet.py rename to onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 9537114a8..453fd650d 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -1,291 +1,18 @@ -from functools import wraps - import onediff_shared -from onediff_utils import check_structure_change -import oneflow as flow import torch -import torch as th -from compile import OneDiffCompiledGraph -from compile.oneflow.mock.common import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) +from modules.sd_hijack_utils import CondFunc +from onediff_utils import check_structure_change, singleton_decorator from ldm.modules.attention import BasicTransformerBlock, CrossAttention from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 -from modules import devices -from modules.sd_hijack_utils import CondFunc -from onediff_utils import singleton_decorator - -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register - - -# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 -def torch_aligned_adding(base, x, require_channel_alignment): - if isinstance(x, float): - if x == 0.0: - return base - return base + x - - if require_channel_alignment: - zeros = torch.zeros_like(base) - zeros[:, : x.shape[1], ...] = x - x = zeros - - # resize to sample resolution - base_h, base_w = base.shape[-2:] - xh, xw = x.shape[-2:] - - if xh > 1 or xw > 1: - if base_h != xh or base_w != xw: - # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') - x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") - - return base + x - - -# Due to the tracing mechanism in OneFlow, it's crucial to ensure that -# the same conditional branches are taken during the first run as in subsequent runs. -# Therefore, certain "optimizations" have been modified. -def oneflow_aligned_adding(base, x, require_channel_alignment): - if isinstance(x, float): - # remove `if x == 0.0: return base` here - return base + x - - if require_channel_alignment: - zeros = flow.zeros_like(base) - zeros[:, : x.shape[1], ...] = x - x = zeros - - # resize to sample resolution - base_h, base_w = base.shape[-2:] - xh, xw = x.shape[-2:] - - if xh > 1 or xw > 1 and (base_h != xh or base_w != xw): - # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') - x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") - return base + x - - -cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) - - -class TorchOnediffControlNetModel(torch.nn.Module): - def __init__(self, unet): - super().__init__() - self.time_embed = unet.time_embed - self.input_blocks = unet.input_blocks - self.label_emb = getattr(unet, "label_emb", None) - self.middle_block = unet.middle_block - self.output_blocks = unet.output_blocks - self.out = unet.out - self.model_channels = unet.model_channels - - def forward( - self, - x, - timesteps, - context, - y, - total_t2i_adapter_embedding, - total_controlnet_embedding, - is_sdxl, - require_inpaint_hijack, - ): - from ldm.modules.diffusionmodules.util import timestep_embedding - - hs = [] - with th.no_grad(): - t_emb = cond_cast_unet( - timestep_embedding(timesteps, self.model_channels, repeat_only=False) - ) - emb = self.time_embed(t_emb) - - if is_sdxl: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for i, module in enumerate(self.input_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = module(h, emb, context) - - t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] - - if i in t2i_injection: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - hs.append(h) - - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = self.middle_block(h, emb, context) - - # U-Net Middle Block - h = torch_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) - - if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - # U-Net Decoder - for i, module in enumerate(self.output_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = th.cat( - [ - h, - torch_aligned_adding( - hs.pop(), - total_controlnet_embedding.pop(), - require_inpaint_hijack, - ), - ], - dim=1, - ) - h = module(h, emb, context) - - # U-Net Output - h = h.type(x.dtype) - h = self.out(h) - - return h - - -class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): - def forward( - self, - x, - timesteps, - context, - y, - total_t2i_adapter_embedding, - total_controlnet_embedding, - is_sdxl, - require_inpaint_hijack, - ): - x = x.half() - if y is not None: - y = y.half() - context = context.half() - hs = [] - with flow.no_grad(): - t_emb = cond_cast_unet( - timestep_embedding(timesteps, self.model_channels, repeat_only=False) - ) - emb = self.time_embed(t_emb.half()) - - if is_sdxl: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for i, module in enumerate(self.input_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = module(h, emb, context) - - t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] - - if i in t2i_injection: - h = oneflow_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - hs.append(h) - - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = self.middle_block(h, emb, context) - - # U-Net Middle Block - h = oneflow_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) - - if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = oneflow_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - # U-Net Decoder - for i, module in enumerate(self.output_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = flow.cat( - [ - h, - oneflow_aligned_adding( - hs.pop(), - total_controlnet_embedding.pop(), - require_inpaint_hijack, - ), - ], - dim=1, - ) - h = h.half() - h = module(h, emb, context) - - # U-Net Output - h = h.type(x.dtype) - h = self.out(h) - - return h - - -def onediff_controlnet_decorator(func): - @wraps(func) - def wrapper(self, p, *arg, **kwargs): - try: - onediff_shared.controlnet_enabled = check_if_controlnet_enabled(p) - if onediff_shared.controlnet_enabled: - hijack_controlnet_extension(p) - return func(self, p, *arg, **kwargs) - finally: - if onediff_shared.controlnet_enabled: - onediff_shared.previous_is_controlnet = True - else: - onediff_shared.controlnet_compiled = False - onediff_shared.previous_is_controlnet = False - - return wrapper - - -def compile_controlnet_ldm_unet(sd_model, unet_model, *, options=None): - from sgm.modules.attention import BasicTransformerBlock as BasicTransformerBlockSGM - from ldm.modules.attention import BasicTransformerBlock as BasicTransformerBlockLDM - from sgm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockSGM - from ldm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockLDM - for module in unet_model.modules(): - if isinstance(module, (BasicTransformerBlockLDM, BasicTransformerBlockSGM)): - module.checkpoint = False - if isinstance(module, (ResBlockLDM, ResBlockSGM)): - module.use_checkpoint = False - # TODO: refine here - compiled_model = oneflow_compile(unet_model, options=options) - compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) - compiled_graph.eager_module = unet_model - compiled_graph.name += "_controlnet" - return compiled_graph - - -torch2oflow_class_map = { - CrossAttention: CrossAttentionOflow, - GroupNorm32: GroupNorm32Oflow, - TorchOnediffControlNetModel: OneFlowOnediffControlNetModel, -} -register(package_names=["scripts.hook"], torch2oflow_class_map=torch2oflow_class_map) - - -def check_if_controlnet_ext_loaded() -> bool: - from modules import extensions - - return "sd-webui-controlnet" in extensions.loaded_extensions +from .utils import get_controlnet_script def hijacked_main_entry(self, p): + from compile.oneflow.mock.controlnet import TorchOnediffControlNetModel + from .compile import compile_controlnet_ldm_unet + self._original_controlnet_main_entry(p) sd_ldm = p.sd_model unet = sd_ldm.model.diffusion_model @@ -303,21 +30,6 @@ def hijacked_main_entry(self, p): pass -def get_controlnet_script(p): - for script in p.scripts.scripts: - if script.__module__ == "controlnet.py": - return script - return None - - -def check_if_controlnet_enabled(p): - controlnet_script_class = get_controlnet_script(p) - return ( - controlnet_script_class is not None - and len(controlnet_script_class.get_enabled_units(p)) != 0 - ) - - # When OneDiff is initializing, the controlnet extension has not yet been loaded. # Therefore, this function should be called during image generation # rather than during the initialization of the OneDiff. diff --git a/onediff_sd_webui_extensions/onediff_controlnet/utils.py b/onediff_sd_webui_extensions/onediff_controlnet/utils.py new file mode 100644 index 000000000..025773642 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/utils.py @@ -0,0 +1,23 @@ +from functools import wraps +import onediff_shared + +def check_if_controlnet_ext_loaded() -> bool: + from modules import extensions + + return "sd-webui-controlnet" in extensions.loaded_extensions + + +def get_controlnet_script(p): + for script in p.scripts.scripts: + if script.__module__ == "controlnet.py": + return script + return None + + +def check_if_controlnet_enabled(p): + controlnet_script_class = get_controlnet_script(p) + return ( + controlnet_script_class is not None + and len(controlnet_script_class.get_enabled_units(p)) != 0 + ) + diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index bd9041c82..3af881a8c 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -9,6 +9,7 @@ "is_ssd": False, } onediff_enabled = False +onediff_backend = "oneflow" # controlnet controlnet_enabled = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index dc3a42927..dbab0f354 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -7,8 +7,8 @@ import onediff_shared import oneflow as flow -from modules.devices import torch_gc from modules import shared +from modules.devices import torch_gc from onediff.infer_compiler import DeployableModule @@ -139,6 +139,7 @@ def wrapper(*args, **kwargs): return wrapper + def get_model_type(model): return { "is_sdxl": model.is_sdxl, @@ -146,3 +147,9 @@ def get_model_type(model): "is_sd1": model.is_sd1, "is_ssd": model.is_ssd, } + + +def onediff_gc(): + torch_gc() + if shared.opts.onediff_compiler_backend == "oneflow": + flow.cuda.empty_cache() diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 0ad3467b4..2ee4cd229 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -6,7 +6,6 @@ import modules.shared as shared import onediff_controlnet import onediff_shared -import oneflow as flow from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks @@ -21,6 +20,7 @@ hints_message, load_graph, onediff_enabled_decorator, + onediff_gc, refresh_all_compiler_caches, save_graph, ) @@ -113,11 +113,19 @@ def run( ): p.override_settings.pop("sd_model_checkpoint", None) sd_models.reload_model_weights() - torch_gc() - flow.cuda.empty_cache() + onediff_gc() backend = backend or shared.opts.onediff_compiler_backend + if backend == "oneflow": + from compile.oneflow.utils import init_oneflow_backend + + init_oneflow_backend() + elif backend == "nexfort": + from compile.nexfort.utils import init_nexfort_backend + + init_nexfort_backend() + current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( shared.sd_model.sd_checkpoint_info.name From 28dbbf1c2be107a9b9ddb070891db00510c9863c Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 24 Jun 2024 11:07:31 +0800 Subject: [PATCH 19/24] fix bug --- onediff_sd_webui_extensions/onediff_controlnet/compile.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index 732988d6b..bf399368e 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -4,7 +4,9 @@ import onediff_shared from functools import wraps from onediff.infer_compiler import oneflow_compile -from compile.utils import disable_unet_checkpointing +from compile.utils import disable_unet_checkpointing, OneDiffCompiledGraph + + def onediff_controlnet_decorator(func): @wraps(func) # TODO: restore hijacked func here From cb7903ae1c1e4f27d352cc263dd92dc7a4a62961 Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 24 Jun 2024 12:16:10 +0800 Subject: [PATCH 20/24] refine --- .../compile/__init__.py | 12 ++++++- .../compile/backend.py | 6 ++++ .../compile/compile.py | 21 ++++++++--- .../compile/nexfort/utils.py | 9 ++--- .../compile/oneflow/mock/controlnet.py | 7 ++-- .../compile/oneflow/mock/ldm.py | 2 +- .../compile/oneflow/mock/sgm.py | 2 +- .../compile/oneflow/utils.py | 7 ++-- .../compile/quantization.py | 2 +- onediff_sd_webui_extensions/compile/utils.py | 16 ++++++++- .../onediff_controlnet/__init__.py | 4 +-- .../onediff_controlnet/compile.py | 35 +++++++++++-------- .../onediff_controlnet/hijack.py | 5 +-- .../onediff_controlnet/utils.py | 3 +- onediff_sd_webui_extensions/onediff_shared.py | 4 +-- onediff_sd_webui_extensions/onediff_utils.py | 13 ++++--- .../scripts/onediff.py | 22 +++++++----- 17 files changed, 114 insertions(+), 56 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/backend.py diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 89454fd4c..9b83b18d5 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -1,6 +1,12 @@ +from .backend import OneDiffBackend from .compile import get_compiled_graph from .sd2 import SD21CompileCtx -from .utils import OneDiffCompiledGraph +from .utils import ( + OneDiffCompiledGraph, + get_onediff_backend, + is_nexfort_backend, + is_oneflow_backend, +) from .vae import VaeCompileCtx __all__ = [ @@ -8,4 +14,8 @@ "SD21CompileCtx", "VaeCompileCtx", "OneDiffCompiledGraph", + "OneDiffBackend", + "get_onediff_backend", + "is_oneflow_backend", + "is_nexfort_backend", ] diff --git a/onediff_sd_webui_extensions/compile/backend.py b/onediff_sd_webui_extensions/compile/backend.py new file mode 100644 index 000000000..e4c592a1a --- /dev/null +++ b/onediff_sd_webui_extensions/compile/backend.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class OneDiffBackend(Enum): + ONEFLOW = "oneflow" + NEXFORT = "nexfort" diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 725377ca2..ed97ba940 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -2,8 +2,15 @@ from onediff.infer_compiler import compile, oneflow_compile +from compile import OneDiffBackend + from .quantization import quant_unet_oneflow -from .utils import OneDiffCompiledGraph, disable_unet_checkpointing +from .utils import ( + OneDiffCompiledGraph, + disable_unet_checkpointing, + is_nexfort_backend, + is_oneflow_backend, +) def get_compiled_graph( @@ -16,12 +23,18 @@ def get_compiled_graph( return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) -def onediff_compile(unet_model, *, quantization=False, backend="oneflow", options=None): - if backend == "oneflow": +def onediff_compile( + unet_model, + *, + quantization: bool = False, + backend: OneDiffBackend = None, + options=None, +): + if is_oneflow_backend(backend): return compile_unet_oneflow( unet_model, quantization=quantization, options=options ) - elif backend == "nexfort": + elif is_nexfort_backend(backend): return compile_unet_nexfort( unet_model, quantization=quantization, options=options ) diff --git a/onediff_sd_webui_extensions/compile/nexfort/utils.py b/onediff_sd_webui_extensions/compile/nexfort/utils.py index a677e72de..84113b6c7 100644 --- a/onediff_sd_webui_extensions/compile/nexfort/utils.py +++ b/onediff_sd_webui_extensions/compile/nexfort/utils.py @@ -11,6 +11,8 @@ from modules.sd_hijack_utils import CondFunc from onediff_utils import singleton_decorator +from onediff.utils.import_utils import is_nexfort_available + @singleton_decorator def init_nexfort_backend(): @@ -132,12 +134,7 @@ class SdOptimizationNexfort(SdOptimization): priority = 10 def is_available(self): - try: - import nexfort - except ImportError: - return False - finally: - return True + is_nexfort_available() def apply(self): ldm.modules.attention.CrossAttention.forward = ( diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py index 3c8299040..9933b08f8 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py @@ -6,15 +6,16 @@ GroupNorm32Oflow, timestep_embedding, ) -from ldm.modules.attention import BasicTransformerBlock, CrossAttention -from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel +from ldm.modules.attention import CrossAttention +from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import devices -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + # https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 def torch_aligned_adding(base, x, require_channel_alignment): if isinstance(x, float): diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 8d5a4a4bb..d8d1ab8c8 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -3,7 +3,7 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index 4efaa82d4..b840eaf21 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -3,7 +3,7 @@ from sgm.modules.diffusionmodules.openaimodel import UNetModel from sgm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index 115a25ac3..6a7aca974 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -1,13 +1,12 @@ from onediff_utils import singleton_decorator from onediff.infer_compiler.backends.oneflow.transform import register +from onediff.utils.import_utils import is_oneflow_available @singleton_decorator def init_oneflow_backend(): - try: - import oneflow as flow - except ImportError: + if not is_oneflow_available(): raise RuntimeError( "Backend oneflow for OneDiff is invalid, please make sure you have installed OneFlow" ) @@ -20,5 +19,3 @@ def init_oneflow_backend(): package_names=["scripts.hook"], torch2oflow_class_map=controlnet.torch2oflow_class_map, ) - - diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py index b5a08d7ac..9bd52755f 100644 --- a/onediff_sd_webui_extensions/compile/quantization.py +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -3,11 +3,11 @@ from modules.sd_models import select_checkpoint -from onediff.utils import logger from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, ) +from onediff.utils import logger def quant_unet_oneflow(compiled_unet): diff --git a/onediff_sd_webui_extensions/compile/utils.py b/onediff_sd_webui_extensions/compile/utils.py index 5eda9cd9e..084c2aa10 100644 --- a/onediff_sd_webui_extensions/compile/utils.py +++ b/onediff_sd_webui_extensions/compile/utils.py @@ -3,11 +3,13 @@ import torch from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel -from modules import sd_models_types +from modules import sd_models_types, shared from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from onediff.infer_compiler import DeployableModule +from .backend import OneDiffBackend + def disable_unet_checkpointing( unet_model: Union[LdmUNetModel, SgmUNetModel] @@ -25,6 +27,18 @@ def disable_unet_checkpointing( return unet_model +def get_onediff_backend() -> OneDiffBackend: + return OneDiffBackend(shared.opts.onediff_compiler_backend) + + +def is_oneflow_backend(backend: Union[OneDiffBackend, None] = None) -> bool: + return (backend or get_onediff_backend()) == OneDiffBackend.ONEFLOW + + +def is_nexfort_backend(backend: Union[OneDiffBackend, None] = None) -> bool: + return (backend or get_onediff_backend()) == OneDiffBackend.NEXFORT + + @dataclasses.dataclass class OneDiffCompiledGraph: name: str = None diff --git a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py index 35605949e..c3ac917d5 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py @@ -1,6 +1,6 @@ -from .compile import onediff_controlnet_decorator, compile_controlnet_ldm_unet +from .compile import compile_controlnet_ldm_unet, onediff_controlnet_decorator from .hijack import hijack_controlnet_extension __all__ = [ "onediff_controlnet_decorator", -] \ No newline at end of file +] diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index bf399368e..559ef53d9 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -1,10 +1,19 @@ -from .utils import check_if_controlnet_enabled -from modules import shared -from .hijack import hijack_controlnet_extension -import onediff_shared from functools import wraps + +import onediff_shared +from compile.utils import ( + OneDiffCompiledGraph, + disable_unet_checkpointing, + get_onediff_backend, + is_nexfort_backend, + is_oneflow_backend, +) +from modules import shared + from onediff.infer_compiler import oneflow_compile -from compile.utils import disable_unet_checkpointing, OneDiffCompiledGraph + +from .hijack import hijack_controlnet_extension +from .utils import check_if_controlnet_enabled def onediff_controlnet_decorator(func): @@ -27,19 +36,17 @@ def wrapper(self, p, *arg, **kwargs): def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=None): - backend = backend or shared.opts.onediff_compiler_backend - from ldm.modules.attention import BasicTransformerBlock as BasicTransformerBlockLDM - from ldm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockLDM - from sgm.modules.attention import BasicTransformerBlock as BasicTransformerBlockSGM - from sgm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockSGM + backend = backend or get_onediff_backend() - if backend == "oneflow": + if is_oneflow_backend(): disable_unet_checkpointing(unet_model) compiled_model = oneflow_compile(unet_model, options=options) - elif backend == "nexfort": - raise NotImplementedError("nexfort backend for controlnet is not implemented yet") + elif is_nexfort_backend(): + raise NotImplementedError( + "nexfort backend for controlnet is not implemented yet" + ) # TODO: refine here compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) compiled_graph.eager_module = unet_model compiled_graph.name += "_controlnet" - return compiled_graph \ No newline at end of file + return compiled_graph diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 7ceb9d67a..8e17ca7aa 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -1,7 +1,5 @@ import onediff_shared import torch -from modules.sd_hijack_utils import CondFunc -from onediff_utils import check_structure_change, singleton_decorator import torch as th from compile import OneDiffCompiledGraph from compile.oneflow.mock.common import ( @@ -12,12 +10,15 @@ from ldm.modules.attention import BasicTransformerBlock, CrossAttention from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 +from modules.sd_hijack_utils import CondFunc +from onediff_utils import check_structure_change, singleton_decorator from .utils import get_controlnet_script def hijacked_main_entry(self, p): from compile.oneflow.mock.controlnet import TorchOnediffControlNetModel + from .compile import compile_controlnet_ldm_unet self._original_controlnet_main_entry(p) diff --git a/onediff_sd_webui_extensions/onediff_controlnet/utils.py b/onediff_sd_webui_extensions/onediff_controlnet/utils.py index 025773642..5cfbd22e4 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/utils.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/utils.py @@ -1,6 +1,8 @@ from functools import wraps + import onediff_shared + def check_if_controlnet_ext_loaded() -> bool: from modules import extensions @@ -20,4 +22,3 @@ def check_if_controlnet_enabled(p): controlnet_script_class is not None and len(controlnet_script_class.get_enabled_units(p)) != 0 ) - diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index 3af881a8c..f6673521a 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -1,4 +1,4 @@ -from compile import OneDiffCompiledGraph +from compile import OneDiffBackend, OneDiffCompiledGraph current_unet_graph = OneDiffCompiledGraph() current_quantization = False @@ -9,7 +9,7 @@ "is_ssd": False, } onediff_enabled = False -onediff_backend = "oneflow" +onediff_backend = OneDiffBackend.NEXFORT # controlnet controlnet_enabled = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index dbab0f354..cae03d4a0 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -1,12 +1,16 @@ import os -from contextlib import contextmanager from functools import wraps from pathlib import Path from textwrap import dedent from zipfile import BadZipFile import onediff_shared -import oneflow as flow + +from onediff.utils.import_utils import is_oneflow_available + +if is_oneflow_available(): + import oneflow as flow +from compile import is_oneflow_backend from modules import shared from modules.devices import torch_gc @@ -123,7 +127,8 @@ def wrapper(self, p, *arg, **kwargs): onediff_shared.onediff_enabled = False onediff_shared.previous_unet_type.update(**get_model_type(shared.sd_model)) torch_gc() - flow.cuda.empty_cache() + if is_oneflow_backend(): + flow.cuda.empty_cache() return wrapper @@ -151,5 +156,5 @@ def get_model_type(model): def onediff_gc(): torch_gc() - if shared.opts.onediff_compiler_backend == "oneflow": + if is_oneflow_backend(): flow.cuda.empty_cache() diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 0372ecb98..945be6ea0 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -6,10 +6,17 @@ import modules.shared as shared import onediff_controlnet import onediff_shared -from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph +from compile import ( + OneDiffBackend, + SD21CompileCtx, + VaeCompileCtx, + get_compiled_graph, + get_onediff_backend, + is_nexfort_backend, + is_oneflow_backend, +) from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks -from modules.devices import torch_gc from modules.processing import process_images from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack @@ -115,19 +122,18 @@ def run( sd_models.reload_model_weights() onediff_gc() - backend = backend or shared.opts.onediff_compiler_backend + backend = backend or get_onediff_backend() - if backend == "oneflow": + # init backend + if is_oneflow_backend(backend): from compile.oneflow.utils import init_oneflow_backend init_oneflow_backend() - elif backend == "nexfort": + elif is_nexfort_backend(backend): from compile.nexfort.utils import init_nexfort_backend init_nexfort_backend() - backend = backend or shared.opts.onediff_compiler_backend - current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( shared.sd_model.sd_checkpoint_info.name @@ -187,7 +193,7 @@ def on_ui_settings(): "oneflow", "Backend for onediff compiler", gr.Radio, - {"choices": ["oneflow", "nexfort"]}, + {"choices": [OneDiffBackend.ONEFLOW, OneDiffBackend.NEXFORT]}, section=section, ), ) From 190a4eb903c312a0226071248cab4e54fed995a1 Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 24 Jun 2024 21:51:57 +0800 Subject: [PATCH 21/24] nexfort support controlnet --- .../compile/__init__.py | 2 + .../compile/backend.py | 6 + .../compile/compile.py | 7 +- .../compile/compile_utils.py | 2 +- .../compile/nexfort/utils.py | 17 ++- .../compile/oneflow/mock/controlnet.py | 139 +----------------- .../compile/oneflow/utils.py | 6 +- onediff_sd_webui_extensions/compile/utils.py | 17 ++- .../onediff_controlnet/__init__.py | 3 +- .../onediff_controlnet/compile.py | 32 ++-- .../onediff_controlnet/hijack.py | 53 +++---- .../onediff_controlnet/model.py | 124 ++++++++++++++++ .../onediff_controlnet/utils.py | 5 - onediff_sd_webui_extensions/onediff_utils.py | 26 +++- 14 files changed, 240 insertions(+), 199 deletions(-) create mode 100644 onediff_sd_webui_extensions/onediff_controlnet/model.py diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 9b83b18d5..4ccdb685e 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -6,6 +6,7 @@ get_onediff_backend, is_nexfort_backend, is_oneflow_backend, + init_backend, ) from .vae import VaeCompileCtx @@ -18,4 +19,5 @@ "get_onediff_backend", "is_oneflow_backend", "is_nexfort_backend", + "init_backend", ] diff --git a/onediff_sd_webui_extensions/compile/backend.py b/onediff_sd_webui_extensions/compile/backend.py index e4c592a1a..c12a0ed6f 100644 --- a/onediff_sd_webui_extensions/compile/backend.py +++ b/onediff_sd_webui_extensions/compile/backend.py @@ -4,3 +4,9 @@ class OneDiffBackend(Enum): ONEFLOW = "oneflow" NEXFORT = "nexfort" + + def __str__(self): + return self.value + + def __repr__(self): + return f"<{self.__class__.__name__}.{self.name}: {self.value}>" diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index ed97ba940..5b151256e 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -14,13 +14,13 @@ def get_compiled_graph( - sd_model, *, backend, quantization=None, options=None + sd_model, unet_model=None, *, backend=None, quantization=None, options=None ) -> OneDiffCompiledGraph: - diffusion_model = sd_model.model.diffusion_model + diffusion_model = unet_model or sd_model.model.diffusion_model compiled_unet = onediff_compile( diffusion_model, backend=backend, quantization=quantization, options=options ) - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) + return OneDiffCompiledGraph(sd_model, diffusion_model, compiled_unet, quantization) def onediff_compile( @@ -62,6 +62,7 @@ def compile_unet_oneflow(unet_model, *, quantization=False, options=None): def compile_unet_nexfort(unet_model, *, quantization=False, options=None): + # TODO: support nexfort quant if quantization: raise NotImplementedError( "Quantization for nexfort backend is not implemented yet." diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py index d79278be2..1e40a186b 100644 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -66,4 +66,4 @@ def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: if "forward" in diffusion_model.__dict__: diffusion_model.__dict__.pop("forward") compiled_unet = compile_unet(diffusion_model, quantization=quantization) - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) + return OneDiffCompiledGraph(sd_model, diffusion_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/nexfort/utils.py b/onediff_sd_webui_extensions/compile/nexfort/utils.py index 84113b6c7..b25a91313 100644 --- a/onediff_sd_webui_extensions/compile/nexfort/utils.py +++ b/onediff_sd_webui_extensions/compile/nexfort/utils.py @@ -28,6 +28,21 @@ def init_nexfort_backend(): lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, ) + def hijack_groupnorm32_forward(orig_func, self, x): + return super(type(self), self).forward(x) + # return self.forward(x) + + CondFunc( + "ldm.modules.diffusionmodules.util.GroupNorm32.forward", + hijack_groupnorm32_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + CondFunc( + "sgm.modules.diffusionmodules.util.GroupNorm32.forward", + hijack_groupnorm32_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + @torch.autocast("cuda", enabled=False) def onediff_nexfort_unet_sgm_forward( @@ -134,7 +149,7 @@ class SdOptimizationNexfort(SdOptimization): priority = 10 def is_available(self): - is_nexfort_available() + return is_nexfort_available() def apply(self): ldm.modules.attention.CrossAttention.forward = ( diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py index 9933b08f8..dbf58e808 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py @@ -1,14 +1,6 @@ import oneflow as flow -import torch -import torch as th -from compile.oneflow.mock.common import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) -from ldm.modules.attention import CrossAttention +from compile.oneflow.mock.common import timestep_embedding from ldm.modules.diffusionmodules.openaimodel import UNetModel -from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import devices from onediff.infer_compiler.backends.oneflow.transform import proxy_class @@ -16,34 +8,10 @@ cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) -# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 -def torch_aligned_adding(base, x, require_channel_alignment): - if isinstance(x, float): - if x == 0.0: - return base - return base + x - - if require_channel_alignment: - zeros = torch.zeros_like(base) - zeros[:, : x.shape[1], ...] = x - x = zeros - - # resize to sample resolution - base_h, base_w = base.shape[-2:] - xh, xw = x.shape[-2:] - - if xh > 1 or xw > 1: - if base_h != xh or base_w != xw: - # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') - x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") - - return base + x - - # Due to the tracing mechanism in OneFlow, it's crucial to ensure that # the same conditional branches are taken during the first run as in subsequent runs. # Therefore, certain "optimizations" have been modified. -def oneflow_aligned_adding(base, x, require_channel_alignment): +def aligned_adding(base, x, require_channel_alignment): if isinstance(x, float): # remove `if x == 0.0: return base` here return base + x @@ -63,91 +31,6 @@ def oneflow_aligned_adding(base, x, require_channel_alignment): return base + x -class TorchOnediffControlNetModel(torch.nn.Module): - def __init__(self, unet): - super().__init__() - self.time_embed = unet.time_embed - self.input_blocks = unet.input_blocks - self.label_emb = getattr(unet, "label_emb", None) - self.middle_block = unet.middle_block - self.output_blocks = unet.output_blocks - self.out = unet.out - self.model_channels = unet.model_channels - - def forward( - self, - x, - timesteps, - context, - y, - total_t2i_adapter_embedding, - total_controlnet_embedding, - is_sdxl, - require_inpaint_hijack, - ): - from ldm.modules.diffusionmodules.util import timestep_embedding - - hs = [] - with th.no_grad(): - t_emb = cond_cast_unet( - timestep_embedding(timesteps, self.model_channels, repeat_only=False) - ) - emb = self.time_embed(t_emb) - - if is_sdxl: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for i, module in enumerate(self.input_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = module(h, emb, context) - - t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] - - if i in t2i_injection: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - hs.append(h) - - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = self.middle_block(h, emb, context) - - # U-Net Middle Block - h = torch_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) - - if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - # U-Net Decoder - for i, module in enumerate(self.output_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = th.cat( - [ - h, - torch_aligned_adding( - hs.pop(), - total_controlnet_embedding.pop(), - require_inpaint_hijack, - ), - ], - dim=1, - ) - h = module(h, emb, context) - - # U-Net Output - h = h.type(x.dtype) - h = self.out(h) - - return h - - class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): def forward( self, @@ -183,7 +66,7 @@ def forward( t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] if i in t2i_injection: - h = oneflow_aligned_adding( + h = aligned_adding( h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack ) @@ -193,12 +76,10 @@ def forward( h = self.middle_block(h, emb, context) # U-Net Middle Block - h = oneflow_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) + h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = oneflow_aligned_adding( + h = aligned_adding( h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack ) @@ -208,7 +89,7 @@ def forward( h = flow.cat( [ h, - oneflow_aligned_adding( + aligned_adding( hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack, @@ -224,11 +105,3 @@ def forward( h = self.out(h) return h - - -torch2oflow_class_map = { - CrossAttention: CrossAttentionOflow, - GroupNorm32: GroupNorm32Oflow, - TorchOnediffControlNetModel: OneFlowOnediffControlNetModel, -} -# register(package_names=["scripts.hook"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index 6a7aca974..9f168c048 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -11,11 +11,7 @@ def init_oneflow_backend(): "Backend oneflow for OneDiff is invalid, please make sure you have installed OneFlow" ) - from .mock import controlnet, ldm, sgm + from .mock import ldm, sgm register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) register(package_names=["sgm"], torch2oflow_class_map=sgm.torch2oflow_class_map) - register( - package_names=["scripts.hook"], - torch2oflow_class_map=controlnet.torch2oflow_class_map, - ) diff --git a/onediff_sd_webui_extensions/compile/utils.py b/onediff_sd_webui_extensions/compile/utils.py index 084c2aa10..abbc44391 100644 --- a/onediff_sd_webui_extensions/compile/utils.py +++ b/onediff_sd_webui_extensions/compile/utils.py @@ -39,6 +39,20 @@ def is_nexfort_backend(backend: Union[OneDiffBackend, None] = None) -> bool: return (backend or get_onediff_backend()) == OneDiffBackend.NEXFORT +def init_backend(backend: Union[OneDiffBackend, None] = None): + backend = backend or get_onediff_backend() + if is_oneflow_backend(backend): + from .oneflow.utils import init_oneflow_backend + + init_oneflow_backend() + elif is_nexfort_backend(backend): + from .nexfort.utils import init_nexfort_backend + + init_nexfort_backend() + else: + raise NotImplementedError(f"invalid backend {backend}") + + @dataclasses.dataclass class OneDiffCompiledGraph: name: str = None @@ -51,6 +65,7 @@ class OneDiffCompiledGraph: def __init__( self, sd_model: sd_models_types.WebuiSdModel = None, + unet_model=None, graph_module: DeployableModule = None, quantized=False, ): @@ -59,6 +74,6 @@ def __init__( self.name = sd_model.sd_checkpoint_info.name self.filename = sd_model.sd_checkpoint_info.filename self.sha = sd_model.sd_model_hash - self.eager_module = sd_model.model.diffusion_model + self.eager_module = unet_model or sd_model.model.diffusion_model self.graph_module = graph_module self.quantized = quantized diff --git a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py index c3ac917d5..ce278f071 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py @@ -1,5 +1,4 @@ -from .compile import compile_controlnet_ldm_unet, onediff_controlnet_decorator -from .hijack import hijack_controlnet_extension +from .compile import onediff_controlnet_decorator __all__ = [ "onediff_controlnet_decorator", diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index 559ef53d9..a8bd14038 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -1,20 +1,19 @@ from functools import wraps +import networks import onediff_shared from compile.utils import ( - OneDiffCompiledGraph, disable_unet_checkpointing, - get_onediff_backend, is_nexfort_backend, is_oneflow_backend, ) -from modules import shared - -from onediff.infer_compiler import oneflow_compile +from compile import get_compiled_graph from .hijack import hijack_controlnet_extension from .utils import check_if_controlnet_enabled +from compile.oneflow.mock.controlnet import OneFlowOnediffControlNetModel + def onediff_controlnet_decorator(func): @wraps(func) @@ -36,17 +35,24 @@ def wrapper(self, p, *arg, **kwargs): def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=None): - backend = backend or get_onediff_backend() - if is_oneflow_backend(): - disable_unet_checkpointing(unet_model) - compiled_model = oneflow_compile(unet_model, options=options) - elif is_nexfort_backend(): - raise NotImplementedError( - "nexfort backend for controlnet is not implemented yet" + from onediff.infer_compiler.backends.oneflow.transform import register + from .model import OnediffControlNetModel + + register( + package_names=["scripts.hook"], + torch2oflow_class_map={ + OnediffControlNetModel: OneFlowOnediffControlNetModel, + }, ) + elif is_nexfort_backend(): + # TODO: restore LoRA here + if networks.originals is not None: + networks.originals.undo() # TODO: refine here - compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) + compiled_graph = get_compiled_graph( + sd_model, unet_model, backend=backend, options=options + ) compiled_graph.eager_module = unet_model compiled_graph.name += "_controlnet" return compiled_graph diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 8e17ca7aa..675e74712 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -1,15 +1,7 @@ import onediff_shared import torch -import torch as th -from compile import OneDiffCompiledGraph -from compile.oneflow.mock.common import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) -from ldm.modules.attention import BasicTransformerBlock, CrossAttention -from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel -from ldm.modules.diffusionmodules.util import GroupNorm32 +from compile import is_oneflow_backend +from ldm.modules.diffusionmodules.openaimodel import UNetModel from modules.sd_hijack_utils import CondFunc from onediff_utils import check_structure_change, singleton_decorator @@ -17,7 +9,7 @@ def hijacked_main_entry(self, p): - from compile.oneflow.mock.controlnet import TorchOnediffControlNetModel + from .model import OnediffControlNetModel from .compile import compile_controlnet_ldm_unet @@ -28,8 +20,9 @@ def hijacked_main_entry(self, p): structure_changed = check_structure_change( onediff_shared.previous_unet_type, sd_ldm ) + # TODO: restore here if onediff_shared.controlnet_compiled is False or structure_changed: - onediff_model = TorchOnediffControlNetModel(unet) + onediff_model = OnediffControlNetModel(unet) onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( sd_ldm, onediff_model ) @@ -504,7 +497,10 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): outer.attention_auto_machine = AutoMachine.Read outer.gn_auto_machine = AutoMachine.Read - # modified by OneDiff + # ------ modified by OneDiff below ------ + x = x.half() + context = context.half() + y = y.half() if y is not None else y h = onediff_shared.current_unet_graph.graph_module( x, timesteps, @@ -515,6 +511,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): is_sdxl, require_inpaint_hijack, ) + # ------ modified by OneDiff above ------ # Post-processing for color fix for param in outer.control_params: @@ -583,16 +580,13 @@ def move_all_control_model_to_cpu(): def forward_webui(*args, **kwargs): # ------ modified by OneDiff below ------ forward_func = None - if ( - "forward" - in onediff_shared.current_unet_graph.graph_module._torch_module.__dict__ - ): - forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( - "forward" - ) - _original_forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( - "_original_forward" - ) + graph_module = onediff_shared.current_unet_graph.graph_module + if is_oneflow_backend(): + if "forward" in graph_module._torch_module.__dict__: + forward_func = graph_module._torch_module.__dict__.pop("forward") + _original_forward_func = graph_module._torch_module.__dict__.pop( + "_original_forward" + ) # ------ modified by OneDiff above ------ # webui will handle other compoments @@ -608,13 +602,12 @@ def forward_webui(*args, **kwargs): move_all_control_model_to_cpu() # ------ modified by OneDiff below ------ - if forward_func is not None: - onediff_shared.current_unet_graph.graph_module._torch_module.forward = ( - forward_func - ) - onediff_shared.current_unet_graph.graph_module._torch_module._original_forward = ( - _original_forward_func - ) + if is_oneflow_backend(): + if forward_func is not None: + graph_module._torch_module.forward = forward_func + graph_module._torch_module._original_forward = ( + _original_forward_func + ) # ------ modified by OneDiff above ------ def hacked_basic_transformer_inner_forward(self, x, context=None): diff --git a/onediff_sd_webui_extensions/onediff_controlnet/model.py b/onediff_sd_webui_extensions/onediff_controlnet/model.py new file mode 100644 index 000000000..4c577a013 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/model.py @@ -0,0 +1,124 @@ +import torch +import torch as th +from modules import devices + +cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + + +# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 +def aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + if x == 0.0: + return base + return base + x + + if require_channel_alignment: + zeros = torch.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1: + if base_h != xh or base_w != xw: + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + + return base + x.half() + + +class OnediffControlNetModel(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.time_embed = unet.time_embed + self.input_blocks = unet.input_blocks + self.label_emb = getattr(unet, "label_emb", None) + self.middle_block = unet.middle_block + self.output_blocks = unet.output_blocks + self.out = unet.out + self.model_channels = unet.model_channels + # import ipdb; ipdb.set_trace() + self.convert_to_fp16 = unet.convert_to_fp16.__get__(self) + # print("something") + + @torch.autocast(device_type="cuda", enabled=False) + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + from ldm.modules.diffusionmodules.util import timestep_embedding + + # cast to half + x = x.half() + context = context.half() + if y is not None: + y = y.half() + + hs = [] + with th.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + + t_emb = t_emb.half() + emb = self.time_embed(t_emb).half() + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y).half() + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = th.cat( + [ + h, + aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h diff --git a/onediff_sd_webui_extensions/onediff_controlnet/utils.py b/onediff_sd_webui_extensions/onediff_controlnet/utils.py index 5cfbd22e4..37cda8489 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/utils.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/utils.py @@ -1,8 +1,3 @@ -from functools import wraps - -import onediff_shared - - def check_if_controlnet_ext_loaded() -> bool: from modules import extensions diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index cae03d4a0..c7e0e59c3 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -15,6 +15,7 @@ from modules.devices import torch_gc from onediff.infer_compiler import DeployableModule +from compile import init_backend hints_message = dedent( """\ @@ -119,16 +120,31 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): def onediff_enabled_decorator(func): @wraps(func) - def wrapper(self, p, *arg, **kwargs): + def wrapper( + self, + p, + quantization=False, + compiler_cache=None, + saved_cache_name="", + always_recompile=False, + backend=None, + ): onediff_shared.onediff_enabled = True + init_backend(backend) try: - return func(self, p, *arg, **kwargs) + return func( + self, + p, + quantization=False, + compiler_cache=None, + saved_cache_name="", + always_recompile=False, + backend=None, + ) finally: onediff_shared.onediff_enabled = False onediff_shared.previous_unet_type.update(**get_model_type(shared.sd_model)) - torch_gc() - if is_oneflow_backend(): - flow.cuda.empty_cache() + onediff_gc() return wrapper From c386c74267327b77e87f17518977d71fe8edade7 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 25 Jun 2024 00:15:58 +0800 Subject: [PATCH 22/24] finally launch without oneflow --- .../compile/__init__.py | 2 +- .../compile/compile.py | 6 ++-- .../compile/oneflow/mock/ldm.py | 10 +++++- .../compile/oneflow/mock/sgm.py | 11 ++++++- .../compile/oneflow/mock/vae.py | 12 +++++++ .../compile/oneflow/utils.py | 3 +- onediff_sd_webui_extensions/compile/vae.py | 26 +++++++-------- .../onediff_controlnet/compile.py | 9 ++--- .../onediff_controlnet/hijack.py | 3 +- onediff_sd_webui_extensions/onediff_hijack.py | 33 ------------------- onediff_sd_webui_extensions/onediff_lora.py | 10 +++--- onediff_sd_webui_extensions/onediff_utils.py | 20 +++++++++-- .../scripts/onediff.py | 4 ++- 13 files changed, 81 insertions(+), 68 deletions(-) create mode 100644 onediff_sd_webui_extensions/compile/oneflow/mock/vae.py diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 4ccdb685e..60827fd87 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -4,9 +4,9 @@ from .utils import ( OneDiffCompiledGraph, get_onediff_backend, + init_backend, is_nexfort_backend, is_oneflow_backend, - init_backend, ) from .vae import VaeCompileCtx diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 5b151256e..22a4d8628 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -1,10 +1,8 @@ +from compile import OneDiffBackend from modules.sd_hijack import apply_optimizations from onediff.infer_compiler import compile, oneflow_compile -from compile import OneDiffBackend - -from .quantization import quant_unet_oneflow from .utils import ( OneDiffCompiledGraph, disable_unet_checkpointing, @@ -57,6 +55,8 @@ def compile_unet_oneflow(unet_model, *, quantization=False, options=None): compiled_unet_model = oneflow_compile(unet_model, options=options) if quantization: + from .quantization import quant_unet_oneflow + compiled_unet_model = quant_unet_oneflow(compiled_unet_model) return compiled_unet_model diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index d8d1ab8c8..8e9295ca5 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -8,6 +8,14 @@ from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding +def cat(tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + a = flow.nn.functional.interpolate_like(a, like=b, mode="nearest") + tensors = (a, b) + return flow.cat(tensors, *args, **kwargs) + + # https://github.com/Stability-AI/stablediffusion/blob/b4bdae9916f628461e1e4edbc62aafedebb9f7ed/ldm/modules/diffusionmodules/openaimodel.py#L775 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): @@ -28,7 +36,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: - h = flow.cat([h, hs.pop()], dim=1) + h = cat([h, hs.pop()], dim=1) h = module(h, emb, context) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index b840eaf21..a071bd7e5 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -8,6 +8,14 @@ from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding +def cat(tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + a = flow.nn.functional.interpolate_like(a, like=b, mode="nearest") + tensors = (a, b) + return flow.cat(tensors, *args, **kwargs) + + # https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/openaimodel.py#L816 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): @@ -29,7 +37,8 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: - h = flow.cat([h, hs.pop()], dim=1) + # h = flow.cat([h, hs.pop()], dim=1) + h = cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py new file mode 100644 index 000000000..329ba6e39 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py @@ -0,0 +1,12 @@ +from modules.sd_vae_approx import VAEApprox + +from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register + + +class VAEApproxOflow(proxy_class(VAEApprox)): + pass + + +torch2oflow_class_map = { + VAEApprox: VAEApproxOflow, +} diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index 9f168c048..006dfd894 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -11,7 +11,8 @@ def init_oneflow_backend(): "Backend oneflow for OneDiff is invalid, please make sure you have installed OneFlow" ) - from .mock import ldm, sgm + from .mock import ldm, sgm, vae register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) register(package_names=["sgm"], torch2oflow_class_map=sgm.torch2oflow_class_map) + register(package_names=["modules"], torch2oflow_class_map=vae.torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/vae.py b/onediff_sd_webui_extensions/compile/vae.py index f3dd03204..8b4f045ff 100644 --- a/onediff_sd_webui_extensions/compile/vae.py +++ b/onediff_sd_webui_extensions/compile/vae.py @@ -1,29 +1,20 @@ +from compile.utils import get_onediff_backend from modules import shared from modules.sd_vae_approx import VAEApprox from modules.sd_vae_approx import model as get_vae_model from modules.sd_vae_approx import sd_vae_approx_models -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler import compile + +# from compile import get_compiled_graph __all__ = ["VaeCompileCtx"] compiled_models = {} -class VAEApproxOflow(proxy_class(VAEApprox)): - pass - - -torch2oflow_class_map = { - VAEApprox: VAEApproxOflow, -} - -register(package_names=["modules"], torch2oflow_class_map=torch2oflow_class_map) - - class VaeCompileCtx(object): - def __init__(self, options=None): + def __init__(self, backend=None, options=None): self._options = options # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/75336dfc84cae280036bc52a6805eb10d9ae30ba/modules/sd_vae_approx.py#L43 self._model_name = ( @@ -32,14 +23,19 @@ def __init__(self, options=None): else "model.pt" ) self._original_model = get_vae_model() + self.backend = backend def __enter__(self): + # TODO: support nexfort here if self._original_model is None: return global compiled_models model = compiled_models.get(self._model_name) + backend = str(self.backend or get_onediff_backend()) if model is None: - model = oneflow_compile(self._original_model, options=self._options) + model = compile( + self._original_model, backend=backend, options=self._options + ) compiled_models[self._model_name] = model sd_vae_approx_models[self._model_name] = model diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index a8bd14038..38b706e66 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -1,19 +1,17 @@ from functools import wraps -import networks +import networks import onediff_shared +from compile import get_compiled_graph from compile.utils import ( disable_unet_checkpointing, is_nexfort_backend, is_oneflow_backend, ) -from compile import get_compiled_graph from .hijack import hijack_controlnet_extension from .utils import check_if_controlnet_enabled -from compile.oneflow.mock.controlnet import OneFlowOnediffControlNetModel - def onediff_controlnet_decorator(func): @wraps(func) @@ -36,7 +34,10 @@ def wrapper(self, p, *arg, **kwargs): def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=None): if is_oneflow_backend(): + from compile.oneflow.mock.controlnet import OneFlowOnediffControlNetModel + from onediff.infer_compiler.backends.oneflow.transform import register + from .model import OnediffControlNetModel register( diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 675e74712..fc8a68109 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -9,9 +9,8 @@ def hijacked_main_entry(self, p): - from .model import OnediffControlNetModel - from .compile import compile_controlnet_ldm_unet + from .model import OnediffControlNetModel self._original_controlnet_main_entry(p) sd_ldm = p.sd_model diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index 47b5c457c..0c6dc5e1b 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,40 +1,11 @@ from typing import Any, Mapping -import oneflow import torch -from compile.oneflow.mock import ldm, sgm from modules import sd_models from modules.sd_hijack_utils import CondFunc from onediff_shared import onediff_enabled -# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8 -class OneFlowHijackForUnet: - """ - This is oneflow, but with cat that resizes tensors to appropriate dimensions if they do not match; - this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 - """ - - def __getattr__(self, item): - if item == "cat": - return self.cat - if hasattr(oneflow, item): - return getattr(oneflow, item) - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{item}'" - ) - - def cat(self, tensors, *args, **kwargs): - if len(tensors) == 2: - a, b = tensors - a = oneflow.nn.functional.interpolate_like(a, like=b, mode="nearest") - tensors = (a, b) - return oneflow.cat(tensors, *args, **kwargs) - - -hijack_flow = OneFlowHijackForUnet() - - def unload_model_weights(sd_model=None, info=None): from modules import devices, lowvram, shared @@ -66,8 +37,6 @@ def unhijack_function(module, name, new_name): def do_hijack(): - ldm.flow = hijack_flow - sgm.flow = hijack_flow from modules import script_callbacks, sd_models script_callbacks.on_script_unloaded(undo_hijack) @@ -86,8 +55,6 @@ def do_hijack(): def undo_hijack(): - ldm.flow = oneflow - sgm.flow = oneflow from modules import sd_models unhijack_function( diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index a1f4da8da..5d8f6a4df 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,9 +1,7 @@ import torch +from compile.utils import is_oneflow_backend from onediff.infer_compiler import DeployableModule -from onediff.infer_compiler.backends.oneflow.param_utils import ( - update_graph_related_tensor, -) class HijackLoraActivate: @@ -52,8 +50,12 @@ def activate(self, p, params_list): ): continue networks.network_apply_weights(sub_module) - if isinstance(sub_module, torch.nn.Conv2d): + if is_oneflow_backend() and isinstance(sub_module, torch.nn.Conv2d): # TODO(WangYi): refine here + from onediff.infer_compiler.backends.oneflow.param_utils import ( + update_graph_related_tensor, + ) + try: update_graph_related_tensor(sub_module) except: diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index c7e0e59c3..823282f57 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -5,17 +5,18 @@ from zipfile import BadZipFile import onediff_shared +from importlib_metadata import version from onediff.utils.import_utils import is_oneflow_available if is_oneflow_available(): import oneflow as flow -from compile import is_oneflow_backend + +from compile import init_backend, is_oneflow_backend from modules import shared from modules.devices import torch_gc from onediff.infer_compiler import DeployableModule -from compile import init_backend hints_message = dedent( """\ @@ -174,3 +175,18 @@ def onediff_gc(): torch_gc() if is_oneflow_backend(): flow.cuda.empty_cache() + + +def varify_can_use_quantization(): + try: + import oneflow + + if version("oneflow") < "0.9.1": + return False + except ImportError as e: + return False + try: + import onediff_quant + except ImportError as e: + return False + return hasattr(oneflow._C, "dynamic_quantization") diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 945be6ea0..8713249df 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -21,6 +21,8 @@ from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate + +# from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff_utils import ( check_structure_change, get_all_compiler_caches, @@ -30,9 +32,9 @@ onediff_gc, refresh_all_compiler_caches, save_graph, + varify_can_use_quantization, ) -from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff.utils import logger, parse_boolean_from_env """oneflow_compiled UNetModel""" From 31b1accde3a801e204be4b77b1aed15ce44bda98 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 25 Jun 2024 14:44:38 +0800 Subject: [PATCH 23/24] refine --- onediff_sd_webui_extensions/README.md | 11 ++- .../compile/compile_utils.py | 69 ------------------- .../compile/oneflow/mock/vae.py | 3 +- onediff_sd_webui_extensions/compile/vae.py | 2 - .../onediff_controlnet/compile.py | 2 - .../onediff_controlnet/hijack.py | 13 +++- onediff_sd_webui_extensions/onediff_utils.py | 11 +-- 7 files changed, 24 insertions(+), 87 deletions(-) delete mode 100644 onediff_sd_webui_extensions/compile/compile_utils.py diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index 0e7b14d14..573b77dff 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -40,21 +40,18 @@ ln -s "$(pwd)/onediff/onediff_sd_webui_extensions" "$(pwd)/stable-diffusion-webu cd stable-diffusion-webui # Install all of stable-diffusion-webui's dependencies. -venv_dir=- bash webui.sh --port=8080 - -# Exit webui server and upgrade some of the components that conflict with onediff. -cd repositories/generative-models && git checkout 9d759324 && cd - -pip install -U einops==0.7.0 +# If you install as root user, append `-f` to the end of the command line. +venv_dir=- bash webui.sh ``` ## Run stable-diffusion-webui service ```bash cd stable-diffusion-webui -python webui.py --port 8080 +python webui.py --port 7860 ``` -Accessing http://server:8080/ from a web browser. +Accessing http://server:7860/ from a web browser. ## Extensions Usage diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py deleted file mode 100644 index 1e40a186b..000000000 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import warnings -from pathlib import Path -from typing import Dict, Union - -from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM -from modules.sd_models import select_checkpoint -from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM - -from onediff.optimization.quant_optimizer import ( - quantize_model, - varify_can_use_quantization, -) -from onediff.utils import logger - -from .compile_ldm import compile_ldm_unet -from .compile_sgm import compile_sgm_unet -from .onediff_compiled_graph import OneDiffCompiledGraph - - -def compile_unet( - unet_model, quantization=False, *, options=None, -): - if isinstance(unet_model, UNetModelLDM): - compiled_unet = compile_ldm_unet(unet_model, options=options) - elif isinstance(unet_model, UNetModelSGM): - compiled_unet = compile_sgm_unet(unet_model, options=options) - else: - warnings.warn( - f"Unsupported model type: {type(unet_model)} for compilation , skip", - RuntimeWarning, - ) - compiled_unet = unet_model - # In OneDiff Community, quantization can be True when called by api - if quantization and varify_can_use_quantization(): - calibrate_info = get_calibrate_info( - f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" - ) - compiled_unet = quantize_model( - compiled_unet, inplace=False, calibrate_info=calibrate_info - ) - return compiled_unet - - -def get_calibrate_info(filename: str) -> Union[None, Dict]: - calibration_path = Path(select_checkpoint().filename).parent / filename - if not calibration_path.exists(): - return None - - logger.info(f"Got calibrate info at {str(calibration_path)}") - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - return calibrate_info - - -def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: - diffusion_model = sd_model.model.diffusion_model - # for controlnet - if "forward" in diffusion_model.__dict__: - diffusion_model.__dict__.pop("forward") - compiled_unet = compile_unet(diffusion_model, quantization=quantization) - return OneDiffCompiledGraph(sd_model, diffusion_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py index 329ba6e39..10dfe64c9 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py @@ -1,8 +1,9 @@ from modules.sd_vae_approx import VAEApprox -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class +# Prevent re-importing modules.shared, which incorrectly initializes all its variables. class VAEApproxOflow(proxy_class(VAEApprox)): pass diff --git a/onediff_sd_webui_extensions/compile/vae.py b/onediff_sd_webui_extensions/compile/vae.py index 8b4f045ff..6e553b97a 100644 --- a/onediff_sd_webui_extensions/compile/vae.py +++ b/onediff_sd_webui_extensions/compile/vae.py @@ -1,12 +1,10 @@ from compile.utils import get_onediff_backend from modules import shared -from modules.sd_vae_approx import VAEApprox from modules.sd_vae_approx import model as get_vae_model from modules.sd_vae_approx import sd_vae_approx_models from onediff.infer_compiler import compile -# from compile import get_compiled_graph __all__ = ["VaeCompileCtx"] diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index 38b706e66..0eedf3ff3 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -4,7 +4,6 @@ import onediff_shared from compile import get_compiled_graph from compile.utils import ( - disable_unet_checkpointing, is_nexfort_backend, is_oneflow_backend, ) @@ -54,6 +53,5 @@ def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=N compiled_graph = get_compiled_graph( sd_model, unet_model, backend=backend, options=options ) - compiled_graph.eager_module = unet_model compiled_graph.name += "_controlnet" return compiled_graph diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index fc8a68109..27d49644d 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -9,9 +9,10 @@ def hijacked_main_entry(self, p): - from .compile import compile_controlnet_ldm_unet from .model import OnediffControlNetModel + from .compile import compile_controlnet_ldm_unet + self._original_controlnet_main_entry(p) sd_ldm = p.sd_model unet = sd_ldm.model.diffusion_model @@ -53,6 +54,16 @@ def hijack_controlnet_extension(p): ) +def unhijack_controlnet_extension(p): + controlnet_script = get_controlnet_script(p) + if controlnet_script is None: + return + + if hasattr(controlnet_script, "_original_controlnet_main_entry"): + controlnet_script.controlnet_main_entry = controlnet_script._original_controlnet_main_entry + delattr(controlnet_script, "_original_controlnet_main_entry") + + # We were intended to only hack the closure function `forward` # in the member function `hook` of the UnetHook class in the ControlNet extension. # But due to certain limitations, we were unable to directly only hack diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 823282f57..3cee852f5 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -15,6 +15,7 @@ from compile import init_backend, is_oneflow_backend from modules import shared from modules.devices import torch_gc +from modules.sd_hijack_utils import CondFunc from onediff.infer_compiler import DeployableModule @@ -136,11 +137,11 @@ def wrapper( return func( self, p, - quantization=False, - compiler_cache=None, - saved_cache_name="", - always_recompile=False, - backend=None, + quantization=quantization, + compiler_cache=compiler_cache, + saved_cache_name=saved_cache_name, + always_recompile=always_recompile, + backend=backend, ) finally: onediff_shared.onediff_enabled = False From 392516f3e8f9114ffc03f4394c49b7da87fa798f Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 25 Jun 2024 17:25:39 +0800 Subject: [PATCH 24/24] refine, support lora --- onediff_sd_webui_extensions/compile/vae.py | 2 -- .../onediff_controlnet/compile.py | 12 +++-------- .../onediff_controlnet/hijack.py | 16 +++++++------- onediff_sd_webui_extensions/onediff_shared.py | 4 ++-- onediff_sd_webui_extensions/onediff_utils.py | 6 +++++- .../scripts/onediff.py | 21 ++++--------------- 6 files changed, 22 insertions(+), 39 deletions(-) diff --git a/onediff_sd_webui_extensions/compile/vae.py b/onediff_sd_webui_extensions/compile/vae.py index 6e553b97a..172578501 100644 --- a/onediff_sd_webui_extensions/compile/vae.py +++ b/onediff_sd_webui_extensions/compile/vae.py @@ -5,7 +5,6 @@ from onediff.infer_compiler import compile - __all__ = ["VaeCompileCtx"] compiled_models = {} @@ -24,7 +23,6 @@ def __init__(self, backend=None, options=None): self.backend = backend def __enter__(self): - # TODO: support nexfort here if self._original_model is None: return global compiled_models diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py index 0eedf3ff3..6a8d54e1d 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/compile.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -1,12 +1,8 @@ from functools import wraps -import networks import onediff_shared from compile import get_compiled_graph -from compile.utils import ( - is_nexfort_backend, - is_oneflow_backend, -) +from compile.utils import is_nexfort_backend, is_oneflow_backend from .hijack import hijack_controlnet_extension from .utils import check_if_controlnet_enabled @@ -46,10 +42,8 @@ def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=N }, ) elif is_nexfort_backend(): - # TODO: restore LoRA here - if networks.originals is not None: - networks.originals.undo() - # TODO: refine here + # nothing need to do + pass compiled_graph = get_compiled_graph( sd_model, unet_model, backend=backend, options=options ) diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 27d49644d..6f7df7871 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -9,26 +9,24 @@ def hijacked_main_entry(self, p): - from .model import OnediffControlNetModel - + self._original_controlnet_main_entry(p) from .compile import compile_controlnet_ldm_unet + from .model import OnediffControlNetModel - self._original_controlnet_main_entry(p) + if not onediff_shared.onediff_enabled: + return sd_ldm = p.sd_model unet = sd_ldm.model.diffusion_model structure_changed = check_structure_change( onediff_shared.previous_unet_type, sd_ldm ) - # TODO: restore here - if onediff_shared.controlnet_compiled is False or structure_changed: + if not onediff_shared.controlnet_compiled or structure_changed: onediff_model = OnediffControlNetModel(unet) onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( sd_ldm, onediff_model ) onediff_shared.controlnet_compiled = True - else: - pass # When OneDiff is initializing, the controlnet extension has not yet been loaded. @@ -60,7 +58,9 @@ def unhijack_controlnet_extension(p): return if hasattr(controlnet_script, "_original_controlnet_main_entry"): - controlnet_script.controlnet_main_entry = controlnet_script._original_controlnet_main_entry + controlnet_script.controlnet_main_entry = ( + controlnet_script._original_controlnet_main_entry + ) delattr(controlnet_script, "_original_controlnet_main_entry") diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index f6673521a..d642d4270 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -1,4 +1,4 @@ -from compile import OneDiffBackend, OneDiffCompiledGraph +from compile import OneDiffCompiledGraph current_unet_graph = OneDiffCompiledGraph() current_quantization = False @@ -9,7 +9,7 @@ "is_ssd": False, } onediff_enabled = False -onediff_backend = OneDiffBackend.NEXFORT +onediff_backend = None # controlnet controlnet_enabled = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 3cee852f5..29b029bf1 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -4,6 +4,7 @@ from textwrap import dedent from zipfile import BadZipFile +import networks import onediff_shared from importlib_metadata import version @@ -15,7 +16,6 @@ from compile import init_backend, is_oneflow_backend from modules import shared from modules.devices import torch_gc -from modules.sd_hijack_utils import CondFunc from onediff.infer_compiler import DeployableModule @@ -132,6 +132,8 @@ def wrapper( backend=None, ): onediff_shared.onediff_enabled = True + if networks.originals is not None: + networks.originals.undo() init_backend(backend) try: return func( @@ -144,6 +146,8 @@ def wrapper( backend=backend, ) finally: + if networks.originals is not None: + networks.originals.__init__() onediff_shared.onediff_enabled = False onediff_shared.previous_unet_type.update(**get_model_type(shared.sd_model)) onediff_gc() diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 8713249df..06dd0fc11 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -12,8 +12,6 @@ VaeCompileCtx, get_compiled_graph, get_onediff_backend, - is_nexfort_backend, - is_oneflow_backend, ) from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks @@ -125,17 +123,6 @@ def run( onediff_gc() backend = backend or get_onediff_backend() - - # init backend - if is_oneflow_backend(backend): - from compile.oneflow.utils import init_oneflow_backend - - init_oneflow_backend() - elif is_nexfort_backend(backend): - from compile.nexfort.utils import init_nexfort_backend - - init_nexfort_backend() - current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( shared.sd_model.sd_checkpoint_info.name @@ -170,9 +157,9 @@ def run( f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.previous_unet_type}, skip compile" ) - with UnetCompileCtx( - not onediff_shared.controlnet_enabled - ), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + with UnetCompileCtx(not onediff_shared.controlnet_enabled), VaeCompileCtx( + backend=backend + ), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) @@ -193,7 +180,7 @@ def on_ui_settings(): "onediff_compiler_backend", shared.OptionInfo( "oneflow", - "Backend for onediff compiler", + "Backend for onediff compiler (if you switch backend, you need to restart webui service)", gr.Radio, {"choices": [OneDiffBackend.ONEFLOW, OneDiffBackend.NEXFORT]}, section=section,