Skip to content

Commit

Permalink
sd-webui refactor, and support refiner model (#930)
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Jun 17, 2024
1 parent 5677af5 commit 156d4f0
Show file tree
Hide file tree
Showing 16 changed files with 451 additions and 173 deletions.
2 changes: 2 additions & 0 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
- [Installation Guide](#installation-guide)
- [Extensions Usage](#extensions-usage)
- [Fast Model Switching](#fast-model-switching)
- [Compiler cache saving and loading](#compiler-cache-saving-and-loading)
- [LoRA](#lora)
- [Quantization](#quantization)
- [Use OneDiff by API](#use-onediff-by-api)
- [Contact](#contact)

## Performance of Community Edition
Expand Down
11 changes: 11 additions & 0 deletions onediff_sd_webui_extensions/compile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .compile_ldm import SD21CompileCtx
from .compile_utils import get_compiled_graph
from .compile_vae import VaeCompileCtx
from .onediff_compiled_graph import OneDiffCompiledGraph

__all__ = [
"get_compiled_graph",
"SD21CompileCtx",
"VaeCompileCtx",
"OneDiffCompiledGraph",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel
from ldm.modules.diffusionmodules.util import GroupNorm32
from modules import shared
from sd_webui_onediff_utils import (

from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

from .sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)

from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

__all__ = ["compile_ldm_unet"]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import oneflow as flow
from sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)
from sgm.modules.attention import (
BasicTransformerBlock,
CrossAttention,
Expand All @@ -15,6 +10,12 @@
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

from .sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)

__all__ = ["compile_sgm_unet"]


Expand Down
67 changes: 67 additions & 0 deletions onediff_sd_webui_extensions/compile/compile_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import warnings
from pathlib import Path
from typing import Dict, Union

from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM
from modules.sd_models import select_checkpoint
from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM

from onediff.optimization.quant_optimizer import (
quantize_model,
varify_can_use_quantization,
)
from onediff.utils import logger

from .compile_ldm import compile_ldm_unet
from .compile_sgm import compile_sgm_unet
from .onediff_compiled_graph import OneDiffCompiledGraph


def compile_unet(
unet_model, quantization=False, *, options=None,
):
if isinstance(unet_model, UNetModelLDM):
compiled_unet = compile_ldm_unet(unet_model, options=options)
elif isinstance(unet_model, UNetModelSGM):
compiled_unet = compile_sgm_unet(unet_model, options=options)
else:
warnings.warn(
f"Unsupported model type: {type(unet_model)} for compilation , skip",
RuntimeWarning,
)
compiled_unet = unet_model
# In OneDiff Community, quantization can be True when called by api
if quantization and varify_can_use_quantization():
calibrate_info = get_calibrate_info(
f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt"
)
compiled_unet = quantize_model(
compiled_unet, inplace=False, calibrate_info=calibrate_info
)
return compiled_unet


def get_calibrate_info(filename: str) -> Union[None, Dict]:
calibration_path = Path(select_checkpoint().filename).parent / filename
if not calibration_path.exists():
return None

logger.info(f"Got calibrate info at {str(calibration_path)}")
calibrate_info = {}
with open(calibration_path, "r") as f:
for line in f.readlines():
line = line.strip()
items = line.split(" ")
calibrate_info[items[0]] = [
float(items[1]),
int(items[2]),
[float(x) for x in items[3].split(",")],
]
return calibrate_info


def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph:
compiled_unet = compile_unet(
sd_model.model.diffusion_model, quantization=quantization
)
return OneDiffCompiledGraph(sd_model, compiled_unet, quantization)
File renamed without changes.
31 changes: 31 additions & 0 deletions onediff_sd_webui_extensions/compile/onediff_compiled_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import dataclasses

import torch
from modules import sd_models_types

from onediff.infer_compiler import DeployableModule


@dataclasses.dataclass
class OneDiffCompiledGraph:
name: str = None
filename: str = None
sha: str = None
eager_module: torch.nn.Module = None
graph_module: DeployableModule = None
quantized: bool = False

def __init__(
self,
sd_model: sd_models_types.WebuiSdModel = None,
graph_module: DeployableModule = None,
quantized=False,
):
if sd_model is None:
return
self.name = sd_model.sd_checkpoint_info.name
self.filename = sd_model.sd_checkpoint_info.filename
self.sha = sd_model.sd_model_hash
self.eager_module = sd_model.model.diffusion_model
self.graph_module = graph_module
self.quantized = quantized
136 changes: 134 additions & 2 deletions onediff_sd_webui_extensions/onediff_hijack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import compile_ldm
import compile_sgm
from typing import Any, Mapping

import oneflow
import torch
from compile import compile_ldm, compile_sgm
from modules import sd_models
from modules.sd_hijack_utils import CondFunc
from onediff_shared import onediff_enabled


# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8
Expand Down Expand Up @@ -95,3 +100,130 @@ def undo_hijack():
name="send_model_to_cpu",
new_name="__onediff_original_send_model_to_cpu",
)


def onediff_hijack_load_model_weights(
orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer
):
# load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer)
sd_model_hash = checkpoint_info.calculate_shorthash()
import onediff_shared

if onediff_shared.current_unet_graph.sha == sd_model_hash:
model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module
state_dict = {
k: v
for k, v in state_dict.items()
if not k.startswith("model.diffusion_model.")
}

# for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check
state_dict[
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight"
] = model.get_parameter(
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight"
)
return orig_func(model, checkpoint_info, state_dict, timer)


def onediff_hijack_load_state_dict(
orig_func,
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
):
if (
len(state_dict) > 0
and next(iter(state_dict.values())).is_cuda
and next(self.parameters()).is_meta
):
return orig_func(self, state_dict, strict, assign=True)
else:
return orig_func(self, state_dict, strict, assign)


# fmt: off
def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self):
from modules import shared
if shared.cmd_opts.disable_model_loading_ram_optimization:
return

sd = self.state_dict
device = self.device

def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []

for name, param in module._parameters.items():
if param is None:
continue

key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)

if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

for name in module._buffers:
key = prefix + name

sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)

original(module, state_dict, prefix, *args, **kwargs)

for key in used_param_keys:
state_dict.pop(key, None)

# def load_state_dict(original, module, state_dict, strict=True):
def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""

if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}

# ------------------- DIFF HERE -------------------
# original(module, state_dict, strict=strict)
if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta:
assign = True
else:
assign = False
# orig_func(original, module, state_dict, strict=strict, assign=assign)
original(module, state_dict, strict=strict, assign=assign)

module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
# fmt: on


CondFunc(
"modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__",
onediff_hijaced_LoadStateDictOnMeta___enter__,
lambda _, *args, **kwargs: onediff_enabled,
)
CondFunc(
"modules.sd_models.load_model_weights",
onediff_hijack_load_model_weights,
lambda _, *args, **kwargs: onediff_enabled,
)
6 changes: 5 additions & 1 deletion onediff_sd_webui_extensions/onediff_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def activate(self, p, params_list):
continue
networks.network_apply_weights(sub_module)
if isinstance(sub_module, torch.nn.Conv2d):
update_graph_related_tensor(sub_module)
# TODO(WangYi): refine here
try:
update_graph_related_tensor(sub_module)
except:
pass

activate._onediff_hijacked = True
return activate
11 changes: 11 additions & 0 deletions onediff_sd_webui_extensions/onediff_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from compile.onediff_compiled_graph import OneDiffCompiledGraph

current_unet_graph = OneDiffCompiledGraph()
current_quantization = False
current_unet_type = {
"is_sdxl": False,
"is_sd2": False,
"is_sd1": False,
"is_ssd": False,
}
onediff_enabled = False
Loading

0 comments on commit 156d4f0

Please sign in to comment.