Skip to content

Commit

Permalink
Merge pull request #69 from v0xie/refactor/reformulate
Browse files Browse the repository at this point in the history
EP-CFG / Reformulate CFG
  • Loading branch information
v0xie authored Dec 18, 2024
2 parents 0ee1257 + 6ac4529 commit 534fd04
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 64 deletions.
142 changes: 92 additions & 50 deletions scripts/cfg_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
logger.debug("CFGCombinerScript process_batch")
pag_active = p.extra_generation_params.get('PAG Active', False)
scfg_active = p.extra_generation_params.get('SCFG Active', False)
cfgi_active = p.extra_generation_params.get('CFG Interval Enable', False)
cfgi_active = p.extra_generation_params.get('CFG Interval Enable', False) or p.extra_generation_params.get('EP-CFG Enable', False)
tcg_active = p.extra_generation_params.get('TCG Active', False)
apg_active = p.extra_generation_params.get('APG Active', False)
sg_active = p.extra_generation_params.get('SG Active', False)
Expand Down Expand Up @@ -244,45 +244,13 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
### Combine Denoised
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:

model_delta = x_out[cond_index] - denoised_uncond[i]

# S-CFG
rate = 1.0
if scfg_params is not None:
rate = scfg_combine_denoised(
model_delta = model_delta,
cfg_scale = cfg_scale,
scfg_params = scfg_params,
)
# If rate is not an int, convert to tensor
if rate is None:
logger.error("scfg_combine_denoised returned None, using default rate of 1.0")
rate = 1.0
elif not isinstance(rate, int) and not isinstance(rate, float):
rate = rate.to(device=shared.device, dtype=model_delta.dtype)
else:
# rate is tensor, probably
pass

# 1. Experimental formulation for S-CFG combined with CFG combined with APG
cfg_x = (model_delta) * rate * (weight * cfg_scale)
cfg_o = model_delta * (weight * cfg_scale) # original delta
cfg_x = cfg_o.detach().clone() # modified

if apg_params is not None:
if apg_params.apg_start_step <= cfg_params.current_step <= apg_params.apg_end_step:
normalized_cond = normalized_guidance(
pred_cond=x_out[cond_index],
pred_uncond=denoised_uncond[i],
apg_params = apg_params,
index = i,
)
cfg_x = normalized_cond * rate * (weight * (cfg_scale - 1))

if not use_saliency_map or not run_pag:
denoised[i] += cfg_x
del rate

# 2. PAG
# 1. PAG
# PAG is added like CFG
if pag_params is not None:
if not run_pag:
Expand All @@ -293,16 +261,16 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
pag_delta = x_out[cond_index] - pag_x_out[i]
pag_x = pag_delta * (weight * pag_scale)

if use_saliency_map:
sal_cfg = sanf(cfg_x, pag_x)
denoised[i] += sal_cfg
if pag_params.pag_sanf:
sal_cfg = sanf(cfg_o, pag_x)
cfg_x += sal_cfg
else:
denoised[i] += pag_x
cfg_x += pag_x

except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

# 3. TCG
# 2. TCG
# TCG is added like CFG
if tcg_params is not None:
if not tcg_params.tcg_active or tcg_params.tcg_scale <= 0 or tcg_params.tcg_x_out is None \
Expand All @@ -313,16 +281,16 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
tcg_delta = x_out[cond_index] - tcg_params.tcg_x_out[i]
tcg_x = tcg_delta * (weight * tcg_params.tcg_scale)

if use_saliency_map:
sal_tcg = sanf(cfg_x, tcg_x)
denoised[i] += sal_tcg
if tcg_params.tcg_sanf:
sal_tcg = sanf(cfg_o, tcg_x)
cfg_x += sal_tcg
else:
denoised[i] += tcg_x
cfg_x += tcg_x

except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

# 4. Self-Guidance
# 3. Self-Guidance
if sg_params is not None:
if not sg_params.sg_active or not sg_params.sg_start_step <= sg_params.step <= sg_params.sg_end_step or sg_params.sg_scale == 0 or sg_params.sg_x_out is None:
pass
Expand All @@ -332,15 +300,89 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
sg_delta = x_out[cond_index] - sg_params.sg_x_out[i]
sg_x = sg_delta * (weight * sg_params.sg_scale)

if use_saliency_map:
sal_sg = sanf(cfg_x, sg_x)
denoised[i] += sal_sg
if sg_params.sg_sanf:
sal_sg = sanf(cfg_o, sg_x)
cfg_x += sal_sg
else:
denoised[i] += sg_x
cfg_x += sg_x

except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

# 4. S-CFG
rate = 1.0
if scfg_params is not None:
rate = scfg_combine_denoised(
model_delta = cfg_x,
cfg_scale = cfg_scale,
scfg_params = scfg_params,
)
# If rate is not an int, convert to tensor
if rate is None:
logger.error("scfg_combine_denoised returned None, using default rate of 1.0")
rate = 1.0
elif not isinstance(rate, int) and not isinstance(rate, float):
rate = rate.to(device=shared.device, dtype=model_delta.dtype)
else:
# rate is tensor, probably
pass
cfg_x = rate * cfg_x



# 5. APG
if apg_params is not None:
if apg_params.apg_start_step <= cfg_params.current_step <= apg_params.apg_end_step:
cfg_x = normalized_guidance(
#cfg_x = (cfg_scale-1) * normalized_guidance(
pred_cond=cfg_x,
pred_uncond=denoised_uncond[i],
apg_params = apg_params,
index = i,
)

# 6. EP-CFG
# Isolate the latents between 0.45 and 0.55 in the energy histogram
# Rescale the cfg term by sqrt of the energy of the original prediction by the energy of the denoised prediction
if cfgi_params is not None:
if cfgi_params.ep_cfg_enable:
min_p = cfgi_params.ep_cfg_min
max_p = cfgi_params.ep_cfg_max
xc = denoised[i]
#xc = x_out[cond_index]
xcfg = denoised[i] + cfg_x
...
# Step 2: Calculate robust energy for xc

#xc_energy = torch.norm(xc, dim=(1, 2))**2
b, h, w = xc.shape
xc_energy = torch.norm(xc, dim=(1, 2))**2
xc_energy = torch.reshape(xc_energy, (xc.shape[0], -1))
xc_energy = xc_energy ** 2
q_45_xc = torch.quantile(xc_energy, 0.45, dim=-1, keepdim=True)
q_55_xc = torch.quantile(xc_energy, 0.55, dim=-1, keepdim=True)
mask_xc = (xc_energy >= q_45_xc) & (xc_energy <= q_55_xc)
robust_energy_xc = torch.sum(xc_energy * mask_xc.float(), dim=-1)

# Step 3: Calculate robust energy for xcfg
xcfg_energy = torch.norm(xcfg, dim=(1, 2))**2
xcfg_energy = torch.reshape(xcfg_energy, (xcfg.shape[0], -1))
xcfg_energy = xcfg_energy ** 2
q_45_xcfg = torch.quantile(xcfg_energy, 0.45, dim=-1, keepdim=True)
q_55_xcfg = torch.quantile(xcfg_energy, 0.55, dim=-1, keepdim=True)
mask_xcfg = (xcfg_energy >= q_45_xcfg) & (xcfg_energy <= q_55_xcfg)
robust_energy_xcfg = torch.sum(xcfg_energy * mask_xcfg.float(), dim=-1)

# Step 4: Rescale xcfg based on the ratio of robust energies
scaling_factor = torch.sqrt(robust_energy_xc / (robust_energy_xcfg + 1e-6))
xcfg_rescaled = xcfg * scaling_factor.unsqueeze(-1).unsqueeze(-1)
cfg_x = cfg_x * scaling_factor.unsqueeze(-1).unsqueeze(-1)


# 6. Add to denoised
denoised[i] += cfg_x


devices.torch_gc()

return denoised
Expand Down
Loading

0 comments on commit 534fd04

Please sign in to comment.