Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't suppress OneDiff logging if client code has already defined handlers in parent logger #888

Merged
merged 11 commits into from
Jul 9, 2024
77 changes: 33 additions & 44 deletions onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,33 @@
from comfy import model_management
from comfy.cli_args import args

from onediff.infer_compiler.utils import is_community_version
from onediff.utils.import_utils import is_onediff_quant_available
from onediff.infer_compiler.backends.oneflow.utils.version_util import (
is_community_version,
)


from ..modules import BoosterScheduler
from ..modules.oneflow import (
BasicOneFlowBoosterExecutor,
DeepcacheBoosterExecutor,
PatchBoosterExecutor,
)
from ..modules.oneflow.config import ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
from ..modules.oneflow.hijack_animatediff import animatediff_hijacker
from ..modules.oneflow.hijack_ipadapter_plus import ipadapter_plus_hijacker
from ..modules.oneflow.hijack_model_management import model_management_hijacker
from ..modules.oneflow.hijack_nodes import nodes_hijacker
from ..modules.oneflow.hijack_samplers import samplers_hijack
from ..modules.oneflow.hijack_comfyui_instantid import comfyui_instantid_hijacker
from ..modules.oneflow.hijack_model_patcher import model_patch_hijacker
from ..modules.oneflow import BasicOneFlowBoosterExecutor
from ..modules.oneflow import DeepcacheBoosterExecutor
from ..modules.oneflow import PatchBoosterExecutor
from ..modules.oneflow.utils import OUTPUT_FOLDER, load_graph, save_graph
from ..modules import BoosterScheduler
from ..utils.import_utils import is_onediff_quant_available


if is_onediff_quant_available() and not is_community_version():
from ..modules.oneflow.booster_quantization import OnelineQuantizationBoosterExecutor # type: ignore
from ..modules.oneflow.booster_quantization import (
OnelineQuantizationBoosterExecutor,
) # type: ignore

model_management_hijacker.hijack() # add flow.cuda.empty_cache()
nodes_hijacker.hijack()
samplers_hijack.hijack()
animatediff_hijacker.hijack()
ipadapter_plus_hijacker.hijack()
comfyui_instantid_hijacker.hijack()
model_patch_hijacker.hijack()

import comfy_extras.nodes_video_model
from nodes import CheckpointLoaderSimple


# https://github.com/comfyanonymous/ComfyUI/commit/bb4940d837f0cfd338ff64776b084303be066c67#diff-fab3fbd81daf87571b12fb3e4d80fc7d6bbbcf0f3dafed1dbc55d81998d82539L54
if hasattr(args, "dont_upcast_attention") and not args.dont_upcast_attention:
if hasattr(args, "dont_upcast_attention") and not args.dont_upcast_attention:
os.environ["ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M"] = "0"


Expand Down Expand Up @@ -166,6 +158,7 @@ def deep_cache_convert(
start_step,
end_step,
):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
booster = BoosterScheduler(
DeepcacheBoosterExecutor(
cache_interval=cache_interval,
Expand Down Expand Up @@ -309,19 +302,16 @@ def onediff_load_checkpoint(
self,
ckpt_name,
vae_speedup,
output_vae=True,
output_clip=True,
static_mode="enable",
cache_interval=3,
cache_layer_id=0,
cache_block_id=1,
start_step=0,
end_step=1000,
):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
# CheckpointLoaderSimple.load_checkpoint
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
booster = BoosterScheduler(
DeepcacheBoosterExecutor(
cache_interval=cache_interval,
Expand Down Expand Up @@ -391,6 +381,7 @@ def speedup(
cache_name="svd",
custom_booster: BoosterScheduler = None,
):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
if custom_booster:
booster = custom_booster
booster.inplace = inplace
Expand Down Expand Up @@ -420,6 +411,7 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def load_graph(self, vae, graph):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
vae_model = vae.first_stage_model
device = model_management.vae_offload_device()
load_graph(vae_model, graph, device, subfolder="vae")
Expand All @@ -443,6 +435,7 @@ def INPUT_TYPES(s):
OUTPUT_NODE = True

def save_graph(self, images, vae, filename_prefix):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
vae_model = vae.first_stage_model
vae_device = model_management.vae_offload_device()
save_graph(vae_model, filename_prefix, vae_device, subfolder="vae")
Expand All @@ -468,6 +461,7 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff"

def load_graph(self, model, graph):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')

diffusion_model = model.model.diffusion_model

Expand All @@ -492,6 +486,7 @@ def INPUT_TYPES(s):
OUTPUT_NODE = True

def save_graph(self, samples, model, filename_prefix):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
diffusion_model = model.model.diffusion_model
save_graph(diffusion_model, filename_prefix, "cuda", subfolder="unet")
return {}
Expand Down Expand Up @@ -545,6 +540,7 @@ def INPUT_TYPES(cls):
CATEGORY = "OneDiff"

def load_unet_int8(self, model_path):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
from ..modules.oneflow.utils.onediff_quant_utils import (
replace_module_with_quantizable_module,
)
Expand Down Expand Up @@ -583,6 +579,7 @@ def INPUT_TYPES(s):
OUTPUT_NODE = True

def quantize_model(self, model, output_dir, conv, linear):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
from ..modules.oneflow.utils import quantize_and_save_model

diffusion_model = model.model.diffusion_model
Expand Down Expand Up @@ -611,12 +608,10 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff/Loaders"
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self, ckpt_name, vae_speedup, output_vae=True, output_clip=True
):
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)

def onediff_load_checkpoint(self, ckpt_name, vae_speedup):
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
booster = BoosterScheduler(
OnelineQuantizationBoosterExecutor(
conv_percentage=100,
Expand Down Expand Up @@ -664,19 +659,12 @@ def INPUT_TYPES(s):
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self,
ckpt_name,
model_path,
compile,
vae_speedup,
output_vae=True,
output_clip=True,
self, ckpt_name, model_path, compile, vae_speedup,
):
need_compile = compile == "enable"
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')

modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
# TODO fix by op.compile
from ..modules.oneflow.utils.onediff_load_utils import (
onediff_load_quant_checkpoint_advanced,
Expand Down Expand Up @@ -727,6 +715,7 @@ def onediff_load_checkpoint(
output_vae=True,
output_clip=True,
):
print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.')
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
Expand Down
67 changes: 33 additions & 34 deletions src/onediff/infer_compiler/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,44 +29,43 @@ def __getattr__(self, name):
def configure_logging(self, name, level, log_dir=None, file_name=None):
logger = logging.getLogger(name)

if logger.hasHandlers():
logger.warning("Logging handlers already exist for %s", name)
return

logger.setLevel(level)

# Create a console formatter and add it to a console handler
console_formatter = ColorFormatter(
fmt="%(levelname)s [%(asctime)s] %(pathname)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)

# Create a file formatter and add it to a file handler if log_dir is provided
if log_dir:
log_dir = Path(log_dir)
os.makedirs(log_dir, exist_ok=True)

file_prefix = "{}_".format(
time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
)

if file_name:
log_file_name = file_prefix + file_name
else:
log_file_name = file_prefix + name + ".log"

log_file = log_dir / log_file_name
file_formatter = logging.Formatter(
fmt="%(levelname)s [%(asctime)s] - %(message)s",
if logger.hasHandlers():
logger.warning("Logging handlers already exist for %s", name)
else:
# Create a console formatter and add it to a console handler
console_formatter = ColorFormatter(
fmt="%(levelname)s [%(asctime)s] %(pathname)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)

# Create a file formatter and add it to a file handler if log_dir is provided
if log_dir:
log_dir = Path(log_dir)
os.makedirs(log_dir, exist_ok=True)

file_prefix = "{}_".format(
time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
)

if file_name:
log_file_name = file_prefix + file_name
else:
log_file_name = file_prefix + name + ".log"

log_file = log_dir / log_file_name
file_formatter = logging.Formatter(
fmt="%(levelname)s [%(asctime)s] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

self.logger = logger

Expand Down