Skip to content

Commit

Permalink
remove oudated args for load_checkpoint (#962)
Browse files Browse the repository at this point in the history
  • Loading branch information
doombeaker authored Jun 21, 2024
1 parent 1b90bbb commit 30d1168
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from comfy.cli_args import args

from onediff.utils.import_utils import is_onediff_quant_available
from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version
from onediff.infer_compiler.backends.oneflow.utils.version_util import (
is_community_version,
)


from ..modules import BoosterScheduler
Expand Down Expand Up @@ -316,8 +318,6 @@ def onediff_load_checkpoint(
self,
ckpt_name,
vae_speedup,
output_vae=True,
output_clip=True,
static_mode="enable",
cache_interval=3,
cache_layer_id=0,
Expand All @@ -326,9 +326,7 @@ def onediff_load_checkpoint(
end_step=1000,
):
# CheckpointLoaderSimple.load_checkpoint
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
booster = BoosterScheduler(
DeepcacheBoosterExecutor(
cache_interval=cache_interval,
Expand Down Expand Up @@ -618,12 +616,8 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff/Loaders"
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self, ckpt_name, vae_speedup, output_vae=True, output_clip=True
):
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
def onediff_load_checkpoint(self, ckpt_name, vae_speedup):
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
booster = BoosterScheduler(
OnelineQuantizationBoosterExecutor(
conv_percentage=100,
Expand Down Expand Up @@ -671,19 +665,11 @@ def INPUT_TYPES(s):
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self,
ckpt_name,
model_path,
compile,
vae_speedup,
output_vae=True,
output_clip=True,
self, ckpt_name, model_path, compile, vae_speedup,
):
need_compile = compile == "enable"

modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
# TODO fix by op.compile
from ..modules.oneflow.utils.onediff_load_utils import (
onediff_load_quant_checkpoint_advanced,
Expand Down

0 comments on commit 30d1168

Please sign in to comment.