diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index 0e7b14d14..573b77dff 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -40,21 +40,18 @@ ln -s "$(pwd)/onediff/onediff_sd_webui_extensions" "$(pwd)/stable-diffusion-webu cd stable-diffusion-webui # Install all of stable-diffusion-webui's dependencies. -venv_dir=- bash webui.sh --port=8080 - -# Exit webui server and upgrade some of the components that conflict with onediff. -cd repositories/generative-models && git checkout 9d759324 && cd - -pip install -U einops==0.7.0 +# If you install as root user, append `-f` to the end of the command line. +venv_dir=- bash webui.sh ``` ## Run stable-diffusion-webui service ```bash cd stable-diffusion-webui -python webui.py --port 8080 +python webui.py --port 7860 ``` -Accessing http://server:8080/ from a web browser. +Accessing http://server:7860/ from a web browser. ## Extensions Usage diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 89454fd4c..60827fd87 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -1,6 +1,13 @@ +from .backend import OneDiffBackend from .compile import get_compiled_graph from .sd2 import SD21CompileCtx -from .utils import OneDiffCompiledGraph +from .utils import ( + OneDiffCompiledGraph, + get_onediff_backend, + init_backend, + is_nexfort_backend, + is_oneflow_backend, +) from .vae import VaeCompileCtx __all__ = [ @@ -8,4 +15,9 @@ "SD21CompileCtx", "VaeCompileCtx", "OneDiffCompiledGraph", + "OneDiffBackend", + "get_onediff_backend", + "is_oneflow_backend", + "is_nexfort_backend", + "init_backend", ] diff --git a/onediff_sd_webui_extensions/compile/backend.py b/onediff_sd_webui_extensions/compile/backend.py new file mode 100644 index 000000000..c12a0ed6f --- /dev/null +++ b/onediff_sd_webui_extensions/compile/backend.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class OneDiffBackend(Enum): + ONEFLOW = "oneflow" + NEXFORT = "nexfort" + + def __str__(self): + return self.value + + def __repr__(self): + return f"<{self.__class__.__name__}.{self.name}: {self.value}>" diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 4872cc578..22a4d8628 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -1,27 +1,38 @@ +from compile import OneDiffBackend from modules.sd_hijack import apply_optimizations from onediff.infer_compiler import compile, oneflow_compile -from .utils import OneDiffCompiledGraph, disable_unet_checkpointing -from .quantization import quant_unet_oneflow +from .utils import ( + OneDiffCompiledGraph, + disable_unet_checkpointing, + is_nexfort_backend, + is_oneflow_backend, +) def get_compiled_graph( - sd_model, *, backend, quantization=None, options=None + sd_model, unet_model=None, *, backend=None, quantization=None, options=None ) -> OneDiffCompiledGraph: - diffusion_model = sd_model.model.diffusion_model + diffusion_model = unet_model or sd_model.model.diffusion_model compiled_unet = onediff_compile( diffusion_model, backend=backend, quantization=quantization, options=options ) - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) + return OneDiffCompiledGraph(sd_model, diffusion_model, compiled_unet, quantization) -def onediff_compile(unet_model, *, quantization=False, backend="oneflow", options=None): - if backend == "oneflow": +def onediff_compile( + unet_model, + *, + quantization: bool = False, + backend: OneDiffBackend = None, + options=None, +): + if is_oneflow_backend(backend): return compile_unet_oneflow( unet_model, quantization=quantization, options=options ) - elif backend == "nexfort": + elif is_nexfort_backend(backend): return compile_unet_nexfort( unet_model, quantization=quantization, options=options ) @@ -44,11 +55,14 @@ def compile_unet_oneflow(unet_model, *, quantization=False, options=None): compiled_unet_model = oneflow_compile(unet_model, options=options) if quantization: + from .quantization import quant_unet_oneflow + compiled_unet_model = quant_unet_oneflow(compiled_unet_model) return compiled_unet_model def compile_unet_nexfort(unet_model, *, quantization=False, options=None): + # TODO: support nexfort quant if quantization: raise NotImplementedError( "Quantization for nexfort backend is not implemented yet." diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py deleted file mode 100644 index d79278be2..000000000 --- a/onediff_sd_webui_extensions/compile/compile_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import warnings -from pathlib import Path -from typing import Dict, Union - -from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM -from modules.sd_models import select_checkpoint -from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM - -from onediff.optimization.quant_optimizer import ( - quantize_model, - varify_can_use_quantization, -) -from onediff.utils import logger - -from .compile_ldm import compile_ldm_unet -from .compile_sgm import compile_sgm_unet -from .onediff_compiled_graph import OneDiffCompiledGraph - - -def compile_unet( - unet_model, quantization=False, *, options=None, -): - if isinstance(unet_model, UNetModelLDM): - compiled_unet = compile_ldm_unet(unet_model, options=options) - elif isinstance(unet_model, UNetModelSGM): - compiled_unet = compile_sgm_unet(unet_model, options=options) - else: - warnings.warn( - f"Unsupported model type: {type(unet_model)} for compilation , skip", - RuntimeWarning, - ) - compiled_unet = unet_model - # In OneDiff Community, quantization can be True when called by api - if quantization and varify_can_use_quantization(): - calibrate_info = get_calibrate_info( - f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" - ) - compiled_unet = quantize_model( - compiled_unet, inplace=False, calibrate_info=calibrate_info - ) - return compiled_unet - - -def get_calibrate_info(filename: str) -> Union[None, Dict]: - calibration_path = Path(select_checkpoint().filename).parent / filename - if not calibration_path.exists(): - return None - - logger.info(f"Got calibrate info at {str(calibration_path)}") - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - return calibrate_info - - -def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: - diffusion_model = sd_model.model.diffusion_model - # for controlnet - if "forward" in diffusion_model.__dict__: - diffusion_model.__dict__.pop("forward") - compiled_unet = compile_unet(diffusion_model, quantization=quantization) - return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile/nexfort/utils.py b/onediff_sd_webui_extensions/compile/nexfort/utils.py index a677e72de..b25a91313 100644 --- a/onediff_sd_webui_extensions/compile/nexfort/utils.py +++ b/onediff_sd_webui_extensions/compile/nexfort/utils.py @@ -11,6 +11,8 @@ from modules.sd_hijack_utils import CondFunc from onediff_utils import singleton_decorator +from onediff.utils.import_utils import is_nexfort_available + @singleton_decorator def init_nexfort_backend(): @@ -26,6 +28,21 @@ def init_nexfort_backend(): lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, ) + def hijack_groupnorm32_forward(orig_func, self, x): + return super(type(self), self).forward(x) + # return self.forward(x) + + CondFunc( + "ldm.modules.diffusionmodules.util.GroupNorm32.forward", + hijack_groupnorm32_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + CondFunc( + "sgm.modules.diffusionmodules.util.GroupNorm32.forward", + hijack_groupnorm32_forward, + lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled, + ) + @torch.autocast("cuda", enabled=False) def onediff_nexfort_unet_sgm_forward( @@ -132,12 +149,7 @@ class SdOptimizationNexfort(SdOptimization): priority = 10 def is_available(self): - try: - import nexfort - except ImportError: - return False - finally: - return True + return is_nexfort_available() def apply(self): ldm.modules.attention.CrossAttention.forward = ( diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py new file mode 100644 index 000000000..dbf58e808 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py @@ -0,0 +1,107 @@ +import oneflow as flow +from compile.oneflow.mock.common import timestep_embedding +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from modules import devices + +from onediff.infer_compiler.backends.oneflow.transform import proxy_class + +cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + + +# Due to the tracing mechanism in OneFlow, it's crucial to ensure that +# the same conditional branches are taken during the first run as in subsequent runs. +# Therefore, certain "optimizations" have been modified. +def aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + # remove `if x == 0.0: return base` here + return base + x + + if require_channel_alignment: + zeros = flow.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1 and (base_h != xh or base_w != xw): + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + return base + x + + +class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + x = x.half() + if y is not None: + y = y.half() + context = context.half() + hs = [] + with flow.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + emb = self.time_embed(t_emb.half()) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = flow.cat( + [ + h, + aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = h.half() + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 9667fa505..8e9295ca5 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -3,11 +3,19 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding +def cat(tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + a = flow.nn.functional.interpolate_like(a, like=b, mode="nearest") + tensors = (a, b) + return flow.cat(tensors, *args, **kwargs) + + # https://github.com/Stability-AI/stablediffusion/blob/b4bdae9916f628461e1e4edbc62aafedebb9f7ed/ldm/modules/diffusionmodules/openaimodel.py#L775 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): @@ -28,7 +36,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: - h = flow.cat([h, hs.pop()], dim=1) + h = cat([h, hs.pop()], dim=1) h = module(h, emb, context) if self.predict_codebook_ids: return self.id_predictor(h) @@ -66,5 +74,3 @@ def forward(self, x, context=None): SpatialTransformer: SpatialTransformerOflow, UNetModel: UNetModelOflow, } - -register(package_names=["ldm"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index fabd6bcdc..a071bd7e5 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -3,11 +3,19 @@ from sgm.modules.diffusionmodules.openaimodel import UNetModel from sgm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler.backends.oneflow.transform import proxy_class from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding +def cat(tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + a = flow.nn.functional.interpolate_like(a, like=b, mode="nearest") + tensors = (a, b) + return flow.cat(tensors, *args, **kwargs) + + # https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/openaimodel.py#L816 class UNetModelOflow(proxy_class(UNetModel)): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): @@ -29,7 +37,8 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: - h = flow.cat([h, hs.pop()], dim=1) + # h = flow.cat([h, hs.pop()], dim=1) + h = cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h) @@ -67,4 +76,3 @@ def forward(self, x, context=None): SpatialTransformer: SpatialTransformerOflow, UNetModel: UNetModelOflow, } -register(package_names=["sgm"], torch2oflow_class_map=torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py new file mode 100644 index 000000000..10dfe64c9 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/vae.py @@ -0,0 +1,13 @@ +from modules.sd_vae_approx import VAEApprox + +from onediff.infer_compiler.backends.oneflow.transform import proxy_class + + +# Prevent re-importing modules.shared, which incorrectly initializes all its variables. +class VAEApproxOflow(proxy_class(VAEApprox)): + pass + + +torch2oflow_class_map = { + VAEApprox: VAEApproxOflow, +} diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index fa87ca33f..006dfd894 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -1,11 +1,18 @@ from onediff_utils import singleton_decorator from onediff.infer_compiler.backends.oneflow.transform import register - -from .mock import ldm, sgm +from onediff.utils.import_utils import is_oneflow_available @singleton_decorator def init_oneflow_backend(): + if not is_oneflow_available(): + raise RuntimeError( + "Backend oneflow for OneDiff is invalid, please make sure you have installed OneFlow" + ) + + from .mock import ldm, sgm, vae + register(package_names=["ldm"], torch2oflow_class_map=ldm.torch2oflow_class_map) register(package_names=["sgm"], torch2oflow_class_map=sgm.torch2oflow_class_map) + register(package_names=["modules"], torch2oflow_class_map=vae.torch2oflow_class_map) diff --git a/onediff_sd_webui_extensions/compile/quantization.py b/onediff_sd_webui_extensions/compile/quantization.py index b5a08d7ac..9bd52755f 100644 --- a/onediff_sd_webui_extensions/compile/quantization.py +++ b/onediff_sd_webui_extensions/compile/quantization.py @@ -3,11 +3,11 @@ from modules.sd_models import select_checkpoint -from onediff.utils import logger from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, ) +from onediff.utils import logger def quant_unet_oneflow(compiled_unet): diff --git a/onediff_sd_webui_extensions/compile/utils.py b/onediff_sd_webui_extensions/compile/utils.py index 5eda9cd9e..abbc44391 100644 --- a/onediff_sd_webui_extensions/compile/utils.py +++ b/onediff_sd_webui_extensions/compile/utils.py @@ -3,11 +3,13 @@ import torch from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel -from modules import sd_models_types +from modules import sd_models_types, shared from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from onediff.infer_compiler import DeployableModule +from .backend import OneDiffBackend + def disable_unet_checkpointing( unet_model: Union[LdmUNetModel, SgmUNetModel] @@ -25,6 +27,32 @@ def disable_unet_checkpointing( return unet_model +def get_onediff_backend() -> OneDiffBackend: + return OneDiffBackend(shared.opts.onediff_compiler_backend) + + +def is_oneflow_backend(backend: Union[OneDiffBackend, None] = None) -> bool: + return (backend or get_onediff_backend()) == OneDiffBackend.ONEFLOW + + +def is_nexfort_backend(backend: Union[OneDiffBackend, None] = None) -> bool: + return (backend or get_onediff_backend()) == OneDiffBackend.NEXFORT + + +def init_backend(backend: Union[OneDiffBackend, None] = None): + backend = backend or get_onediff_backend() + if is_oneflow_backend(backend): + from .oneflow.utils import init_oneflow_backend + + init_oneflow_backend() + elif is_nexfort_backend(backend): + from .nexfort.utils import init_nexfort_backend + + init_nexfort_backend() + else: + raise NotImplementedError(f"invalid backend {backend}") + + @dataclasses.dataclass class OneDiffCompiledGraph: name: str = None @@ -37,6 +65,7 @@ class OneDiffCompiledGraph: def __init__( self, sd_model: sd_models_types.WebuiSdModel = None, + unet_model=None, graph_module: DeployableModule = None, quantized=False, ): @@ -45,6 +74,6 @@ def __init__( self.name = sd_model.sd_checkpoint_info.name self.filename = sd_model.sd_checkpoint_info.filename self.sha = sd_model.sd_model_hash - self.eager_module = sd_model.model.diffusion_model + self.eager_module = unet_model or sd_model.model.diffusion_model self.graph_module = graph_module self.quantized = quantized diff --git a/onediff_sd_webui_extensions/compile/vae.py b/onediff_sd_webui_extensions/compile/vae.py index f3dd03204..172578501 100644 --- a/onediff_sd_webui_extensions/compile/vae.py +++ b/onediff_sd_webui_extensions/compile/vae.py @@ -1,29 +1,17 @@ +from compile.utils import get_onediff_backend from modules import shared -from modules.sd_vae_approx import VAEApprox from modules.sd_vae_approx import model as get_vae_model from modules.sd_vae_approx import sd_vae_approx_models -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from onediff.infer_compiler import compile __all__ = ["VaeCompileCtx"] compiled_models = {} -class VAEApproxOflow(proxy_class(VAEApprox)): - pass - - -torch2oflow_class_map = { - VAEApprox: VAEApproxOflow, -} - -register(package_names=["modules"], torch2oflow_class_map=torch2oflow_class_map) - - class VaeCompileCtx(object): - def __init__(self, options=None): + def __init__(self, backend=None, options=None): self._options = options # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/75336dfc84cae280036bc52a6805eb10d9ae30ba/modules/sd_vae_approx.py#L43 self._model_name = ( @@ -32,14 +20,18 @@ def __init__(self, options=None): else "model.pt" ) self._original_model = get_vae_model() + self.backend = backend def __enter__(self): if self._original_model is None: return global compiled_models model = compiled_models.get(self._model_name) + backend = str(self.backend or get_onediff_backend()) if model is None: - model = oneflow_compile(self._original_model, options=self._options) + model = compile( + self._original_model, backend=backend, options=self._options + ) compiled_models[self._model_name] = model sd_vae_approx_models[self._model_name] = model diff --git a/onediff_sd_webui_extensions/onediff_controlnet/__init__.py b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py new file mode 100644 index 000000000..ce278f071 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/__init__.py @@ -0,0 +1,5 @@ +from .compile import onediff_controlnet_decorator + +__all__ = [ + "onediff_controlnet_decorator", +] diff --git a/onediff_sd_webui_extensions/onediff_controlnet/compile.py b/onediff_sd_webui_extensions/onediff_controlnet/compile.py new file mode 100644 index 000000000..6a8d54e1d --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/compile.py @@ -0,0 +1,51 @@ +from functools import wraps + +import onediff_shared +from compile import get_compiled_graph +from compile.utils import is_nexfort_backend, is_oneflow_backend + +from .hijack import hijack_controlnet_extension +from .utils import check_if_controlnet_enabled + + +def onediff_controlnet_decorator(func): + @wraps(func) + # TODO: restore hijacked func here + def wrapper(self, p, *arg, **kwargs): + try: + onediff_shared.controlnet_enabled = check_if_controlnet_enabled(p) + if onediff_shared.controlnet_enabled: + hijack_controlnet_extension(p) + return func(self, p, *arg, **kwargs) + finally: + if onediff_shared.controlnet_enabled: + onediff_shared.previous_is_controlnet = True + else: + onediff_shared.controlnet_compiled = False + onediff_shared.previous_is_controlnet = False + + return wrapper + + +def compile_controlnet_ldm_unet(sd_model, unet_model, *, backend=None, options=None): + if is_oneflow_backend(): + from compile.oneflow.mock.controlnet import OneFlowOnediffControlNetModel + + from onediff.infer_compiler.backends.oneflow.transform import register + + from .model import OnediffControlNetModel + + register( + package_names=["scripts.hook"], + torch2oflow_class_map={ + OnediffControlNetModel: OneFlowOnediffControlNetModel, + }, + ) + elif is_nexfort_backend(): + # nothing need to do + pass + compiled_graph = get_compiled_graph( + sd_model, unet_model, backend=backend, options=options + ) + compiled_graph.name += "_controlnet" + return compiled_graph diff --git a/onediff_sd_webui_extensions/onediff_controlnet.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py similarity index 74% rename from onediff_sd_webui_extensions/onediff_controlnet.py rename to onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 9537114a8..6f7df7871 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -1,321 +1,32 @@ -from functools import wraps - import onediff_shared -from onediff_utils import check_structure_change -import oneflow as flow import torch -import torch as th -from compile import OneDiffCompiledGraph -from compile.oneflow.mock.common import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) -from ldm.modules.attention import BasicTransformerBlock, CrossAttention -from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel -from ldm.modules.diffusionmodules.util import GroupNorm32 -from modules import devices +from compile import is_oneflow_backend +from ldm.modules.diffusionmodules.openaimodel import UNetModel from modules.sd_hijack_utils import CondFunc -from onediff_utils import singleton_decorator - -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register - - -# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 -def torch_aligned_adding(base, x, require_channel_alignment): - if isinstance(x, float): - if x == 0.0: - return base - return base + x - - if require_channel_alignment: - zeros = torch.zeros_like(base) - zeros[:, : x.shape[1], ...] = x - x = zeros - - # resize to sample resolution - base_h, base_w = base.shape[-2:] - xh, xw = x.shape[-2:] - - if xh > 1 or xw > 1: - if base_h != xh or base_w != xw: - # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') - x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") - - return base + x - - -# Due to the tracing mechanism in OneFlow, it's crucial to ensure that -# the same conditional branches are taken during the first run as in subsequent runs. -# Therefore, certain "optimizations" have been modified. -def oneflow_aligned_adding(base, x, require_channel_alignment): - if isinstance(x, float): - # remove `if x == 0.0: return base` here - return base + x - - if require_channel_alignment: - zeros = flow.zeros_like(base) - zeros[:, : x.shape[1], ...] = x - x = zeros - - # resize to sample resolution - base_h, base_w = base.shape[-2:] - xh, xw = x.shape[-2:] - - if xh > 1 or xw > 1 and (base_h != xh or base_w != xw): - # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') - x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") - return base + x - - -cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) - - -class TorchOnediffControlNetModel(torch.nn.Module): - def __init__(self, unet): - super().__init__() - self.time_embed = unet.time_embed - self.input_blocks = unet.input_blocks - self.label_emb = getattr(unet, "label_emb", None) - self.middle_block = unet.middle_block - self.output_blocks = unet.output_blocks - self.out = unet.out - self.model_channels = unet.model_channels - - def forward( - self, - x, - timesteps, - context, - y, - total_t2i_adapter_embedding, - total_controlnet_embedding, - is_sdxl, - require_inpaint_hijack, - ): - from ldm.modules.diffusionmodules.util import timestep_embedding - - hs = [] - with th.no_grad(): - t_emb = cond_cast_unet( - timestep_embedding(timesteps, self.model_channels, repeat_only=False) - ) - emb = self.time_embed(t_emb) - - if is_sdxl: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for i, module in enumerate(self.input_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = module(h, emb, context) - - t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] - - if i in t2i_injection: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - hs.append(h) - - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = self.middle_block(h, emb, context) - - # U-Net Middle Block - h = torch_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) - - if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = torch_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - # U-Net Decoder - for i, module in enumerate(self.output_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = th.cat( - [ - h, - torch_aligned_adding( - hs.pop(), - total_controlnet_embedding.pop(), - require_inpaint_hijack, - ), - ], - dim=1, - ) - h = module(h, emb, context) - - # U-Net Output - h = h.type(x.dtype) - h = self.out(h) - - return h - +from onediff_utils import check_structure_change, singleton_decorator -class OneFlowOnediffControlNetModel(proxy_class(UNetModel)): - def forward( - self, - x, - timesteps, - context, - y, - total_t2i_adapter_embedding, - total_controlnet_embedding, - is_sdxl, - require_inpaint_hijack, - ): - x = x.half() - if y is not None: - y = y.half() - context = context.half() - hs = [] - with flow.no_grad(): - t_emb = cond_cast_unet( - timestep_embedding(timesteps, self.model_channels, repeat_only=False) - ) - emb = self.time_embed(t_emb.half()) - - if is_sdxl: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for i, module in enumerate(self.input_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = module(h, emb, context) - - t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] - - if i in t2i_injection: - h = oneflow_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - hs.append(h) - - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = self.middle_block(h, emb, context) - - # U-Net Middle Block - h = oneflow_aligned_adding( - h, total_controlnet_embedding.pop(), require_inpaint_hijack - ) - - if len(total_t2i_adapter_embedding) > 0 and is_sdxl: - h = oneflow_aligned_adding( - h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack - ) - - # U-Net Decoder - for i, module in enumerate(self.output_blocks): - self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) - h = flow.cat( - [ - h, - oneflow_aligned_adding( - hs.pop(), - total_controlnet_embedding.pop(), - require_inpaint_hijack, - ), - ], - dim=1, - ) - h = h.half() - h = module(h, emb, context) - - # U-Net Output - h = h.type(x.dtype) - h = self.out(h) - - return h - - -def onediff_controlnet_decorator(func): - @wraps(func) - def wrapper(self, p, *arg, **kwargs): - try: - onediff_shared.controlnet_enabled = check_if_controlnet_enabled(p) - if onediff_shared.controlnet_enabled: - hijack_controlnet_extension(p) - return func(self, p, *arg, **kwargs) - finally: - if onediff_shared.controlnet_enabled: - onediff_shared.previous_is_controlnet = True - else: - onediff_shared.controlnet_compiled = False - onediff_shared.previous_is_controlnet = False - - return wrapper - - -def compile_controlnet_ldm_unet(sd_model, unet_model, *, options=None): - from sgm.modules.attention import BasicTransformerBlock as BasicTransformerBlockSGM - from ldm.modules.attention import BasicTransformerBlock as BasicTransformerBlockLDM - from sgm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockSGM - from ldm.modules.diffusionmodules.openaimodel import ResBlock as ResBlockLDM - for module in unet_model.modules(): - if isinstance(module, (BasicTransformerBlockLDM, BasicTransformerBlockSGM)): - module.checkpoint = False - if isinstance(module, (ResBlockLDM, ResBlockSGM)): - module.use_checkpoint = False - # TODO: refine here - compiled_model = oneflow_compile(unet_model, options=options) - compiled_graph = OneDiffCompiledGraph(sd_model, compiled_model) - compiled_graph.eager_module = unet_model - compiled_graph.name += "_controlnet" - return compiled_graph - - -torch2oflow_class_map = { - CrossAttention: CrossAttentionOflow, - GroupNorm32: GroupNorm32Oflow, - TorchOnediffControlNetModel: OneFlowOnediffControlNetModel, -} -register(package_names=["scripts.hook"], torch2oflow_class_map=torch2oflow_class_map) - - -def check_if_controlnet_ext_loaded() -> bool: - from modules import extensions - - return "sd-webui-controlnet" in extensions.loaded_extensions +from .utils import get_controlnet_script def hijacked_main_entry(self, p): self._original_controlnet_main_entry(p) + from .compile import compile_controlnet_ldm_unet + from .model import OnediffControlNetModel + + if not onediff_shared.onediff_enabled: + return sd_ldm = p.sd_model unet = sd_ldm.model.diffusion_model structure_changed = check_structure_change( onediff_shared.previous_unet_type, sd_ldm ) - if onediff_shared.controlnet_compiled is False or structure_changed: - onediff_model = TorchOnediffControlNetModel(unet) + if not onediff_shared.controlnet_compiled or structure_changed: + onediff_model = OnediffControlNetModel(unet) onediff_shared.current_unet_graph = compile_controlnet_ldm_unet( sd_ldm, onediff_model ) onediff_shared.controlnet_compiled = True - else: - pass - - -def get_controlnet_script(p): - for script in p.scripts.scripts: - if script.__module__ == "controlnet.py": - return script - return None - - -def check_if_controlnet_enabled(p): - controlnet_script_class = get_controlnet_script(p) - return ( - controlnet_script_class is not None - and len(controlnet_script_class.get_enabled_units(p)) != 0 - ) # When OneDiff is initializing, the controlnet extension has not yet been loaded. @@ -341,6 +52,18 @@ def hijack_controlnet_extension(p): ) +def unhijack_controlnet_extension(p): + controlnet_script = get_controlnet_script(p) + if controlnet_script is None: + return + + if hasattr(controlnet_script, "_original_controlnet_main_entry"): + controlnet_script.controlnet_main_entry = ( + controlnet_script._original_controlnet_main_entry + ) + delattr(controlnet_script, "_original_controlnet_main_entry") + + # We were intended to only hack the closure function `forward` # in the member function `hook` of the UnetHook class in the ControlNet extension. # But due to certain limitations, we were unable to directly only hack @@ -784,7 +507,10 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): outer.attention_auto_machine = AutoMachine.Read outer.gn_auto_machine = AutoMachine.Read - # modified by OneDiff + # ------ modified by OneDiff below ------ + x = x.half() + context = context.half() + y = y.half() if y is not None else y h = onediff_shared.current_unet_graph.graph_module( x, timesteps, @@ -795,6 +521,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): is_sdxl, require_inpaint_hijack, ) + # ------ modified by OneDiff above ------ # Post-processing for color fix for param in outer.control_params: @@ -863,16 +590,13 @@ def move_all_control_model_to_cpu(): def forward_webui(*args, **kwargs): # ------ modified by OneDiff below ------ forward_func = None - if ( - "forward" - in onediff_shared.current_unet_graph.graph_module._torch_module.__dict__ - ): - forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( - "forward" - ) - _original_forward_func = onediff_shared.current_unet_graph.graph_module._torch_module.__dict__.pop( - "_original_forward" - ) + graph_module = onediff_shared.current_unet_graph.graph_module + if is_oneflow_backend(): + if "forward" in graph_module._torch_module.__dict__: + forward_func = graph_module._torch_module.__dict__.pop("forward") + _original_forward_func = graph_module._torch_module.__dict__.pop( + "_original_forward" + ) # ------ modified by OneDiff above ------ # webui will handle other compoments @@ -888,13 +612,12 @@ def forward_webui(*args, **kwargs): move_all_control_model_to_cpu() # ------ modified by OneDiff below ------ - if forward_func is not None: - onediff_shared.current_unet_graph.graph_module._torch_module.forward = ( - forward_func - ) - onediff_shared.current_unet_graph.graph_module._torch_module._original_forward = ( - _original_forward_func - ) + if is_oneflow_backend(): + if forward_func is not None: + graph_module._torch_module.forward = forward_func + graph_module._torch_module._original_forward = ( + _original_forward_func + ) # ------ modified by OneDiff above ------ def hacked_basic_transformer_inner_forward(self, x, context=None): diff --git a/onediff_sd_webui_extensions/onediff_controlnet/model.py b/onediff_sd_webui_extensions/onediff_controlnet/model.py new file mode 100644 index 000000000..4c577a013 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/model.py @@ -0,0 +1,124 @@ +import torch +import torch as th +from modules import devices + +cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x) + + +# https://github.com/Mikubill/sd-webui-controlnet/blob/8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0/scripts/hook.py#L238 +def aligned_adding(base, x, require_channel_alignment): + if isinstance(x, float): + if x == 0.0: + return base + return base + x + + if require_channel_alignment: + zeros = torch.zeros_like(base) + zeros[:, : x.shape[1], ...] = x + x = zeros + + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + + if xh > 1 or xw > 1: + if base_h != xh or base_w != xw: + # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.') + x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + + return base + x.half() + + +class OnediffControlNetModel(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.time_embed = unet.time_embed + self.input_blocks = unet.input_blocks + self.label_emb = getattr(unet, "label_emb", None) + self.middle_block = unet.middle_block + self.output_blocks = unet.output_blocks + self.out = unet.out + self.model_channels = unet.model_channels + # import ipdb; ipdb.set_trace() + self.convert_to_fp16 = unet.convert_to_fp16.__get__(self) + # print("something") + + @torch.autocast(device_type="cuda", enabled=False) + def forward( + self, + x, + timesteps, + context, + y, + total_t2i_adapter_embedding, + total_controlnet_embedding, + is_sdxl, + require_inpaint_hijack, + ): + from ldm.modules.diffusionmodules.util import timestep_embedding + + # cast to half + x = x.half() + context = context.half() + if y is not None: + y = y.half() + + hs = [] + with th.no_grad(): + t_emb = cond_cast_unet( + timestep_embedding(timesteps, self.model_channels, repeat_only=False) + ) + + t_emb = t_emb.half() + emb = self.time_embed(t_emb).half() + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y).half() + + h = x + for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = module(h, emb, context) + + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = self.middle_block(h, emb, context) + + # U-Net Middle Block + h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) + + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = aligned_adding( + h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack + ) + + # U-Net Decoder + for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) + h = th.cat( + [ + h, + aligned_adding( + hs.pop(), + total_controlnet_embedding.pop(), + require_inpaint_hijack, + ), + ], + dim=1, + ) + h = module(h, emb, context) + + # U-Net Output + h = h.type(x.dtype) + h = self.out(h) + + return h diff --git a/onediff_sd_webui_extensions/onediff_controlnet/utils.py b/onediff_sd_webui_extensions/onediff_controlnet/utils.py new file mode 100644 index 000000000..37cda8489 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_controlnet/utils.py @@ -0,0 +1,19 @@ +def check_if_controlnet_ext_loaded() -> bool: + from modules import extensions + + return "sd-webui-controlnet" in extensions.loaded_extensions + + +def get_controlnet_script(p): + for script in p.scripts.scripts: + if script.__module__ == "controlnet.py": + return script + return None + + +def check_if_controlnet_enabled(p): + controlnet_script_class = get_controlnet_script(p) + return ( + controlnet_script_class is not None + and len(controlnet_script_class.get_enabled_units(p)) != 0 + ) diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index 47b5c457c..0c6dc5e1b 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,40 +1,11 @@ from typing import Any, Mapping -import oneflow import torch -from compile.oneflow.mock import ldm, sgm from modules import sd_models from modules.sd_hijack_utils import CondFunc from onediff_shared import onediff_enabled -# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8 -class OneFlowHijackForUnet: - """ - This is oneflow, but with cat that resizes tensors to appropriate dimensions if they do not match; - this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 - """ - - def __getattr__(self, item): - if item == "cat": - return self.cat - if hasattr(oneflow, item): - return getattr(oneflow, item) - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{item}'" - ) - - def cat(self, tensors, *args, **kwargs): - if len(tensors) == 2: - a, b = tensors - a = oneflow.nn.functional.interpolate_like(a, like=b, mode="nearest") - tensors = (a, b) - return oneflow.cat(tensors, *args, **kwargs) - - -hijack_flow = OneFlowHijackForUnet() - - def unload_model_weights(sd_model=None, info=None): from modules import devices, lowvram, shared @@ -66,8 +37,6 @@ def unhijack_function(module, name, new_name): def do_hijack(): - ldm.flow = hijack_flow - sgm.flow = hijack_flow from modules import script_callbacks, sd_models script_callbacks.on_script_unloaded(undo_hijack) @@ -86,8 +55,6 @@ def do_hijack(): def undo_hijack(): - ldm.flow = oneflow - sgm.flow = oneflow from modules import sd_models unhijack_function( diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index a1f4da8da..5d8f6a4df 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,9 +1,7 @@ import torch +from compile.utils import is_oneflow_backend from onediff.infer_compiler import DeployableModule -from onediff.infer_compiler.backends.oneflow.param_utils import ( - update_graph_related_tensor, -) class HijackLoraActivate: @@ -52,8 +50,12 @@ def activate(self, p, params_list): ): continue networks.network_apply_weights(sub_module) - if isinstance(sub_module, torch.nn.Conv2d): + if is_oneflow_backend() and isinstance(sub_module, torch.nn.Conv2d): # TODO(WangYi): refine here + from onediff.infer_compiler.backends.oneflow.param_utils import ( + update_graph_related_tensor, + ) + try: update_graph_related_tensor(sub_module) except: diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py index bd9041c82..d642d4270 100644 --- a/onediff_sd_webui_extensions/onediff_shared.py +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -9,6 +9,7 @@ "is_ssd": False, } onediff_enabled = False +onediff_backend = None # controlnet controlnet_enabled = False diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index dc3a42927..29b029bf1 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -1,14 +1,21 @@ import os -from contextlib import contextmanager from functools import wraps from pathlib import Path from textwrap import dedent from zipfile import BadZipFile +import networks import onediff_shared -import oneflow as flow -from modules.devices import torch_gc +from importlib_metadata import version + +from onediff.utils.import_utils import is_oneflow_available + +if is_oneflow_available(): + import oneflow as flow + +from compile import init_backend, is_oneflow_backend from modules import shared +from modules.devices import torch_gc from onediff.infer_compiler import DeployableModule @@ -115,15 +122,35 @@ def save_graph(compiled_unet: DeployableModule, saved_cache_name: str = ""): def onediff_enabled_decorator(func): @wraps(func) - def wrapper(self, p, *arg, **kwargs): + def wrapper( + self, + p, + quantization=False, + compiler_cache=None, + saved_cache_name="", + always_recompile=False, + backend=None, + ): onediff_shared.onediff_enabled = True + if networks.originals is not None: + networks.originals.undo() + init_backend(backend) try: - return func(self, p, *arg, **kwargs) + return func( + self, + p, + quantization=quantization, + compiler_cache=compiler_cache, + saved_cache_name=saved_cache_name, + always_recompile=always_recompile, + backend=backend, + ) finally: + if networks.originals is not None: + networks.originals.__init__() onediff_shared.onediff_enabled = False onediff_shared.previous_unet_type.update(**get_model_type(shared.sd_model)) - torch_gc() - flow.cuda.empty_cache() + onediff_gc() return wrapper @@ -139,6 +166,7 @@ def wrapper(*args, **kwargs): return wrapper + def get_model_type(model): return { "is_sdxl": model.is_sdxl, @@ -146,3 +174,24 @@ def get_model_type(model): "is_sd1": model.is_sd1, "is_ssd": model.is_ssd, } + + +def onediff_gc(): + torch_gc() + if is_oneflow_backend(): + flow.cuda.empty_cache() + + +def varify_can_use_quantization(): + try: + import oneflow + + if version("oneflow") < "0.9.1": + return False + except ImportError as e: + return False + try: + import onediff_quant + except ImportError as e: + return False + return hasattr(oneflow._C, "dynamic_quantization") diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 0ad3467b4..06dd0fc11 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -6,26 +6,33 @@ import modules.shared as shared import onediff_controlnet import onediff_shared -import oneflow as flow -from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph +from compile import ( + OneDiffBackend, + SD21CompileCtx, + VaeCompileCtx, + get_compiled_graph, + get_onediff_backend, +) from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks -from modules.devices import torch_gc from modules.processing import process_images from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate + +# from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff_utils import ( check_structure_change, get_all_compiler_caches, hints_message, load_graph, onediff_enabled_decorator, + onediff_gc, refresh_all_compiler_caches, save_graph, + varify_can_use_quantization, ) -from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff.utils import logger, parse_boolean_from_env """oneflow_compiled UNetModel""" @@ -113,11 +120,9 @@ def run( ): p.override_settings.pop("sd_model_checkpoint", None) sd_models.reload_model_weights() - torch_gc() - flow.cuda.empty_cache() - - backend = backend or shared.opts.onediff_compiler_backend + onediff_gc() + backend = backend or get_onediff_backend() current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name ckpt_changed = ( shared.sd_model.sd_checkpoint_info.name @@ -152,9 +157,9 @@ def run( f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.previous_unet_type}, skip compile" ) - with UnetCompileCtx( - not onediff_shared.controlnet_enabled - ), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + with UnetCompileCtx(not onediff_shared.controlnet_enabled), VaeCompileCtx( + backend=backend + ), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) save_graph(onediff_shared.current_unet_graph, saved_cache_name) @@ -175,9 +180,9 @@ def on_ui_settings(): "onediff_compiler_backend", shared.OptionInfo( "oneflow", - "Backend for onediff compiler", + "Backend for onediff compiler (if you switch backend, you need to restart webui service)", gr.Radio, - {"choices": ["oneflow", "nexfort"]}, + {"choices": [OneDiffBackend.ONEFLOW, OneDiffBackend.NEXFORT]}, section=section, ), )