Skip to content

Commit

Permalink
Merge pull request #31 from v0xie/fix/cfg-interval
Browse files Browse the repository at this point in the history
Fixes for CFG Scheduler
  • Loading branch information
v0xie authored May 1, 2024
2 parents 66899a7 + 9e172ca commit 1b3d17c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ https://arxiv.org/abs/2404.07724 and https://arxiv.org/abs/2404.13040

Constrains the usage of CFG to within a specified noise interval. Allows usage of high CFG levels (>15) without drastic alteration of composition.

Adds controllable CFG schedules. For Clamp-Linear, use (c=2.0) for SD1.5 and (c=4.0) for SDXL. For PCS, use (s=1.0) for SD1.5 and (s=0.1) for SDXL.
Adds controllable CFG schedules. For Clamp-Linear, use (c=2.0) for SD1.5 and (c=4.0) for SDXL. For PCS, use (s=1.0) for SD1.5 and (s=0.1) for SDXL.

To use CFG Scheduler, PAG Active must be set True! PAG scale can be set to 0.

#### Controls
* **Enable CFG Interval**: Enables the CFG Interval (PAG must be active! PAG scale can be set to 0.)
* **CFG Noise Interval Start**: Minimum noise level to use CFG with. SDXL recommended value: 0.28.
* **CFG Noise Interval End**: Maximum noise level to use CFG with. SDXL recommended value: >5.42.
* **CFG Scheduler**: Sets the schedule type to apply CFG.
* **Enable CFG Scheduler**: Enables the CFG Scheduler.
* **CFG Schedule Type**: Sets the schedule type to apply CFG.
- Constant: The default CFG method (constant value over all timesteps)
- Interval: Constant with CFG only being applied within the specified noise interval!
- Clamp-Linear: Clamps the CFG to the maximum of (c, Linear)
- Clamp-Cosine: Clamps the CFG to the maximum of (c, Cosine)
- PCS: Powered Cosine, lower values are better
- PCS: Powered Cosine, lower values are typically better
* **CFG Noise Interval Start**: Minimum noise level to use CFG with. SDXL recommended value: 0.28.
* **CFG Noise Interval End**: Maximum noise level to use CFG with. SDXL recommended value: >5.42.


#### Results
##### CFG Interval
Expand Down
32 changes: 19 additions & 13 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,12 @@ def setup_ui(self, is_img2img) -> list:
start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'pag_start_step', info="")
end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'pag_end_step', info="")
with gr.Row():
cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Interval", elem_id='cfg_interval_enable', info="Apply CFG only within noise interval. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]")
cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Scheduler", elem_id='cfg_interval_enable', info="If enabled, applies CFG only within noise interval with the selected schedule type. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]")
cfg_schedule = gr.Dropdown(
value='Constant',
choices= SCHEDULES,
label="CFG Interval Schedule",
label="CFG Schedule Type",
elem_id='cfg_interval_schedule',
info="Select the CFG schedule"
)
with gr.Row():
cfg_interval_low = gr.Slider(value = 0, minimum = 0, maximum = 100, step = 0.01, label="CFG Noise Interval Low", elem_id = 'cfg_interval_low', info="")
Expand Down Expand Up @@ -242,7 +241,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste
high_index = find_closest_index(cfg_interval_high, pag_params.max_sampling_step)
pag_params.cfg_interval_low = calculate_noise_level(low_index, pag_params.max_sampling_step)
pag_params.cfg_interval_high = calculate_noise_level(high_index, pag_params.max_sampling_step)
logger.debug(f"Low Index, High Index: ({low_index}, {high_index}), CFG Interval Low, High: ({pag_params.cfg_interval_low}, {pag_params.cfg_interval_high})")
logger.debug(f"Step Aligned CFG Interval (low, high): ({low_index}, {high_index}), Step Aligned CFG Interval: ({round(pag_params.cfg_interval_low, 4)}, {round(pag_params.cfg_interval_high, 4)})")

# Get all the qv modules
cross_attn_modules = self.get_cross_attn_modules()
Expand Down Expand Up @@ -481,10 +480,10 @@ def get_xyz_axis_options(self) -> dict:
xyz_grid.AxisOption("[PAG] PAG Scale", float, pag_apply_field("pag_scale")),
xyz_grid.AxisOption("[PAG] PAG Start Step", int, pag_apply_field("pag_start_step")),
xyz_grid.AxisOption("[PAG] PAG End Step", int, pag_apply_field("pag_end_step")),
xyz_grid.AxisOption("[PAG] CFG Interval Enable", str, pag_apply_override('cfg_interval_enable', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] CFG Interval Low", float, pag_apply_field("cfg_interval_low")),
xyz_grid.AxisOption("[PAG] CFG Interval High", float, pag_apply_field("cfg_interval_high")),
xyz_grid.AxisOption("[PAG] CFG Schedule", str, pag_apply_override('cfg_interval_schedule', boolean=False), choices=lambda: SCHEDULES),
xyz_grid.AxisOption("[PAG] Enable CFG Scheduler", str, pag_apply_override('cfg_interval_enable', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] CFG Noise Interval Low", float, pag_apply_field("cfg_interval_low")),
xyz_grid.AxisOption("[PAG] CFG Noise Interval High", float, pag_apply_field("cfg_interval_high")),
xyz_grid.AxisOption("[PAG] CFG Schedule Type", str, pag_apply_override('cfg_interval_schedule', boolean=False), choices=lambda: SCHEDULES),
#xyz_grid.AxisOption("[PAG] ctnms_alpha", float, pag_apply_field("pag_ctnms_alpha")),
}
return extra_axis_options
Expand All @@ -502,19 +501,22 @@ def combine_denoised_pass_conds_list(*args, **kwargs):
def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
cfg_scale = cond_scale

noise_level = calculate_noise_level(new_params.step, new_params.max_sampling_step)

# Calculate CFG Scale
cfg_scale = cond_scale
if new_params.cfg_interval_enable:
if new_params.cfg_interval_schedule == 'Interval':
if new_params.cfg_interval_schedule != 'Constant':
# Calculate noise interval
start = new_params.cfg_interval_low
end = new_params.cfg_interval_high
begin_range = start if start <= end else end
end_range = end if start <= end else start
cfg_scale = cfg_scale if begin_range <= noise_level <= end_range else 1.0
else:
cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale)
# Scheduled CFG Value
scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale)
# Only apply CFG in the interval
cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0

if incantations_debug:
logger.debug(f"Schedule: {new_params.cfg_interval_schedule}, CFG Scale: {cfg_scale}, Noise_level: {round(noise_level,3)}")
Expand Down Expand Up @@ -769,6 +771,10 @@ def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
setattr(p, field, x)
if not hasattr(p, "pag_active"):
setattr(p, "pag_active", True)
if 'cfg_interval_' in field and not hasattr(p, "cfg_interval_enable"):
setattr(p, "cfg_interval_enable", True)
return fun


Expand Down

0 comments on commit 1b3d17c

Please sign in to comment.