Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for CFG Scheduler #31

Merged
merged 7 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ https://arxiv.org/abs/2404.07724 and https://arxiv.org/abs/2404.13040

Constrains the usage of CFG to within a specified noise interval. Allows usage of high CFG levels (>15) without drastic alteration of composition.

Adds controllable CFG schedules. For Clamp-Linear, use (c=2.0) for SD1.5 and (c=4.0) for SDXL. For PCS, use (s=1.0) for SD1.5 and (s=0.1) for SDXL.
Adds controllable CFG schedules. For Clamp-Linear, use (c=2.0) for SD1.5 and (c=4.0) for SDXL. For PCS, use (s=1.0) for SD1.5 and (s=0.1) for SDXL.

To use CFG Scheduler, PAG Active must be set True! PAG scale can be set to 0.

#### Controls
* **Enable CFG Interval**: Enables the CFG Interval (PAG must be active! PAG scale can be set to 0.)
* **CFG Noise Interval Start**: Minimum noise level to use CFG with. SDXL recommended value: 0.28.
* **CFG Noise Interval End**: Maximum noise level to use CFG with. SDXL recommended value: >5.42.
* **CFG Scheduler**: Sets the schedule type to apply CFG.
* **Enable CFG Scheduler**: Enables the CFG Scheduler.
* **CFG Schedule Type**: Sets the schedule type to apply CFG.
- Constant: The default CFG method (constant value over all timesteps)
- Interval: Constant with CFG only being applied within the specified noise interval!
- Clamp-Linear: Clamps the CFG to the maximum of (c, Linear)
- Clamp-Cosine: Clamps the CFG to the maximum of (c, Cosine)
- PCS: Powered Cosine, lower values are better
- PCS: Powered Cosine, lower values are typically better
* **CFG Noise Interval Start**: Minimum noise level to use CFG with. SDXL recommended value: 0.28.
* **CFG Noise Interval End**: Maximum noise level to use CFG with. SDXL recommended value: >5.42.


#### Results
##### CFG Interval
Expand Down
32 changes: 19 additions & 13 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,12 @@ def setup_ui(self, is_img2img) -> list:
start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'pag_start_step', info="")
end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'pag_end_step', info="")
with gr.Row():
cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Interval", elem_id='cfg_interval_enable', info="Apply CFG only within noise interval. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]")
cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Scheduler", elem_id='cfg_interval_enable', info="If enabled, applies CFG only within noise interval with the selected schedule type. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]")
cfg_schedule = gr.Dropdown(
value='Constant',
choices= SCHEDULES,
label="CFG Interval Schedule",
label="CFG Schedule Type",
elem_id='cfg_interval_schedule',
info="Select the CFG schedule"
)
with gr.Row():
cfg_interval_low = gr.Slider(value = 0, minimum = 0, maximum = 100, step = 0.01, label="CFG Noise Interval Low", elem_id = 'cfg_interval_low', info="")
Expand Down Expand Up @@ -242,7 +241,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste
high_index = find_closest_index(cfg_interval_high, pag_params.max_sampling_step)
pag_params.cfg_interval_low = calculate_noise_level(low_index, pag_params.max_sampling_step)
pag_params.cfg_interval_high = calculate_noise_level(high_index, pag_params.max_sampling_step)
logger.debug(f"Low Index, High Index: ({low_index}, {high_index}), CFG Interval Low, High: ({pag_params.cfg_interval_low}, {pag_params.cfg_interval_high})")
logger.debug(f"Step Aligned CFG Interval (low, high): ({low_index}, {high_index}), Step Aligned CFG Interval: ({round(pag_params.cfg_interval_low, 4)}, {round(pag_params.cfg_interval_high, 4)})")

# Get all the qv modules
cross_attn_modules = self.get_cross_attn_modules()
Expand Down Expand Up @@ -481,10 +480,10 @@ def get_xyz_axis_options(self) -> dict:
xyz_grid.AxisOption("[PAG] PAG Scale", float, pag_apply_field("pag_scale")),
xyz_grid.AxisOption("[PAG] PAG Start Step", int, pag_apply_field("pag_start_step")),
xyz_grid.AxisOption("[PAG] PAG End Step", int, pag_apply_field("pag_end_step")),
xyz_grid.AxisOption("[PAG] CFG Interval Enable", str, pag_apply_override('cfg_interval_enable', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] CFG Interval Low", float, pag_apply_field("cfg_interval_low")),
xyz_grid.AxisOption("[PAG] CFG Interval High", float, pag_apply_field("cfg_interval_high")),
xyz_grid.AxisOption("[PAG] CFG Schedule", str, pag_apply_override('cfg_interval_schedule', boolean=False), choices=lambda: SCHEDULES),
xyz_grid.AxisOption("[PAG] Enable CFG Scheduler", str, pag_apply_override('cfg_interval_enable', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] CFG Noise Interval Low", float, pag_apply_field("cfg_interval_low")),
xyz_grid.AxisOption("[PAG] CFG Noise Interval High", float, pag_apply_field("cfg_interval_high")),
xyz_grid.AxisOption("[PAG] CFG Schedule Type", str, pag_apply_override('cfg_interval_schedule', boolean=False), choices=lambda: SCHEDULES),
#xyz_grid.AxisOption("[PAG] ctnms_alpha", float, pag_apply_field("pag_ctnms_alpha")),
}
return extra_axis_options
Expand All @@ -502,19 +501,22 @@ def combine_denoised_pass_conds_list(*args, **kwargs):
def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
cfg_scale = cond_scale

noise_level = calculate_noise_level(new_params.step, new_params.max_sampling_step)

# Calculate CFG Scale
cfg_scale = cond_scale
if new_params.cfg_interval_enable:
if new_params.cfg_interval_schedule == 'Interval':
if new_params.cfg_interval_schedule != 'Constant':
# Calculate noise interval
start = new_params.cfg_interval_low
end = new_params.cfg_interval_high
begin_range = start if start <= end else end
end_range = end if start <= end else start
cfg_scale = cfg_scale if begin_range <= noise_level <= end_range else 1.0
else:
cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale)
# Scheduled CFG Value
scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale)
# Only apply CFG in the interval
cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0

if incantations_debug:
logger.debug(f"Schedule: {new_params.cfg_interval_schedule}, CFG Scale: {cfg_scale}, Noise_level: {round(noise_level,3)}")
Expand Down Expand Up @@ -769,6 +771,10 @@ def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
setattr(p, field, x)
if not hasattr(p, "pag_active"):
setattr(p, "pag_active", True)
if 'cfg_interval_' in field and not hasattr(p, "cfg_interval_enable"):
setattr(p, "cfg_interval_enable", True)
return fun


Expand Down