Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webui remove import oneflow for nexfort backend #973

Merged
merged 38 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c5e2234
refactor
marigoold May 29, 2024
e4332cf
move mock utils
marigoold May 29, 2024
686d533
fix bug of refiner
marigoold Jun 4, 2024
5c45ddc
Merge branch 'main' into webui-refactor-support-refiner
marigoold Jun 4, 2024
156724c
refine, format
marigoold Jun 4, 2024
7b51da0
add test
marigoold Jun 4, 2024
0843f45
fix cuda memory of refiner
marigoold Jun 4, 2024
78bbe55
Merge branch 'main' into webui-refactor-support-refiner
marigoold Jun 5, 2024
345da80
refine
marigoold Jun 5, 2024
37b8f51
Merge branch 'main' into webui-refactor-support-refiner
marigoold Jun 6, 2024
51f0b06
Merge branch 'webui-refactor-support-refiner' of github.com:siliconfl…
marigoold Jun 6, 2024
03b3a89
api test add model
marigoold Jun 6, 2024
e3acdbb
support controlnet unet (controlnet model not supported now)
marigoold Jun 13, 2024
1529402
Merge branch 'main' into dev_wy_webui_controlnet
marigoold Jun 13, 2024
75a1ec3
merge master
marigoold Jun 17, 2024
b66fed5
refine
marigoold Jun 17, 2024
e356672
support recompile when switching model
marigoold Jun 18, 2024
fbfe345
Merge branch 'main' into dev_wy_webui_controlnet
lijunliangTG Jun 19, 2024
7674aff
support nexfort
marigoold Jun 19, 2024
6b9a74f
Merge branch 'main' into dev_wy_webui_controlnet
lijunliangTG Jun 20, 2024
7a24a02
use torch functional sdpa
marigoold Jun 20, 2024
e18b54a
refactor compile
marigoold Jun 20, 2024
aa604e5
refine
marigoold Jun 20, 2024
ceedbc3
Merge branch 'dev_wy_webui_controlnet' into dev_wy_support_webui_nexfort
marigoold Jun 20, 2024
70bf929
support quant and refine
marigoold Jun 21, 2024
8fabb7d
merge master
marigoold Jun 22, 2024
2c650e7
refine
marigoold Jun 22, 2024
448d88d
remove oneflow for nexfort backend
marigoold Jun 24, 2024
38f6024
merge master
marigoold Jun 24, 2024
28dbbf1
fix bug
marigoold Jun 24, 2024
cb7903a
refine
marigoold Jun 24, 2024
190a4eb
nexfort support controlnet
marigoold Jun 24, 2024
c386c74
finally launch without oneflow
marigoold Jun 24, 2024
c0887f6
Merge branch 'main' into dev_wy_webui_without_oneflow
marigoold Jun 24, 2024
31b1acc
refine
marigoold Jun 25, 2024
392516f
refine, support lora
marigoold Jun 25, 2024
5f68bb6
Merge branch 'main' into dev_wy_webui_without_oneflow
marigoold Jun 25, 2024
e7f886e
Merge branch 'main' into dev_wy_webui_without_oneflow
marigoold Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,18 @@ ln -s "$(pwd)/onediff/onediff_sd_webui_extensions" "$(pwd)/stable-diffusion-webu
cd stable-diffusion-webui

# Install all of stable-diffusion-webui's dependencies.
venv_dir=- bash webui.sh --port=8080

# Exit webui server and upgrade some of the components that conflict with onediff.
cd repositories/generative-models && git checkout 9d759324 && cd -
pip install -U einops==0.7.0
# If you install as root user, append `-f` to the end of the command line.
venv_dir=- bash webui.sh
```

## Run stable-diffusion-webui service

```bash
cd stable-diffusion-webui
python webui.py --port 8080
python webui.py --port 7860
```

Accessing http://server:8080/ from a web browser.
Accessing http://server:7860/ from a web browser.

## Extensions Usage

Expand Down
14 changes: 13 additions & 1 deletion onediff_sd_webui_extensions/compile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from .backend import OneDiffBackend
from .compile import get_compiled_graph
from .sd2 import SD21CompileCtx
from .utils import OneDiffCompiledGraph
from .utils import (
OneDiffCompiledGraph,
get_onediff_backend,
init_backend,
is_nexfort_backend,
is_oneflow_backend,
)
from .vae import VaeCompileCtx

__all__ = [
"get_compiled_graph",
"SD21CompileCtx",
"VaeCompileCtx",
"OneDiffCompiledGraph",
"OneDiffBackend",
"get_onediff_backend",
"is_oneflow_backend",
"is_nexfort_backend",
"init_backend",
]
12 changes: 12 additions & 0 deletions onediff_sd_webui_extensions/compile/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from enum import Enum


class OneDiffBackend(Enum):
ONEFLOW = "oneflow"
NEXFORT = "nexfort"

def __str__(self):
return self.value

def __repr__(self):
return f"<{self.__class__.__name__}.{self.name}: {self.value}>"
30 changes: 22 additions & 8 deletions onediff_sd_webui_extensions/compile/compile.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
from compile import OneDiffBackend
from modules.sd_hijack import apply_optimizations

from onediff.infer_compiler import compile, oneflow_compile

from .utils import OneDiffCompiledGraph, disable_unet_checkpointing
from .quantization import quant_unet_oneflow
from .utils import (
OneDiffCompiledGraph,
disable_unet_checkpointing,
is_nexfort_backend,
is_oneflow_backend,
)


def get_compiled_graph(
sd_model, *, backend, quantization=None, options=None
sd_model, unet_model=None, *, backend=None, quantization=None, options=None
) -> OneDiffCompiledGraph:
diffusion_model = sd_model.model.diffusion_model
diffusion_model = unet_model or sd_model.model.diffusion_model
compiled_unet = onediff_compile(
diffusion_model, backend=backend, quantization=quantization, options=options
)
return OneDiffCompiledGraph(sd_model, compiled_unet, quantization)
return OneDiffCompiledGraph(sd_model, diffusion_model, compiled_unet, quantization)


def onediff_compile(unet_model, *, quantization=False, backend="oneflow", options=None):
if backend == "oneflow":
def onediff_compile(
unet_model,
*,
quantization: bool = False,
backend: OneDiffBackend = None,
options=None,
):
if is_oneflow_backend(backend):
return compile_unet_oneflow(
unet_model, quantization=quantization, options=options
)
elif backend == "nexfort":
elif is_nexfort_backend(backend):
return compile_unet_nexfort(
unet_model, quantization=quantization, options=options
)
Expand All @@ -44,11 +55,14 @@ def compile_unet_oneflow(unet_model, *, quantization=False, options=None):

compiled_unet_model = oneflow_compile(unet_model, options=options)
if quantization:
from .quantization import quant_unet_oneflow

compiled_unet_model = quant_unet_oneflow(compiled_unet_model)
return compiled_unet_model


def compile_unet_nexfort(unet_model, *, quantization=False, options=None):
# TODO: support nexfort quant
if quantization:
raise NotImplementedError(
"Quantization for nexfort backend is not implemented yet."
Expand Down
69 changes: 0 additions & 69 deletions onediff_sd_webui_extensions/compile/compile_utils.py

This file was deleted.

24 changes: 18 additions & 6 deletions onediff_sd_webui_extensions/compile/nexfort/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from modules.sd_hijack_utils import CondFunc
from onediff_utils import singleton_decorator

from onediff.utils.import_utils import is_nexfort_available


@singleton_decorator
def init_nexfort_backend():
Expand All @@ -26,6 +28,21 @@ def init_nexfort_backend():
lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled,
)

def hijack_groupnorm32_forward(orig_func, self, x):
return super(type(self), self).forward(x)
# return self.forward(x)

CondFunc(
"ldm.modules.diffusionmodules.util.GroupNorm32.forward",
hijack_groupnorm32_forward,
lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled,
)
CondFunc(
"sgm.modules.diffusionmodules.util.GroupNorm32.forward",
hijack_groupnorm32_forward,
lambda orig_func, *args, **kwargs: onediff_shared.onediff_enabled,
)


@torch.autocast("cuda", enabled=False)
def onediff_nexfort_unet_sgm_forward(
Expand Down Expand Up @@ -132,12 +149,7 @@ class SdOptimizationNexfort(SdOptimization):
priority = 10

def is_available(self):
try:
import nexfort
except ImportError:
return False
finally:
return True
return is_nexfort_available()

def apply(self):
ldm.modules.attention.CrossAttention.forward = (
Expand Down
107 changes: 107 additions & 0 deletions onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import oneflow as flow
from compile.oneflow.mock.common import timestep_embedding
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from modules import devices

from onediff.infer_compiler.backends.oneflow.transform import proxy_class

cond_cast_unet = getattr(devices, "cond_cast_unet", lambda x: x)


# Due to the tracing mechanism in OneFlow, it's crucial to ensure that
# the same conditional branches are taken during the first run as in subsequent runs.
# Therefore, certain "optimizations" have been modified.
def aligned_adding(base, x, require_channel_alignment):
if isinstance(x, float):
# remove `if x == 0.0: return base` here
return base + x

if require_channel_alignment:
zeros = flow.zeros_like(base)
zeros[:, : x.shape[1], ...] = x
x = zeros

# resize to sample resolution
base_h, base_w = base.shape[-2:]
xh, xw = x.shape[-2:]

if xh > 1 or xw > 1 and (base_h != xh or base_w != xw):
# logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.')
x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
return base + x


class OneFlowOnediffControlNetModel(proxy_class(UNetModel)):
def forward(
self,
x,
timesteps,
context,
y,
total_t2i_adapter_embedding,
total_controlnet_embedding,
is_sdxl,
require_inpaint_hijack,
):
x = x.half()
if y is not None:
y = y.half()
context = context.half()
hs = []
with flow.no_grad():
t_emb = cond_cast_unet(
timestep_embedding(timesteps, self.model_channels, repeat_only=False)
)
emb = self.time_embed(t_emb.half())

if is_sdxl:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

h = x
for i, module in enumerate(self.input_blocks):
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = module(h, emb, context)

t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11]

if i in t2i_injection:
h = aligned_adding(
h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack
)

hs.append(h)

self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = self.middle_block(h, emb, context)

# U-Net Middle Block
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)

if len(total_t2i_adapter_embedding) > 0 and is_sdxl:
h = aligned_adding(
h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack
)

# U-Net Decoder
for i, module in enumerate(self.output_blocks):
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = flow.cat(
[
h,
aligned_adding(
hs.pop(),
total_controlnet_embedding.pop(),
require_inpaint_hijack,
),
],
dim=1,
)
h = h.half()
h = module(h, emb, context)

# U-Net Output
h = h.type(x.dtype)
h = self.out(h)

return h
14 changes: 10 additions & 4 deletions onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from ldm.modules.diffusionmodules.util import GroupNorm32

from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register
from onediff.infer_compiler.backends.oneflow.transform import proxy_class

from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding


def cat(tensors, *args, **kwargs):
if len(tensors) == 2:
a, b = tensors
a = flow.nn.functional.interpolate_like(a, like=b, mode="nearest")
tensors = (a, b)
return flow.cat(tensors, *args, **kwargs)


# https://github.com/Stability-AI/stablediffusion/blob/b4bdae9916f628461e1e4edbc62aafedebb9f7ed/ldm/modules/diffusionmodules/openaimodel.py#L775
class UNetModelOflow(proxy_class(UNetModel)):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
Expand All @@ -28,7 +36,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = flow.cat([h, hs.pop()], dim=1)
h = cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if self.predict_codebook_ids:
return self.id_predictor(h)
Expand Down Expand Up @@ -66,5 +74,3 @@ def forward(self, x, context=None):
SpatialTransformer: SpatialTransformerOflow,
UNetModel: UNetModelOflow,
}

register(package_names=["ldm"], torch2oflow_class_map=torch2oflow_class_map)
Loading