diff --git a/modules/processing_scripts/comments.py b/modules/processing_scripts/comments.py index 316356c77d9..638e39f2989 100644 --- a/modules/processing_scripts/comments.py +++ b/modules/processing_scripts/comments.py @@ -1,4 +1,4 @@ -from modules import scripts, shared +from modules import scripts, shared, script_callbacks import re @@ -27,6 +27,16 @@ def process(self, p, *args): p.main_negative_prompt = strip_comments(p.main_negative_prompt) +def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): + if not shared.opts.enable_prompt_comments: + return + + params.prompt = strip_comments(params.prompt) + + +script_callbacks.on_before_token_counter(before_token_counter) + + shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), { "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."), })) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a54cb3ebb56..08bc525641d 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ +import dataclasses import inspect import os from collections import namedtuple @@ -106,6 +107,15 @@ def __init__(self, imgs, cols, rows): self.rows = rows +@dataclasses.dataclass +class BeforeTokenCounterParams: + prompt: str + steps: int + styles: list + + is_positive: bool = True + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -128,6 +138,7 @@ def __init__(self, imgs, cols, rows): callbacks_on_reload=[], callbacks_list_optimizers=[], callbacks_list_unets=[], + callbacks_before_token_counter=[], ) @@ -309,6 +320,14 @@ def list_unets_callback(): return res +def before_token_counter_callback(params: BeforeTokenCounterParams): + for c in callback_map['callbacks_before_token_counter']: + try: + c.callback(params) + except Exception: + report_exception(c, 'before_token_counter') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -483,3 +502,10 @@ def on_list_unets(callback): The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" add_callback(callback_map['callbacks_list_unets'], callback) + + +def on_before_token_counter(callback): + """register a function to be called when UI is counting tokens for a prompt. + The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" + + add_callback(callback_map['callbacks_before_token_counter'], callback) diff --git a/modules/ui.py b/modules/ui.py index 5284a6300ae..dcba8e885c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -152,6 +152,12 @@ def connect_clear_prompt(button): def update_token_counter(text, steps, styles, *, is_positive=True): + params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive) + script_callbacks.before_token_counter_callback(params) + text = params.prompt + steps = params.steps + styles = params.styles + is_positive = params.is_positive if shared.opts.include_styles_into_token_counters: apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt