Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

30x speedup when building extra networks UI on startup #8742

Closed
3 changes: 3 additions & 0 deletions extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def __init__(self):
def refresh(self):
lora.list_available_loras()

def item_count(self):
return len(lora.available_loras)

def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
Expand Down
25 changes: 25 additions & 0 deletions javascript/extraNetworks.js
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,28 @@ function extraNetworksRequestMetadata(event, extraPage, cardName){

event.stopPropagation()
}

var hookedExtraNetworksButtons = {};
function extraNetworksHookPageToggleIfBuilt(tab_name) {
if (hookedExtraNetworksButtons[tab_name]) {
return;
}

let button = gradioApp().querySelector("#" + tab_name + "_extra_networks.gradio-button");
let extra_networks_section = gradioApp().querySelector("div#" + tab_name + "_extra_networks");
let extra_networks_html = extra_networks_section.getElementsByClassName("output-html");
let loaded = new Array(extra_networks_html).some(html => html.innerHTML != "");
if (loaded && button) {
console.log("Replacing Gradio native callback with CSS toggle for extra networks button: " + tab_name)

// Remove event handlers by recreating button node
hookedExtraNetworksButtons[tab_name] = button;
var new_button = button.cloneNode(true);
button.parentNode.replaceChild(new_button, button);

// Add our own event to toggle extra networks section with CSS instead of Gradio
new_button.addEventListener("click", function() {
extra_networks_section.classList.toggle("hidden");
}, false);
}
}
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def list_samplers():
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"ui_extra_networks_defer_load": OptionInfo(True, "Defer loading extra networks UI until button is pressed (improves startup/page load performance)"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))

Expand Down
72 changes: 58 additions & 14 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gradio as gr
import json
import html
import tqdm

from modules.generation_parameters_copypaste import image_from_url_text

Expand Down Expand Up @@ -83,7 +84,7 @@ def search_terms_from_path(self, filename, possible_directories=None):

def create_html(self, tabname):
view = shared.opts.extra_networks_default_view
items_html = ''
items_html = []

self.metadata = {}

Expand Down Expand Up @@ -112,16 +113,25 @@ def create_html(self, tabname):
</button>
""" for subdir in subdirs])

for item in self.list_items():
total_items = self.item_count()
if total_items >= 1000:
print(f"Loading extra networks page: {self.title}")
iterator = tqdm.tqdm(self.list_items(), total=total_items)
else:
iterator = self.list_items()

for item in iterator:
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata

items_html += self.create_html_for_item(item, tabname)
items_html.append(self.create_html_for_item(item, tabname))

if items_html == '':
if not items_html:
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
items_html = [shared.html("extra-networks-no-cards.html").format(dirs=dirs)]

items_html = "".join(items_html)

self_name_id = self.name.replace(" ", "_")

Expand All @@ -136,6 +146,9 @@ def create_html(self, tabname):

return res

def item_count(self):
raise NotImplementedError()

def list_items(self):
raise NotImplementedError()

Expand Down Expand Up @@ -240,10 +253,14 @@ def create_ui(container, button, tabname):
ui.tabname = tabname

with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
tab_name_state = gr.Textbox(tabname, visible=False)
for page in ui.stored_extra_pages:
with gr.Tab(page.title):

page_elem = gr.HTML(page.create_html(ui.tabname))
if shared.opts.ui_extra_networks_defer_load:
html = ""
else:
html = page.create_html(ui.tabname)
page_elem = gr.HTML(html)
ui.pages.append(page_elem)

filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
Expand All @@ -252,13 +269,6 @@ def create_ui(container, button, tabname):
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)

def toggle_visibility(is_visible):
is_visible = not is_visible
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))

state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])

def refresh():
res = []

Expand All @@ -270,6 +280,40 @@ def refresh():

button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)

state_visible = gr.State(value=False)

if shared.opts.ui_extra_networks_defer_load:
def toggle_visibility_defer_load(is_visible, *pages):
is_visible = not is_visible
if is_visible:
new_pages = []
for i, pg in enumerate(ui.stored_extra_pages):
html = pages[i]
if not html:
html = pg.create_html(ui.tabname)
new_pages.append(html)
pages = new_pages
return [is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))] + list(pages)
# TODO: Use .then() so the extra networks drawer/loading spinner appears
# instead of nothing happening for X seconds
button.click(fn=toggle_visibility_defer_load, inputs=[state_visible] + ui.pages, outputs=[state_visible, container, button] + ui.pages)

# Gradio has to send the rendered HTML for the extra networks UI to the
# frontend every time the toggle_visibility event handler is called, even if
# all it does is change a single flag on and off. This causes a serious
# performance drop if the pages are huge strings.
# This callback removes Gradio's "click" event listener on the button in the
# frontend once it receives the pages HTML, by replacing the button and
# adding a new click listener to it that toggles the ".hidden" CSS class
# instead, thus bypassing Gradio entirely.
button.click(fn=None, _js="extraNetworksHookPageToggleIfBuilt", inputs=[tab_name_state], outputs=[])
else:
def toggle_visibility(is_visible):
is_visible = not is_visible
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))

button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])

return ui


Expand Down
3 changes: 3 additions & 0 deletions modules/ui_extra_networks_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def __init__(self):
def refresh(self):
shared.refresh_checkpoints()

def item_count(self):
return len(sd_models.checkpoints_list)

def list_items(self):
checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items():
Expand Down
3 changes: 3 additions & 0 deletions modules/ui_extra_networks_hypernets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def __init__(self):
def refresh(self):
shared.reload_hypernetworks()

def item_count(self):
return len(shared.hypernetworks)

def list_items(self):
for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path)
Expand Down
3 changes: 3 additions & 0 deletions modules/ui_extra_networks_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def __init__(self):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)

def item_count(self):
return len(sd_hijack.model_hijack.embedding_db.word_embeddings)

def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename)
Expand Down