Skip to content

Commit

Permalink
Implement use_shell as parameter (#2297)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais authored Apr 16, 2024
1 parent f4658f9 commit e6a8dec
Show file tree
Hide file tree
Showing 25 changed files with 241 additions and 109 deletions.
3 changes: 3 additions & 0 deletions config example.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copy this file and name it config.toml
# Edit the values to suit your needs

[settings]
use_shell = false # Use shell furing process run of sd-scripts oython code. Most secure is false but some systems may require it to be true to properly run sd-scripts.

# Default folders location
[model]
models_dir = "./models" # Pretrained model name or path
Expand Down
40 changes: 31 additions & 9 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript


def UI(**kwargs):
add_javascript(kwargs.get("language"))
css = ""
Expand All @@ -35,23 +36,36 @@ def UI(**kwargs):
interface = gr.Blocks(
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
)

config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))

if config.is_config_loaded():
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")

use_shell_flag = kwargs.get("use_shell", False)
if use_shell_flag == False:
use_shell_flag = config.get("settings.use_shell", False)
if use_shell_flag:
log.info("Using shell=True when running external commands...")

with interface:
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless, config=config)
) = dreambooth_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
with gr.Tab("LoRA"):
lora_tab(headless=headless, config=config)
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
with gr.Tab("Textual Inversion"):
ti_tab(headless=headless, config=config)
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
with gr.Tab("Finetuning"):
finetune_tab(headless=headless, config=config)
finetune_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
Expand All @@ -61,9 +75,10 @@ def UI(**kwargs):
enable_copy_info_button=True,
headless=headless,
config=config,
use_shell_flag=use_shell_flag,
)
with gr.Tab("LoRA"):
_ = LoRATools(headless=headless)
_ = LoRATools(headless=headless, use_shell_flag=use_shell_flag)
with gr.Tab("About"):
gr.Markdown(f"kohya_ss GUI release {release}")
with gr.Tab("README"):
Expand Down Expand Up @@ -102,6 +117,7 @@ def UI(**kwargs):
launch_kwargs["debug"] = True
interface.launch(**launch_kwargs)


if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -141,11 +157,17 @@ def UI(**kwargs):

parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")

parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")

parser.add_argument(
"--use_shell", action="store_true", help="Use shell environment"
)

parser.add_argument(
"--do_not_share", action="store_true", help="Do not share the gradio UI"
)

args = parser.parse_args()

# Set up logging
log = setup_logging(debug=args.debug)

Expand Down
6 changes: 4 additions & 2 deletions kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def caption_images(
beam_search: bool,
prefix: str = "",
postfix: str = "",
use_shell: bool = False,
) -> None:
"""
Automatically generates captions for images in the specified directory using the BLIP model.
Expand Down Expand Up @@ -96,7 +97,7 @@ def caption_images(
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, env=env, cwd=f"{scriptdir}/sd-scripts")
subprocess.run(run_cmd, env=env, shell=use_shell, cwd=f"{scriptdir}/sd-scripts")


# Add prefix and postfix
Expand All @@ -115,7 +116,7 @@ def caption_images(
###


def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None):
def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None, use_shell: bool = False):
from .common_gui import create_refresh_button

default_train_dir = (
Expand Down Expand Up @@ -205,6 +206,7 @@ def list_train_dirs(path):
beam_search,
prefix,
postfix,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
Expand Down
17 changes: 9 additions & 8 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import psutil
import time
import gradio as gr
import shlex

from .custom_logging import setup_logging

# Set up logging
Expand All @@ -21,7 +21,7 @@ def __init__(self):
self.process = None
self.run_state = gr.Textbox(value="", visible=False)

def execute_command(self, run_cmd: str, **kwargs):
def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs):
"""
Execute a command if no other command is currently running.
Expand All @@ -36,11 +36,12 @@ def execute_command(self, run_cmd: str, **kwargs):
# log.info(f"{i}: {item}")

# Reconstruct the safe command string for display
command_to_run = ' '.join(run_cmd)
log.info(f"Executing command: {command_to_run}")
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run} with shell={use_shell}")

# Execute the command securely
self.process = subprocess.Popen(run_cmd, **kwargs)
self.process = subprocess.Popen(run_cmd, **kwargs, shell=use_shell)
log.info("Command executed.")

def kill_command(self):
"""
Expand All @@ -64,9 +65,9 @@ def kill_command(self):
log.info(f"Error when terminating process: {e}")
else:
log.info("There is no running process to kill.")

return gr.Button(visible=True), gr.Button(visible=False)

def wait_for_training_to_end(self):
while self.is_running():
time.sleep(1)
Expand All @@ -81,4 +82,4 @@ def is_running(self):
Returns:
- bool: True if the command is running, False otherwise.
"""
return self.process and self.process.poll() is None
return self.process and self.process.poll() is None
11 changes: 11 additions & 0 deletions kohya_gui/class_gui_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,14 @@ def get(self, key: str, default=None):
# Return the final value
log.debug(f"Returned {data}")
return data

def is_config_loaded(self) -> bool:
"""
Checks if the configuration was loaded from a file.
Returns:
bool: True if the configuration was loaded from a file, False otherwise.
"""
is_loaded = self.config != {}
log.debug(f"Configuration was loaded from file: {is_loaded}")
return is_loaded
26 changes: 14 additions & 12 deletions kohya_gui/class_lora_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@


class LoRATools:
def __init__(self, headless: bool = False):
self.headless = headless

def __init__(
self,
headless: bool = False,
use_shell_flag: bool = False,
):
gr.Markdown("This section provide various LoRA tools...")
gradio_extract_dylora_tab(headless=headless)
gradio_convert_lcm_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)
gradio_extract_lycoris_locon_tab(headless=headless)
gradio_merge_lora_tab = GradioMergeLoRaTab()
gradio_merge_lycoris_tab(headless=headless)
gradio_svd_merge_lora_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
gradio_extract_dylora_tab(headless=headless, use_shell=use_shell_flag)
gradio_convert_lcm_tab(headless=headless, use_shell=use_shell_flag)
gradio_extract_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_extract_lycoris_locon_tab(headless=headless, use_shell=use_shell_flag)
gradio_merge_lora_tab = GradioMergeLoRaTab(use_shell=use_shell_flag)
gradio_merge_lycoris_tab(headless=headless, use_shell=use_shell_flag)
gradio_svd_merge_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_resize_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_verify_lora_tab(headless=headless, use_shell=use_shell_flag)
3 changes: 1 addition & 2 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from easygui import msgbox, ynbox
from typing import Optional
from .custom_logging import setup_logging
from .class_command_executor import CommandExecutor

import os
import re
Expand All @@ -12,7 +11,6 @@
import json
import math
import shutil
import time

# Set up logging
log = setup_logging()
Expand All @@ -23,6 +21,7 @@
document_symbol = "\U0001F4C4" # 📄

scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))

if os.name == "nt":
scriptdir = scriptdir.replace("\\", "/")

Expand Down
22 changes: 17 additions & 5 deletions kohya_gui/convert_lcm_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
PYTHON = sys.executable


def convert_lcm(name, model_path, lora_scale, model_type):
def convert_lcm(
name,
model_path,
lora_scale,
model_type,
use_shell: bool = False,
):
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'

# Check if source model exist
Expand Down Expand Up @@ -62,7 +68,7 @@ def convert_lcm(name, model_path, lora_scale, model_type):
run_cmd.append("--ssd-1b")

# Log the command
log.info(' '.join(run_cmd))
log.info(" ".join(run_cmd))

# Set up the environment
env = os.environ.copy()
Expand All @@ -72,13 +78,13 @@ def convert_lcm(name, model_path, lora_scale, model_type):
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command
subprocess.run(run_cmd, env=env)
subprocess.run(run_cmd, env=env, shell=use_shell)

# Return a success message
log.info("Done extracting...")


def gradio_convert_lcm_tab(headless=False):
def gradio_convert_lcm_tab(headless=False, use_shell: bool = False):
current_model_dir = os.path.join(scriptdir, "outputs")
current_save_dir = os.path.join(scriptdir, "outputs")

Expand Down Expand Up @@ -183,6 +189,12 @@ def list_save_to(path):

extract_button.click(
convert_lcm,
inputs=[name, model_path, lora_scale, model_type],
inputs=[
name,
model_path,
lora_scale,
model_type,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
6 changes: 4 additions & 2 deletions kohya_gui/convert_model_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def convert_model(
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
use_shell: bool = False,
):
# Check for caption_text_input
if source_model_type == "":
Expand Down Expand Up @@ -107,7 +108,7 @@ def convert_model(
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command
subprocess.run(run_cmd, env=env)
subprocess.run(run_cmd, env=env, shell=use_shell)



Expand All @@ -116,7 +117,7 @@ def convert_model(
###


def gradio_convert_model_tab(headless=False):
def gradio_convert_model_tab(headless=False, use_shell: bool = False):
from .common_gui import create_refresh_button

default_source_model = os.path.join(scriptdir, "outputs")
Expand Down Expand Up @@ -276,6 +277,7 @@ def list_target_folder(path):
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
7 changes: 6 additions & 1 deletion kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

# Setup huggingface
huggingface = None
use_shell = False

PYTHON = sys.executable

Expand Down Expand Up @@ -843,7 +844,7 @@ def train_model(

# Run the command

executor.execute_command(run_cmd=run_cmd, env=env)
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)

return (
gr.Button(visible=False),
Expand All @@ -859,10 +860,14 @@ def dreambooth_tab(
# logging_dir=gr.Textbox(),
headless=False,
config: KohyaSSGUIConfig = {},
use_shell_flag: bool = False,
):
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)

global use_shell
use_shell = use_shell_flag

with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya dreambooth python code...")
Expand Down
Loading

0 comments on commit e6a8dec

Please sign in to comment.