Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way for API txt2img and img2img requests to pass args to always on scripts #8187

Merged
merged 8 commits into from
Mar 12, 2023
83 changes: 68 additions & 15 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,10 @@ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):

raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})

def get_script(self, script_name, script_runner):
if script_name is None:
def get_selectable_script(self, script_name, script_runner):
if script_name is None or script_name == "":
return None, None

if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()

script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
Expand All @@ -182,8 +178,49 @@ def get_scripts_list(self):

return ScriptsList(txt2img = t2ilist, img2img = i2ilist)

def get_script(self, script_name, script_runner):
if script_name is None or script_name == "":
return None, None

script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx]

def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
#find max idx from the scripts in runner and generate a none array to init script_args
last_arg_index = 1
for script in script_runner.scripts:
if last_arg_index < script.args_to:
last_arg_index = script.args_to
# None everywhere except position 0 to initialize script args
script_args = [None]*last_arg_index
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
script_args[0] = selectable_idx + 1
else:
# when [0] = 0 no selectable script to run
script_args[0] = 0

# Now check for always on scripts
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner)
if alwayson_script == None:
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
# Selectable script in always on script param check
if alwayson_script.alwayson == False:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
return script_args

def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)

populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
Expand All @@ -195,20 +232,26 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):

args = vars(populate)
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)

script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)

send_images = args.pop('send_images', True)
args.pop('save_images', None)

with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples

shared.state.begin()
if script is not None:
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args)
if selectable_scripts != None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()

Expand All @@ -221,12 +264,16 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")

script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)

mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)

script_runner = scripts.scripts_img2img
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui()
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)

populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images,
Expand All @@ -239,21 +286,27 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)

script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)

send_images = args.pop('send_images', True)
args.pop('save_images', None)

with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples

shared.state.begin()
if script is not None:
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_img2img.run(p, *p.script_args)
if selectable_scripts != None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()

Expand Down
2 changes: 2 additions & 0 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def generate_model(self):
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model()

Expand All @@ -122,6 +123,7 @@ def generate_model(self):
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model()

Expand Down