Skip to content

Commit

Permalink
add before_token_counter callback and use it for prompt comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Feb 11, 2024
1 parent 02ab75b commit b7f45e6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
12 changes: 11 additions & 1 deletion modules/processing_scripts/comments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from modules import scripts, shared
from modules import scripts, shared, script_callbacks
import re


Expand Down Expand Up @@ -27,6 +27,16 @@ def process(self, p, *args):
p.main_negative_prompt = strip_comments(p.main_negative_prompt)


def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
if not shared.opts.enable_prompt_comments:
return

params.prompt = strip_comments(params.prompt)


script_callbacks.on_before_token_counter(before_token_counter)


shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
"enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
}))
26 changes: 26 additions & 0 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import inspect
import os
from collections import namedtuple
Expand Down Expand Up @@ -106,6 +107,15 @@ def __init__(self, imgs, cols, rows):
self.rows = rows


@dataclasses.dataclass
class BeforeTokenCounterParams:
prompt: str
steps: int
styles: list

is_positive: bool = True


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
Expand All @@ -128,6 +138,7 @@ def __init__(self, imgs, cols, rows):
callbacks_on_reload=[],
callbacks_list_optimizers=[],
callbacks_list_unets=[],
callbacks_before_token_counter=[],
)


Expand Down Expand Up @@ -309,6 +320,14 @@ def list_unets_callback():
return res


def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']:
try:
c.callback(params)
except Exception:
report_exception(c, 'before_token_counter')


def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
Expand Down Expand Up @@ -483,3 +502,10 @@ def on_list_unets(callback):
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""

add_callback(callback_map['callbacks_list_unets'], callback)


def on_before_token_counter(callback):
"""register a function to be called when UI is counting tokens for a prompt.
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""

add_callback(callback_map['callbacks_before_token_counter'], callback)
6 changes: 6 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def connect_clear_prompt(button):


def update_token_counter(text, steps, styles, *, is_positive=True):
params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
script_callbacks.before_token_counter_callback(params)
text = params.prompt
steps = params.steps
styles = params.styles
is_positive = params.is_positive

if shared.opts.include_styles_into_token_counters:
apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
Expand Down

0 comments on commit b7f45e6

Please sign in to comment.