-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
sd-webui refactor, and support refiner model (#930)
- Loading branch information
Showing
16 changed files
with
451 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .compile_ldm import SD21CompileCtx | ||
from .compile_utils import get_compiled_graph | ||
from .compile_vae import VaeCompileCtx | ||
from .onediff_compiled_graph import OneDiffCompiledGraph | ||
|
||
__all__ = [ | ||
"get_compiled_graph", | ||
"SD21CompileCtx", | ||
"VaeCompileCtx", | ||
"OneDiffCompiledGraph", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict, Union | ||
|
||
from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM | ||
from modules.sd_models import select_checkpoint | ||
from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM | ||
|
||
from onediff.optimization.quant_optimizer import ( | ||
quantize_model, | ||
varify_can_use_quantization, | ||
) | ||
from onediff.utils import logger | ||
|
||
from .compile_ldm import compile_ldm_unet | ||
from .compile_sgm import compile_sgm_unet | ||
from .onediff_compiled_graph import OneDiffCompiledGraph | ||
|
||
|
||
def compile_unet( | ||
unet_model, quantization=False, *, options=None, | ||
): | ||
if isinstance(unet_model, UNetModelLDM): | ||
compiled_unet = compile_ldm_unet(unet_model, options=options) | ||
elif isinstance(unet_model, UNetModelSGM): | ||
compiled_unet = compile_sgm_unet(unet_model, options=options) | ||
else: | ||
warnings.warn( | ||
f"Unsupported model type: {type(unet_model)} for compilation , skip", | ||
RuntimeWarning, | ||
) | ||
compiled_unet = unet_model | ||
# In OneDiff Community, quantization can be True when called by api | ||
if quantization and varify_can_use_quantization(): | ||
calibrate_info = get_calibrate_info( | ||
f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" | ||
) | ||
compiled_unet = quantize_model( | ||
compiled_unet, inplace=False, calibrate_info=calibrate_info | ||
) | ||
return compiled_unet | ||
|
||
|
||
def get_calibrate_info(filename: str) -> Union[None, Dict]: | ||
calibration_path = Path(select_checkpoint().filename).parent / filename | ||
if not calibration_path.exists(): | ||
return None | ||
|
||
logger.info(f"Got calibrate info at {str(calibration_path)}") | ||
calibrate_info = {} | ||
with open(calibration_path, "r") as f: | ||
for line in f.readlines(): | ||
line = line.strip() | ||
items = line.split(" ") | ||
calibrate_info[items[0]] = [ | ||
float(items[1]), | ||
int(items[2]), | ||
[float(x) for x in items[3].split(",")], | ||
] | ||
return calibrate_info | ||
|
||
|
||
def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: | ||
compiled_unet = compile_unet( | ||
sd_model.model.diffusion_model, quantization=quantization | ||
) | ||
return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) |
File renamed without changes.
31 changes: 31 additions & 0 deletions
31
onediff_sd_webui_extensions/compile/onediff_compiled_graph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import dataclasses | ||
|
||
import torch | ||
from modules import sd_models_types | ||
|
||
from onediff.infer_compiler import DeployableModule | ||
|
||
|
||
@dataclasses.dataclass | ||
class OneDiffCompiledGraph: | ||
name: str = None | ||
filename: str = None | ||
sha: str = None | ||
eager_module: torch.nn.Module = None | ||
graph_module: DeployableModule = None | ||
quantized: bool = False | ||
|
||
def __init__( | ||
self, | ||
sd_model: sd_models_types.WebuiSdModel = None, | ||
graph_module: DeployableModule = None, | ||
quantized=False, | ||
): | ||
if sd_model is None: | ||
return | ||
self.name = sd_model.sd_checkpoint_info.name | ||
self.filename = sd_model.sd_checkpoint_info.filename | ||
self.sha = sd_model.sd_model_hash | ||
self.eager_module = sd_model.model.diffusion_model | ||
self.graph_module = graph_module | ||
self.quantized = quantized |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from compile.onediff_compiled_graph import OneDiffCompiledGraph | ||
|
||
current_unet_graph = OneDiffCompiledGraph() | ||
current_quantization = False | ||
current_unet_type = { | ||
"is_sdxl": False, | ||
"is_sd2": False, | ||
"is_sd1": False, | ||
"is_ssd": False, | ||
} | ||
onediff_enabled = False |
Oops, something went wrong.