Skip to content

Commit

Permalink
Merge pull request #66 from v0xie/self-guidance
Browse files Browse the repository at this point in the history
Self-Guidance
  • Loading branch information
v0xie authored Dec 16, 2024
2 parents 374a8d5 + 7c91a30 commit 7c2b434
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 4 deletions.
33 changes: 29 additions & 4 deletions scripts/cfg_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def before_process(self, p: StableDiffusionProcessing, *args, **kwargs):
"cfgi_params": None,
"tcg_params": None,
"apg_params": None,
"sg_params": None
}
setattr(p, 'incant_cfg_params', cfg_dict)

Expand All @@ -71,11 +72,13 @@ def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
cfgi_active = p.extra_generation_params.get('CFG Interval Enable', False)
tcg_active = p.extra_generation_params.get('TCG Active', False)
apg_active = p.extra_generation_params.get('APG Active', False)
sg_active = p.extra_generation_params.get('SG Active', False)

if not any([
pag_active,
scfg_active,
cfgi_active,
sg_active,
tcg_active,
apg_active
]):
Expand Down Expand Up @@ -135,7 +138,8 @@ def patch_cfg_denoiser(self, denoiser, cfg_dict: dict, cfg_params: CFGCombinerPa
scfg_params = cfg_dict['scfg_params'],
cfgi_params = cfg_dict['cfgi_params'],
tcg_params = cfg_dict['tcg_params'],
apg_params = cfg_dict['apg_params']
apg_params = cfg_dict['apg_params'],
sg_params = cfg_dict['sg_params']
)
patched_combine_denoised = patches.patch(__name__, denoiser, "combine_denoised", pass_conds_func)
setattr(denoiser, 'combine_denoised_patched', True)
Expand Down Expand Up @@ -185,13 +189,15 @@ def combine_denoised_pass_conds_list(*args, **kwargs):
cfgi_params = kwargs.get('cfgi_params', None)
apg_params = kwargs.get('apg_params', None)
tcg_params = kwargs.get('tcg_params', None)
sg_params = kwargs.get('sg_params', None)

if not any([
pag_params,
scfg_params,
cfgi_params,
apg_params,
tcg_params
tcg_params,
sg_params
]):
logger.warning("No reason to hijack combine_denoised")
return original_func(*args)
Expand Down Expand Up @@ -308,14 +314,33 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
tcg_x = tcg_delta * (weight * tcg_params.tcg_scale)

if use_saliency_map:
sal_cfg = sanf(cfg_x, tcg_x)
denoised[i] += sal_cfg
sal_tcg = sanf(cfg_x, tcg_x)
denoised[i] += sal_tcg
else:
denoised[i] += tcg_x

except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

# 4. Self-Guidance
if sg_params is not None:
if not sg_params.sg_active or not sg_params.sg_start_step <= sg_params.step <= sg_params.sg_end_step or sg_params.sg_scale == 0 or sg_params.sg_x_out is None:
pass
# do pag
else:
try:
sg_delta = x_out[cond_index] - sg_params.sg_x_out[i]
sg_x = sg_delta * (weight * sg_params.sg_scale)

if use_saliency_map:
sal_sg = sanf(cfg_x, sg_x)
denoised[i] += sal_sg
else:
denoised[i] += sg_x

except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

devices.torch_gc()

return denoised
Expand Down
2 changes: 2 additions & 0 deletions scripts/incantation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scripts.smoothed_energy_guidance import SEGExtensionScript
from scripts.adaptive_projected_guidance import APGExtensionScript
from scripts.cfg_scheduler import CFGSchedulerExtensionScript
from scripts.self_guidance import SGExtensionScript

logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
Expand All @@ -39,6 +40,7 @@ def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = -

# main scripts
submodules: list[SubmoduleInfo] = [
SubmoduleInfo(module=SGExtensionScript()),
SubmoduleInfo(module=SEGExtensionScript()),
SubmoduleInfo(module=SCFGExtensionScript()),
SubmoduleInfo(module=PAGExtensionScript()),
Expand Down
Loading

0 comments on commit 7c2b434

Please sign in to comment.