Skip to content

Commit

Permalink
Significantly reduce thread abuse for faster model moving
Browse files Browse the repository at this point in the history
This will move all major gradio calls into the main thread rather than random gradio threads.
This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments.
In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
  • Loading branch information
lllyasviel committed Feb 8, 2024
1 parent 291ec74 commit f06ba8e
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 31 deletions.
7 changes: 6 additions & 1 deletion modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
from modules_forge import main_thread


def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
Expand Down Expand Up @@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)

is_batch = mode == 5
Expand Down Expand Up @@ -244,3 +245,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
processed.images = []

return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args)
21 changes: 3 additions & 18 deletions modules/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False):
sd_unet.list_unets()
startup_timer.record("scripts list_unets")

def load_model():
"""
Accesses shared.sd_model property to load model.
After it's available, if it has been loaded before this access by some extension,
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
from modules import devices
devices.torch_npu_set_device()

shared.sd_model # noqa: B018

if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()

devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start()
from modules_forge import main_thread
import modules.sd_models
main_thread.async_run(modules.sd_models.model_data.get_sd_model)

from modules import shared_items
shared_items.reload_hypernetworks()
Expand Down
7 changes: 4 additions & 3 deletions modules/initialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,11 @@ def sigint_handler(sig, frame):
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
from modules_forge import main_thread

shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
Expand Down
5 changes: 5 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,11 @@ def start():
else:
webui.webui()

from modules_forge import main_thread

main_thread.loop()
return


def dump_sysinfo():
from modules import sysinfo
Expand Down
9 changes: 8 additions & 1 deletion modules/shared_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import threading
import time
import traceback
import torch

from modules import errors, shared, devices
from typing import Optional
Expand Down Expand Up @@ -134,6 +136,7 @@ def end(self):

devices.torch_gc()

@torch.inference_mode()
def set_current_image(self):
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
Expand All @@ -142,6 +145,7 @@ def set_current_image(self):
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()

@torch.inference_mode()
def do_set_current_image(self):
if self.current_latent is None:
return
Expand All @@ -156,11 +160,14 @@ def do_set_current_image(self):

self.current_image_sampling_step = self.sampling_step

except Exception:
except Exception as e:
# traceback.print_exc()
# print(e)
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()

@torch.inference_mode()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
13 changes: 11 additions & 2 deletions modules/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.ui import plaintext_to_html
from PIL import Image
import gradio as gr
from modules_forge import main_thread


def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
Expand Down Expand Up @@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
return p


def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'

Expand Down Expand Up @@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")


def txt2img(id_task: str, request: gr.Request, *args):
def txt2img_function(id_task: str, request: gr.Request, *args):
p = txt2img_create_processing(id_task, request, *args)

with closing(p):
Expand All @@ -119,3 +120,11 @@ def txt2img(id_task: str, request: gr.Request, *args):
processed.images = []

return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")


def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args)


def txt2img(id_task: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args)
68 changes: 68 additions & 0 deletions modules_forge/main_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.


import time
import traceback
import threading


lock = threading.Lock()
last_id = 0
waiting_list = []
finished_list = []


class Task:
def __init__(self, task_id, func, args, kwargs):
self.task_id = task_id
self.func = func
self.args = args
self.kwargs = kwargs
self.result = None

def work(self):
self.result = self.func(*self.args, **self.kwargs)


def loop():
global lock, last_id, waiting_list, finished_list
while True:
time.sleep(0.01)
if len(waiting_list) > 0:
with lock:
task = waiting_list.pop(0)
try:
task.work()
except Exception as e:
traceback.print_exc()
print(e)
with lock:
finished_list.append(task)


def async_run(func, *args, **kwargs):
global lock, last_id, waiting_list, finished_list
with lock:
last_id += 1
new_task = Task(task_id=last_id, func=func, args=args, kwargs=kwargs)
waiting_list.append(new_task)
return new_task.task_id


def run_and_wait_result(func, *args, **kwargs):
global lock, last_id, waiting_list, finished_list
current_id = async_run(func, *args, **kwargs)
while True:
time.sleep(0.01)
finished_task = None
for t in finished_list.copy(): # thread safe shallow copy without needing a lock
if t.task_id == current_id:
finished_task = t
break
if finished_task is not None:
with lock:
finished_list.remove(finished_task)
return finished_task.result

23 changes: 17 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from modules import timer
from modules import initialize_util
from modules import initialize

from threading import Thread
from modules_forge.initialization import initialize_forge
from modules_forge import main_thread


startup_timer = timer.startup_timer
startup_timer.record("launcher")
Expand All @@ -18,6 +20,8 @@

initialize.check_versions()

initialize.initialize()


def create_api(app):
from modules.api.api import Api
Expand All @@ -27,12 +31,10 @@ def create_api(app):
return api


def api_only():
def api_only_worker():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts

initialize.initialize()

app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
Expand All @@ -49,11 +51,10 @@ def api_only():
)


def webui():
def webui_worker():
from modules.shared_cmd_options import cmd_opts

launch_api = cmd_opts.api
initialize.initialize()

from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks

Expand Down Expand Up @@ -157,10 +158,20 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)


def api_only():
Thread(target=api_only_worker, daemon=True).start()


def webui():
Thread(target=webui_worker, daemon=True).start()


if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts

if cmd_opts.nowebui:
api_only()
else:
webui()

main_thread.loop()

0 comments on commit f06ba8e

Please sign in to comment.