From e6a8dec98deea047bf44f5c33b79607f008cbae2 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 15 Apr 2024 20:26:09 -0400 Subject: [PATCH] Implement use_shell as parameter (#2297) --- config example.toml | 3 ++ kohya_gui.py | 40 ++++++++++++++++++----- kohya_gui/blip_caption_gui.py | 6 ++-- kohya_gui/class_command_executor.py | 17 +++++----- kohya_gui/class_gui_config.py | 11 +++++++ kohya_gui/class_lora_tab.py | 26 ++++++++------- kohya_gui/common_gui.py | 3 +- kohya_gui/convert_lcm_gui.py | 22 ++++++++++--- kohya_gui/convert_model_gui.py | 6 ++-- kohya_gui/dreambooth_gui.py | 7 +++- kohya_gui/extract_lora_from_dylora_gui.py | 6 ++-- kohya_gui/extract_lora_gui.py | 6 ++-- kohya_gui/extract_lycoris_locon_gui.py | 6 ++-- kohya_gui/finetune_gui.py | 38 ++++++++++++++------- kohya_gui/git_caption_gui.py | 8 +++-- kohya_gui/group_images_gui.py | 6 ++-- kohya_gui/lora_gui.py | 10 ++++-- kohya_gui/merge_lora_gui.py | 7 ++-- kohya_gui/merge_lycoris_gui.py | 6 ++-- kohya_gui/resize_lora_gui.py | 6 ++-- kohya_gui/svd_merge_lora_gui.py | 33 ++++++++++--------- kohya_gui/textual_inversion_gui.py | 28 +++++++++++----- kohya_gui/utilities.py | 14 ++++---- kohya_gui/verify_lora_gui.py | 11 +++++-- kohya_gui/wd14_caption_gui.py | 24 ++++++++++---- 25 files changed, 241 insertions(+), 109 deletions(-) diff --git a/config example.toml b/config example.toml index 0ec325492..11c852346 100644 --- a/config example.toml +++ b/config example.toml @@ -1,6 +1,9 @@ # Copy this file and name it config.toml # Edit the values to suit your needs +[settings] +use_shell = false # Use shell furing process run of sd-scripts oython code. Most secure is false but some systems may require it to be true to properly run sd-scripts. + # Default folders location [model] models_dir = "./models" # Pretrained model name or path diff --git a/kohya_gui.py b/kohya_gui.py index f40814087..ea8624af3 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -12,6 +12,7 @@ from kohya_gui.custom_logging import setup_logging from kohya_gui.localization_ext import add_javascript + def UI(**kwargs): add_javascript(kwargs.get("language")) css = "" @@ -35,9 +36,18 @@ def UI(**kwargs): interface = gr.Blocks( css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() ) - + config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) + if config.is_config_loaded(): + log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") + + use_shell_flag = kwargs.get("use_shell", False) + if use_shell_flag == False: + use_shell_flag = config.get("settings.use_shell", False) + if use_shell_flag: + log.info("Using shell=True when running external commands...") + with interface: with gr.Tab("Dreambooth"): ( @@ -45,13 +55,17 @@ def UI(**kwargs): reg_data_dir_input, output_dir_input, logging_dir_input, - ) = dreambooth_tab(headless=headless, config=config) + ) = dreambooth_tab( + headless=headless, config=config, use_shell_flag=use_shell_flag + ) with gr.Tab("LoRA"): - lora_tab(headless=headless, config=config) + lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) with gr.Tab("Textual Inversion"): - ti_tab(headless=headless, config=config) + ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) with gr.Tab("Finetuning"): - finetune_tab(headless=headless, config=config) + finetune_tab( + headless=headless, config=config, use_shell_flag=use_shell_flag + ) with gr.Tab("Utilities"): utilities_tab( train_data_dir_input=train_data_dir_input, @@ -61,9 +75,10 @@ def UI(**kwargs): enable_copy_info_button=True, headless=headless, config=config, + use_shell_flag=use_shell_flag, ) with gr.Tab("LoRA"): - _ = LoRATools(headless=headless) + _ = LoRATools(headless=headless, use_shell_flag=use_shell_flag) with gr.Tab("About"): gr.Markdown(f"kohya_ss GUI release {release}") with gr.Tab("README"): @@ -102,6 +117,7 @@ def UI(**kwargs): launch_kwargs["debug"] = True interface.launch(**launch_kwargs) + if __name__ == "__main__": # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() @@ -141,11 +157,17 @@ def UI(**kwargs): parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") - - parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI") + + parser.add_argument( + "--use_shell", action="store_true", help="Use shell environment" + ) + + parser.add_argument( + "--do_not_share", action="store_true", help="Do not share the gradio UI" + ) args = parser.parse_args() - + # Set up logging log = setup_logging(debug=args.debug) diff --git a/kohya_gui/blip_caption_gui.py b/kohya_gui/blip_caption_gui.py index 29db10e33..8b752880c 100644 --- a/kohya_gui/blip_caption_gui.py +++ b/kohya_gui/blip_caption_gui.py @@ -23,6 +23,7 @@ def caption_images( beam_search: bool, prefix: str = "", postfix: str = "", + use_shell: bool = False, ) -> None: """ Automatically generates captions for images in the specified directory using the BLIP model. @@ -96,7 +97,7 @@ def caption_images( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command in the sd-scripts folder context - subprocess.run(run_cmd, env=env, cwd=f"{scriptdir}/sd-scripts") + subprocess.run(run_cmd, env=env, shell=use_shell, cwd=f"{scriptdir}/sd-scripts") # Add prefix and postfix @@ -115,7 +116,7 @@ def caption_images( ### -def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None): +def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None, use_shell: bool = False): from .common_gui import create_refresh_button default_train_dir = ( @@ -205,6 +206,7 @@ def list_train_dirs(path): beam_search, prefix, postfix, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/class_command_executor.py b/kohya_gui/class_command_executor.py index 7fc22f9d8..a8db3ff47 100644 --- a/kohya_gui/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -2,7 +2,7 @@ import psutil import time import gradio as gr -import shlex + from .custom_logging import setup_logging # Set up logging @@ -21,7 +21,7 @@ def __init__(self): self.process = None self.run_state = gr.Textbox(value="", visible=False) - def execute_command(self, run_cmd: str, **kwargs): + def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs): """ Execute a command if no other command is currently running. @@ -36,11 +36,12 @@ def execute_command(self, run_cmd: str, **kwargs): # log.info(f"{i}: {item}") # Reconstruct the safe command string for display - command_to_run = ' '.join(run_cmd) - log.info(f"Executing command: {command_to_run}") + command_to_run = " ".join(run_cmd) + log.info(f"Executing command: {command_to_run} with shell={use_shell}") # Execute the command securely - self.process = subprocess.Popen(run_cmd, **kwargs) + self.process = subprocess.Popen(run_cmd, **kwargs, shell=use_shell) + log.info("Command executed.") def kill_command(self): """ @@ -64,9 +65,9 @@ def kill_command(self): log.info(f"Error when terminating process: {e}") else: log.info("There is no running process to kill.") - + return gr.Button(visible=True), gr.Button(visible=False) - + def wait_for_training_to_end(self): while self.is_running(): time.sleep(1) @@ -81,4 +82,4 @@ def is_running(self): Returns: - bool: True if the command is running, False otherwise. """ - return self.process and self.process.poll() is None \ No newline at end of file + return self.process and self.process.poll() is None diff --git a/kohya_gui/class_gui_config.py b/kohya_gui/class_gui_config.py index 3624631e6..a19e855af 100644 --- a/kohya_gui/class_gui_config.py +++ b/kohya_gui/class_gui_config.py @@ -80,3 +80,14 @@ def get(self, key: str, default=None): # Return the final value log.debug(f"Returned {data}") return data + + def is_config_loaded(self) -> bool: + """ + Checks if the configuration was loaded from a file. + + Returns: + bool: True if the configuration was loaded from a file, False otherwise. + """ + is_loaded = self.config != {} + log.debug(f"Configuration was loaded from file: {is_loaded}") + return is_loaded diff --git a/kohya_gui/class_lora_tab.py b/kohya_gui/class_lora_tab.py index efeaf952e..798b7d7f7 100644 --- a/kohya_gui/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -11,16 +11,18 @@ class LoRATools: - def __init__(self, headless: bool = False): - self.headless = headless - + def __init__( + self, + headless: bool = False, + use_shell_flag: bool = False, + ): gr.Markdown("This section provide various LoRA tools...") - gradio_extract_dylora_tab(headless=headless) - gradio_convert_lcm_tab(headless=headless) - gradio_extract_lora_tab(headless=headless) - gradio_extract_lycoris_locon_tab(headless=headless) - gradio_merge_lora_tab = GradioMergeLoRaTab() - gradio_merge_lycoris_tab(headless=headless) - gradio_svd_merge_lora_tab(headless=headless) - gradio_resize_lora_tab(headless=headless) - gradio_verify_lora_tab(headless=headless) + gradio_extract_dylora_tab(headless=headless, use_shell=use_shell_flag) + gradio_convert_lcm_tab(headless=headless, use_shell=use_shell_flag) + gradio_extract_lora_tab(headless=headless, use_shell=use_shell_flag) + gradio_extract_lycoris_locon_tab(headless=headless, use_shell=use_shell_flag) + gradio_merge_lora_tab = GradioMergeLoRaTab(use_shell=use_shell_flag) + gradio_merge_lycoris_tab(headless=headless, use_shell=use_shell_flag) + gradio_svd_merge_lora_tab(headless=headless, use_shell=use_shell_flag) + gradio_resize_lora_tab(headless=headless, use_shell=use_shell_flag) + gradio_verify_lora_tab(headless=headless, use_shell=use_shell_flag) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 634e5a652..25594f95e 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -2,7 +2,6 @@ from easygui import msgbox, ynbox from typing import Optional from .custom_logging import setup_logging -from .class_command_executor import CommandExecutor import os import re @@ -12,7 +11,6 @@ import json import math import shutil -import time # Set up logging log = setup_logging() @@ -23,6 +21,7 @@ document_symbol = "\U0001F4C4" # 📄 scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + if os.name == "nt": scriptdir = scriptdir.replace("\\", "/") diff --git a/kohya_gui/convert_lcm_gui.py b/kohya_gui/convert_lcm_gui.py index 5b8fb1e65..0d0445676 100644 --- a/kohya_gui/convert_lcm_gui.py +++ b/kohya_gui/convert_lcm_gui.py @@ -22,7 +22,13 @@ PYTHON = sys.executable -def convert_lcm(name, model_path, lora_scale, model_type): +def convert_lcm( + name, + model_path, + lora_scale, + model_type, + use_shell: bool = False, +): run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"' # Check if source model exist @@ -62,7 +68,7 @@ def convert_lcm(name, model_path, lora_scale, model_type): run_cmd.append("--ssd-1b") # Log the command - log.info(' '.join(run_cmd)) + log.info(" ".join(run_cmd)) # Set up the environment env = os.environ.copy() @@ -72,13 +78,13 @@ def convert_lcm(name, model_path, lora_scale, model_type): env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) # Return a success message log.info("Done extracting...") -def gradio_convert_lcm_tab(headless=False): +def gradio_convert_lcm_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -183,6 +189,12 @@ def list_save_to(path): extract_button.click( convert_lcm, - inputs=[name, model_path, lora_scale, model_type], + inputs=[ + name, + model_path, + lora_scale, + model_type, + gr.Checkbox(value=use_shell, visible=False), + ], show_progress=False, ) diff --git a/kohya_gui/convert_model_gui.py b/kohya_gui/convert_model_gui.py index 14cb71eb4..c366499e6 100644 --- a/kohya_gui/convert_model_gui.py +++ b/kohya_gui/convert_model_gui.py @@ -26,6 +26,7 @@ def convert_model( target_model_type, target_save_precision_type, unet_use_linear_projection, + use_shell: bool = False, ): # Check for caption_text_input if source_model_type == "": @@ -107,7 +108,7 @@ def convert_model( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) @@ -116,7 +117,7 @@ def convert_model( ### -def gradio_convert_model_tab(headless=False): +def gradio_convert_model_tab(headless=False, use_shell: bool = False): from .common_gui import create_refresh_button default_source_model = os.path.join(scriptdir, "outputs") @@ -276,6 +277,7 @@ def list_target_folder(path): target_model_type, target_save_precision_type, unet_use_linear_projection, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 7e9538df0..32954964d 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -47,6 +47,7 @@ # Setup huggingface huggingface = None +use_shell = False PYTHON = sys.executable @@ -843,7 +844,7 @@ def train_model( # Run the command - executor.execute_command(run_cmd=run_cmd, env=env) + executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) return ( gr.Button(visible=False), @@ -859,10 +860,14 @@ def dreambooth_tab( # logging_dir=gr.Textbox(), headless=False, config: KohyaSSGUIConfig = {}, + use_shell_flag: bool = False, ): dummy_db_true = gr.Checkbox(value=True, visible=False) dummy_db_false = gr.Checkbox(value=False, visible=False) dummy_headless = gr.Checkbox(value=headless, visible=False) + + global use_shell + use_shell = use_shell_flag with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya dreambooth python code...") diff --git a/kohya_gui/extract_lora_from_dylora_gui.py b/kohya_gui/extract_lora_from_dylora_gui.py index e87c29d37..e1000dc1a 100644 --- a/kohya_gui/extract_lora_from_dylora_gui.py +++ b/kohya_gui/extract_lora_from_dylora_gui.py @@ -27,6 +27,7 @@ def extract_dylora( model, save_to, unit, + use_shell: bool = False, ): # Check for caption_text_input if model == "": @@ -71,7 +72,7 @@ def extract_dylora( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("Done extracting DyLoRA...") @@ -81,7 +82,7 @@ def extract_dylora( ### -def gradio_extract_dylora_tab(headless=False): +def gradio_extract_dylora_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -170,6 +171,7 @@ def list_save_to(path): model, save_to, unit, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index ea0cc53dd..1a1bf9cb2 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -39,6 +39,7 @@ def extract_lora( load_original_model_to, load_tuned_model_to, load_precision, + use_shell: bool = False, ): # Check for caption_text_input if model_tuned == "": @@ -120,7 +121,7 @@ def extract_lora( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) ### @@ -128,7 +129,7 @@ def extract_lora( ### -def gradio_extract_lora_tab(headless=False): +def gradio_extract_lora_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_model_org_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -358,6 +359,7 @@ def change_sdxl(sdxl): load_original_model_to, load_tuned_model_to, load_precision, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/extract_lycoris_locon_gui.py b/kohya_gui/extract_lycoris_locon_gui.py index e22ef8504..af17163c5 100644 --- a/kohya_gui/extract_lycoris_locon_gui.py +++ b/kohya_gui/extract_lycoris_locon_gui.py @@ -43,6 +43,7 @@ def extract_lycoris_locon( use_sparse_bias, sparsity, disable_cp, + use_shell: bool = False, ): # Check for caption_text_input if db_model == "": @@ -135,7 +136,7 @@ def extract_lycoris_locon( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("Done extracting...") @@ -171,7 +172,7 @@ def update_mode(mode): return tuple(updates) -def gradio_extract_lycoris_locon_tab(headless=False): +def gradio_extract_lycoris_locon_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_base_model_dir = os.path.join(scriptdir, "outputs") @@ -449,6 +450,7 @@ def list_save_to(path): use_sparse_bias, sparsity, disable_cp, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index b7b3dd2b1..26a4a1822 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -32,6 +32,7 @@ from .class_sample_images import SampleImages, create_prompt_file from .class_huggingface import HuggingFace from .class_metadata import MetaData +from .class_gui_config import KohyaSSGUIConfig from .custom_logging import setup_logging @@ -43,6 +44,7 @@ # Setup huggingface huggingface = None +use_shell = False # from easygui import msgbox @@ -592,7 +594,6 @@ def train_model( if not print_only: subprocess.run(run_cmd, env=env) - # create images buckets if generate_image_buckets: # Build the command to run the preparation script @@ -639,7 +640,6 @@ def train_model( if not print_only: subprocess.run(run_cmd, env=env) - if image_folder == "": log.error("Image folder dir is empty") return TRAIN_BUTTON_VISIBLE @@ -709,7 +709,7 @@ def train_model( ) cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs no_half_vae = sdxl_checkbox and sdxl_no_half_vae - + if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 else: @@ -719,7 +719,7 @@ def train_model( max_train_steps = 0 else: max_train_steps = int(max_train_steps) - + config_toml_data = { # Update the values in the TOML data "huggingface_repo_id": huggingface_repo_id, @@ -758,16 +758,22 @@ def train_model( "ip_noise_gamma": ip_noise_gamma, "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), - "learning_rate": learning_rate, # both for sd1.5 and sdxl - "learning_rate_te": learning_rate_te if not sdxl_checkbox else None, # only for sd1.5 - "learning_rate_te1": learning_rate_te1 if sdxl_checkbox else None, # only for sdxl - "learning_rate_te2": learning_rate_te2 if sdxl_checkbox else None, # only for sdxl + "learning_rate": learning_rate, # both for sd1.5 and sdxl + "learning_rate_te": ( + learning_rate_te if not sdxl_checkbox else None + ), # only for sd1.5 + "learning_rate_te1": ( + learning_rate_te1 if sdxl_checkbox else None + ), # only for sdxl + "learning_rate_te2": ( + learning_rate_te2 if sdxl_checkbox else None + ), # only for sdxl "logging_dir": logging_dir, "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, "loss_type": loss_type, "lr_scheduler": lr_scheduler, - "lr_scheduler_args": str(lr_scheduler_args).replace('"', '').split(), + "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_warmup_steps": lr_warmup_steps, "max_bucket_reso": int(max_bucket_reso), "max_data_loader_n_workers": max_data_loader_n_workers, @@ -792,7 +798,7 @@ def train_model( "noise_offset_random_strength": noise_offset_random_strength, "noise_offset_type": noise_offset_type, "optimizer_type": optimizer, - "optimizer_args": str(optimizer_args).replace('"', '').split(), + "optimizer_args": str(optimizer_args).replace('"', "").split(), "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": persistent_data_loader_workers, @@ -892,7 +898,7 @@ def train_model( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - executor.execute_command(run_cmd=run_cmd, env=env) + executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) return ( gr.Button(visible=False), @@ -901,10 +907,18 @@ def train_model( ) -def finetune_tab(headless=False, config: dict = {}): +def finetune_tab( + headless=False, + config: KohyaSSGUIConfig = {}, + use_shell_flag: bool = False, +): dummy_db_true = gr.Checkbox(value=True, visible=False) dummy_db_false = gr.Checkbox(value=False, visible=False) dummy_headless = gr.Checkbox(value=headless, visible=False) + + global use_shell + use_shell = use_shell_flag + with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya finetune python code...") diff --git a/kohya_gui/git_caption_gui.py b/kohya_gui/git_caption_gui.py index e413d24ae..b82882f99 100644 --- a/kohya_gui/git_caption_gui.py +++ b/kohya_gui/git_caption_gui.py @@ -22,6 +22,7 @@ def caption_images( model_id, prefix, postfix, + use_shell: bool = False, ): # Check for images_dir_input if train_data_dir == "": @@ -70,7 +71,7 @@ def caption_images( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) # Add prefix and postfix add_pre_postfix( @@ -88,7 +89,9 @@ def caption_images( ### -def gradio_git_caption_gui_tab(headless=False, default_train_dir=None): +def gradio_git_caption_gui_tab( + headless=False, default_train_dir=None, use_shell: bool = False +): from .common_gui import create_refresh_button default_train_dir = ( @@ -178,6 +181,7 @@ def list_train_dirs(path): model_id, prefix, postfix, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/group_images_gui.py b/kohya_gui/group_images_gui.py index 9536dd5e7..62258d533 100644 --- a/kohya_gui/group_images_gui.py +++ b/kohya_gui/group_images_gui.py @@ -21,6 +21,7 @@ def group_images( do_not_copy_other_files, generate_captions, caption_ext, + use_shell: bool = False, ): if input_folder == "": msgbox("Input folder is missing...") @@ -63,12 +64,12 @@ def group_images( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("...grouping done") -def gradio_group_images_gui_tab(headless=False): +def gradio_group_images_gui_tab(headless=False, use_shell: bool = False): from .common_gui import create_refresh_button current_input_folder = os.path.join(scriptdir, "data") @@ -200,6 +201,7 @@ def list_output_dirs(path): do_not_copy_other_files, generate_captions, caption_ext, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index d0f7c4297..d86eec80c 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -34,6 +34,7 @@ from .class_lora_tab import LoRATools from .class_huggingface import HuggingFace from .class_metadata import MetaData +from .class_gui_config import KohyaSSGUIConfig from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -50,6 +51,7 @@ # Setup huggingface huggingface = None +use_shell = False button_run = gr.Button("Start training", variant="primary") @@ -1193,7 +1195,7 @@ def train_model( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - executor.execute_command(run_cmd=run_cmd, env=env) + executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) return ( gr.Button(visible=False), @@ -1208,11 +1210,15 @@ def lora_tab( output_dir_input=gr.Dropdown(), logging_dir_input=gr.Dropdown(), headless=False, - config: dict = {}, + config: KohyaSSGUIConfig = {}, + use_shell_flag: bool = False, ): dummy_db_true = gr.Checkbox(value=True, visible=False) dummy_db_false = gr.Checkbox(value=False, visible=False) dummy_headless = gr.Checkbox(value=headless, visible=False) + + global use_shell + use_shell = use_shell_flag with gr.Tab("Training"), gr.Column(variant="compact") as tab: gr.Markdown( diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index 01962178e..e7d9645f8 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -48,8 +48,9 @@ def verify_conditions(sd_model, lora_models): class GradioMergeLoRaTab: - def __init__(self, headless=False): + def __init__(self, headless=False, use_shell: bool = False): self.headless = headless + self.use_shell = use_shell self.build_tab() def save_inputs_to_json(self, file_path, inputs): @@ -379,6 +380,7 @@ def list_save_to(path): save_to, precision, save_precision, + gr.Checkbox(value=self.use_shell, visible=False), ], show_progress=False, ) @@ -398,6 +400,7 @@ def merge_lora( save_to, precision, save_precision, + use_shell: bool = False, ): log.info("Merge model...") @@ -458,6 +461,6 @@ def merge_lora( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("Done merging...") diff --git a/kohya_gui/merge_lycoris_gui.py b/kohya_gui/merge_lycoris_gui.py index c958a7a64..935cfb1af 100644 --- a/kohya_gui/merge_lycoris_gui.py +++ b/kohya_gui/merge_lycoris_gui.py @@ -33,6 +33,7 @@ def merge_lycoris( device, is_sdxl, is_v2, + use_shell: bool = False, ): log.info("Merge model...") @@ -67,7 +68,7 @@ def merge_lycoris( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Execute the command with the modified environment - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("Done merging...") @@ -77,7 +78,7 @@ def merge_lycoris( ### -def gradio_merge_lycoris_tab(headless=False): +def gradio_merge_lycoris_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_lycoris_dir = current_model_dir current_save_dir = current_model_dir @@ -250,6 +251,7 @@ def list_save_to(path): device, is_sdxl, is_v2, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/resize_lora_gui.py b/kohya_gui/resize_lora_gui.py index 465acc9e5..8f90fa5d8 100644 --- a/kohya_gui/resize_lora_gui.py +++ b/kohya_gui/resize_lora_gui.py @@ -33,6 +33,7 @@ def resize_lora( dynamic_method, dynamic_param, verbose, + use_shell: bool = False, ): # Check for caption_text_input if model == "": @@ -100,7 +101,7 @@ def resize_lora( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) log.info("Done resizing...") @@ -110,7 +111,7 @@ def resize_lora( ### -def gradio_resize_lora_tab(headless=False): +def gradio_resize_lora_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -246,6 +247,7 @@ def list_save_to(path): dynamic_method, dynamic_param, verbose, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/svd_merge_lora_gui.py b/kohya_gui/svd_merge_lora_gui.py index dc1c78cdc..a896ebc69 100644 --- a/kohya_gui/svd_merge_lora_gui.py +++ b/kohya_gui/svd_merge_lora_gui.py @@ -38,6 +38,7 @@ def svd_merge_lora( new_rank, new_conv_rank, device, + use_shell: bool = False, ): # Check if the output file already exists if os.path.isfile(save_to): @@ -53,10 +54,14 @@ def svd_merge_lora( ratio_d /= total_ratio run_cmd = [ - PYTHON, f"{scriptdir}/sd-scripts/networks/svd_merge_lora.py", - '--save_precision', save_precision, - '--precision', precision, - '--save_to', save_to + PYTHON, + f"{scriptdir}/sd-scripts/networks/svd_merge_lora.py", + "--save_precision", + save_precision, + "--precision", + precision, + "--save_to", + save_to, ] # Variables for model paths and their ratios @@ -82,17 +87,15 @@ def add_model(model_path, ratio): pass if models and ratios: # Ensure we have valid models and ratios before appending - run_cmd.extend(['--models'] + models) - run_cmd.extend(['--ratios'] + ratios) + run_cmd.extend(["--models"] + models) + run_cmd.extend(["--ratios"] + ratios) - run_cmd.extend([ - '--device', device, - '--new_rank', new_rank, - '--new_conv_rank', new_conv_rank - ]) + run_cmd.extend( + ["--device", device, "--new_rank", new_rank, "--new_conv_rank", new_conv_rank] + ) # Log the command - log.info(' '.join(run_cmd)) + log.info(" ".join(run_cmd)) env = os.environ.copy() env["PYTHONPATH"] = ( @@ -102,8 +105,7 @@ def add_model(model_path, ratio): env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) - + subprocess.run(run_cmd, env=env, shell=use_shell) ### @@ -111,7 +113,7 @@ def add_model(model_path, ratio): ### -def gradio_svd_merge_lora_tab(headless=False): +def gradio_svd_merge_lora_tab(headless=False, use_shell: bool = False): current_save_dir = os.path.join(scriptdir, "outputs") current_a_model_dir = current_save_dir current_b_model_dir = current_save_dir @@ -406,6 +408,7 @@ def list_save_to(path): new_rank, new_conv_rank, device, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index d29d140ba..1c1b241fe 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -37,6 +37,7 @@ ) from .dataset_balancing_gui import gradio_dataset_balancing_tab from .class_sample_images import SampleImages, create_prompt_file +from .class_gui_config import KohyaSSGUIConfig from .custom_logging import setup_logging @@ -48,6 +49,7 @@ # Setup huggingface huggingface = None +use_shell = False TRAIN_BUTTON_VISIBLE = [gr.Button(visible=True), gr.Button(visible=False)] @@ -624,7 +626,7 @@ def train_model( run_cmd.append(f"{scriptdir}/sd-scripts/sdxl_train_textual_inversion.py") else: run_cmd.append(f"{scriptdir}/sd-scripts/train_textual_inversion.py") - + if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 else: @@ -634,7 +636,7 @@ def train_model( max_train_steps = 0 else: max_train_steps = int(max_train_steps) - + # def save_huggingface_to_toml(self, toml_file_path: str): config_toml_data = { # Update the values in the TOML data @@ -675,8 +677,10 @@ def train_model( "log_tracker_config": log_tracker_config, "loss_type": loss_type, "lr_scheduler": lr_scheduler, - "lr_scheduler_args": str(lr_scheduler_args).replace('"', '').split(), - "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch), + "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), + "lr_scheduler_num_cycles": ( + lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch) + ), "lr_scheduler_power": lr_scheduler_power, "lr_warmup_steps": lr_warmup_steps, "max_bucket_reso": max_bucket_reso, @@ -704,7 +708,7 @@ def train_model( "noise_offset_type": noise_offset_type, "num_vectors_per_token": int(num_vectors_per_token), "optimizer_type": optimizer, - "optimizer_args": str(optimizer_args).replace('"', '').split(), + "optimizer_args": str(optimizer_args).replace('"', "").split(), "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": persistent_data_loader_workers, @@ -766,7 +770,7 @@ def train_model( run_cmd.append(f"--config_file") run_cmd.append(tmpfilename) - + # Initialize a dictionary with always-included keyword arguments kwargs_for_training = { "max_data_loader_n_workers": max_data_loader_n_workers, @@ -811,7 +815,7 @@ def train_model( # Run the command - executor.execute_command(run_cmd=run_cmd, env=env) + executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) return ( gr.Button(visible=False), @@ -820,11 +824,19 @@ def train_model( ) -def ti_tab(headless=False, default_output_dir=None, config: dict = {}): +def ti_tab( + headless=False, + default_output_dir=None, + config: KohyaSSGUIConfig = {}, + use_shell_flag: bool = False, +): dummy_db_true = gr.Checkbox(value=True, visible=False) dummy_db_false = gr.Checkbox(value=False, visible=False) dummy_headless = gr.Checkbox(value=headless, visible=False) + global use_shell + use_shell = use_shell_flag + current_embedding_dir = ( default_output_dir if default_output_dir is not None and default_output_dir != "" diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py index 408ce15b4..f5209d42f 100644 --- a/kohya_gui/utilities.py +++ b/kohya_gui/utilities.py @@ -8,6 +8,7 @@ from .wd14_caption_gui import gradio_wd14_caption_gui_tab from .manual_caption_gui import gradio_manual_caption_gui_tab from .group_images_gui import gradio_group_images_gui_tab +from .class_gui_config import KohyaSSGUIConfig def utilities_tab( @@ -18,17 +19,18 @@ def utilities_tab( enable_copy_info_button=bool(False), enable_dreambooth_tab=True, headless=False, - config: dict = {}, + config: KohyaSSGUIConfig = {}, + use_shell_flag: bool = False, ): with gr.Tab("Captioning"): gradio_basic_caption_gui_tab(headless=headless) - gradio_blip_caption_gui_tab(headless=headless) + gradio_blip_caption_gui_tab(headless=headless, use_shell=use_shell_flag) gradio_blip2_caption_gui_tab(headless=headless) - gradio_git_caption_gui_tab(headless=headless) - gradio_wd14_caption_gui_tab(headless=headless, config=config) + gradio_git_caption_gui_tab(headless=headless, use_shell=use_shell_flag) + gradio_wd14_caption_gui_tab(headless=headless, config=config, use_shell=use_shell_flag) gradio_manual_caption_gui_tab(headless=headless) - gradio_convert_model_tab(headless=headless) - gradio_group_images_gui_tab(headless=headless) + gradio_convert_model_tab(headless=headless, use_shell=use_shell_flag) + gradio_group_images_gui_tab(headless=headless, use_shell=use_shell_flag) return ( train_data_dir_input, diff --git a/kohya_gui/verify_lora_gui.py b/kohya_gui/verify_lora_gui.py index 0f153fca4..ec5bdf087 100644 --- a/kohya_gui/verify_lora_gui.py +++ b/kohya_gui/verify_lora_gui.py @@ -24,6 +24,7 @@ def verify_lora( lora_model, + use_shell: bool = False, ): # verify for caption_text_input if lora_model == "": @@ -37,11 +38,13 @@ def verify_lora( # Build the command to run check_lora_weights.py run_cmd = [ - PYTHON, f"{scriptdir}/sd-scripts/networks/check_lora_weights.py", lora_model + PYTHON, + f"{scriptdir}/sd-scripts/networks/check_lora_weights.py", + lora_model, ] # Log the command - log.info(' '.join(run_cmd)) + log.info(" ".join(run_cmd)) # Set the environment variable for the Python path env = os.environ.copy() @@ -57,6 +60,7 @@ def verify_lora( stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, + shell=use_shell, ) output, error = process.communicate() @@ -68,7 +72,7 @@ def verify_lora( ### -def gradio_verify_lora_tab(headless=False): +def gradio_verify_lora_tab(headless=False, use_shell: bool = False): current_model_dir = os.path.join(scriptdir, "outputs") def list_models(path): @@ -139,6 +143,7 @@ def list_models(path): verify_lora, inputs=[ lora_model, + gr.Checkbox(value=use_shell, visible=False), ], outputs=[lora_model_verif_output, lora_model_verif_error], show_progress=False, diff --git a/kohya_gui/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py index eb9e836bb..62266201d 100644 --- a/kohya_gui/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -1,7 +1,13 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs, get_executable_path +from .common_gui import ( + get_folder_path, + add_pre_postfix, + scriptdir, + list_dirs, + get_executable_path, +) from .class_gui_config import KohyaSSGUIConfig import os @@ -34,6 +40,7 @@ def caption_images( use_rating_tags_as_last_tag: bool, remove_underscore: bool, thresh: float, + use_shell: bool = False, ) -> None: # Check for images_dir_input if train_data_dir == "": @@ -46,7 +53,9 @@ def caption_images( log.info(f"Captioning files in {train_data_dir}...") run_cmd = [ - get_executable_path("accelerate"), "launch", f"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py" + get_executable_path("accelerate"), + "launch", + f"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py", ] # Uncomment and modify if needed @@ -106,7 +115,7 @@ def caption_images( run_cmd.append(train_data_dir) # Log the command - log.info(' '.join(run_cmd)) + log.info(" ".join(run_cmd)) env = os.environ.copy() env["PYTHONPATH"] = ( @@ -116,9 +125,8 @@ def caption_images( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - subprocess.run(run_cmd, env=env) + subprocess.run(run_cmd, env=env, shell=use_shell) - # Add prefix and postfix add_pre_postfix( folder=train_data_dir, @@ -135,7 +143,10 @@ def caption_images( def gradio_wd14_caption_gui_tab( - headless=False, default_train_dir=None, config: KohyaSSGUIConfig = {} + headless=False, + default_train_dir=None, + config: KohyaSSGUIConfig = {}, + use_shell: bool = False, ): from .common_gui import create_refresh_button @@ -374,6 +385,7 @@ def list_train_dirs(path): use_rating_tags_as_last_tag, remove_underscore, thresh, + gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, )