Skip to content

Commit

Permalink
Merge pull request #2025 from wkpark/gui-fix
Browse files Browse the repository at this point in the history
Gui fix
  • Loading branch information
bmaltais authored Mar 2, 2024
2 parents 0c7fdc0 + e2578cf commit f8b1f2d
Show file tree
Hide file tree
Showing 27 changed files with 226 additions and 274 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b

## Change History
* 2024/03/02 (v22.7.0)
- Major code refactoring thanks to @wkpark , This will make updating sd-script cleaner by keeping sd-scripts files seperate from the GUI files.
- Major code refactoring thanks to @wkpark , This will make updating sd-script cleaner by keeping sd-scripts files separate from the GUI files.
* 2024/02/17 (v22.6.2)
- Fix issue with Lora Extract GUI
- Fix syntax issue where parameter lora_network_weights is actually called network_weights
Expand Down
14 changes: 8 additions & 6 deletions kohya_gui/basic_caption_gui.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path, add_pre_postfix, find_replace
from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir
import os
import sys

from .custom_logging import setup_logging

# Set up logging
log = setup_logging()

PYTHON = sys.executable

def caption_images(
caption_text,
Expand Down Expand Up @@ -36,7 +38,7 @@ def caption_images(
log.info(f'Captioning files in {images_dir} with {caption_text}...')

# Build the command to run caption.py
run_cmd = f'python "tools/caption.py"'
run_cmd = fr'{PYTHON} "{scriptdir}/tools/caption.py"'
run_cmd += f' --caption_text="{caption_text}"'

# Add optional flags to the command
Expand All @@ -49,11 +51,11 @@ def caption_images(

log.info(run_cmd)

env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}"

# Run the command based on the operating system
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)
subprocess.run(run_cmd, shell=True, env=env)

# Check if overwrite option is enabled
if overwrite:
Expand Down
15 changes: 8 additions & 7 deletions kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
import sys
from .common_gui import get_folder_path, add_pre_postfix, scriptdir
from .custom_logging import setup_logging

# Set up logging
log = setup_logging()

PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
PYTHON = sys.executable


def caption_images(
Expand Down Expand Up @@ -36,7 +37,7 @@ def caption_images(
log.info(f'Captioning files in {train_data_dir}...')

# Construct the command to run
run_cmd = f'{PYTHON} "finetune/make_captions.py"'
run_cmd = fr'{PYTHON} "{scriptdir}/finetune/make_captions.py"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --num_beams="{int(num_beams)}"'
run_cmd += f' --top_p="{top_p}"'
Expand All @@ -51,11 +52,11 @@ def caption_images(

log.info(run_cmd)

env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}"

# Run the command
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)
subprocess.run(run_cmd, shell=True, env=env)

# Add prefix and postfix
add_pre_postfix(
Expand Down
10 changes: 5 additions & 5 deletions kohya_gui/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def __init__(self, headless=False, finetuning: bool = False, training_type: str
def noise_offset_type_change(noise_offset_type):
if noise_offset_type == 'Original':
return (
gr.Group(visible=True),
gr.Group(visible=False),
gr.Group.update(visible=True),
gr.Group.update(visible=False),
)
else:
return (
gr.Group(visible=False),
gr.Group(visible=True),
gr.Group.update(visible=False),
gr.Group.update(visible=True),
)

with gr.Row(visible=not finetuning):
if training_type != "lora": # Not avaible for LoRA
if training_type != "lora": # Not available for LoRA
self.no_token_padding = gr.Checkbox(
label='No token padding', value=False
)
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ class CommandExecutor:
def __init__(self):
self.process = None

def execute_command(self, run_cmd):
def execute_command(self, run_cmd, **kwargs):
if self.process and self.process.poll() is None:
log.info(
'The command is already running. Please wait for it to finish.'
)
else:
self.process = subprocess.Popen(run_cmd, shell=True)
self.process = subprocess.Popen(run_cmd, shell=True, **kwargs)

def kill_command(self):
if self.process and self.process.poll() is None:
Expand Down
6 changes: 0 additions & 6 deletions kohya_gui/class_configuration_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gradio as gr
from .common_gui import remove_doublequote


class ConfigurationFile:
Expand Down Expand Up @@ -29,8 +28,3 @@ def __init__(self, headless=False):
self.button_load_config = gr.Button(
'Load 💾', elem_id='open_folder'
)
self.config_file_name.blur(
remove_doublequote,
inputs=[self.config_file_name],
outputs=[self.config_file_name],
)
28 changes: 4 additions & 24 deletions kohya_gui/class_folders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import gradio as gr
from .common_gui import remove_doublequote, get_folder_path
from .common_gui import get_folder_path


class Folders:
def __init__(self, headless=False):
def __init__(self, finetune=False, headless=False):
self.headless = headless

with gr.Row():
Expand All @@ -20,8 +20,8 @@ def __init__(self, headless=False):
show_progress=False,
)
self.reg_data_dir = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
label='Regularisation folder' if not finetune else 'Train config folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located' if not finetune else "folder where the training configuration files will be saved",
)
self.reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not self.headless)
Expand Down Expand Up @@ -68,23 +68,3 @@ def __init__(self, headless=False):
placeholder='(Optional) Add training comment to be included in metadata',
interactive=True,
)
self.train_data_dir.blur(
remove_doublequote,
inputs=[self.train_data_dir],
outputs=[self.train_data_dir],
)
self.reg_data_dir.blur(
remove_doublequote,
inputs=[self.reg_data_dir],
outputs=[self.reg_data_dir],
)
self.output_dir.blur(
remove_doublequote,
inputs=[self.output_dir],
outputs=[self.output_dir],
)
self.logging_dir.blur(
remove_doublequote,
inputs=[self.logging_dir],
outputs=[self.logging_dir],
)
37 changes: 22 additions & 15 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄

scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))

# define a list of substrings to search for v2 base models
V2_BASE_MODELS = [
"stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned",
Expand Down Expand Up @@ -206,13 +208,6 @@ def get_any_file_path(file_path=""):
return file_path


def remove_doublequote(file_path):
if file_path != None:
file_path = file_path.replace('"', "")

return file_path


def get_folder_path(folder_path=""):
if not any(var in os.environ for var in ENV_EXCLUSION) and sys.platform != "darwin":
current_folder_path = folder_path
Expand Down Expand Up @@ -405,9 +400,9 @@ def color_aug_changed(color_aug):
msgbox(
'Disabling "Cache latent" because "Color augmentation" has been selected...'
)
return gr.Checkbox(value=False, interactive=False)
return gr.Checkbox.update(value=False, interactive=False)
else:
return gr.Checkbox(value=True, interactive=True)
return gr.Checkbox.update(value=True, interactive=True)


def save_inference_file(output_dir, v2, v_parameterization, output_name):
Expand All @@ -429,15 +424,15 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
f"Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml"
)
shutil.copy(
f"./v2_inference/v2-inference-v.yaml",
fr"{scriptdir}/v2_inference/v2-inference-v.yaml",
f"{output_dir}/{file_name}.yaml",
)
elif v2:
log.info(
f"Saving v2-inference.yaml as {output_dir}/{file_name}.yaml"
)
shutil.copy(
f"./v2_inference/v2-inference.yaml",
fr"{scriptdir}/v2_inference/v2-inference.yaml",
f"{output_dir}/{file_name}.yaml",
)

Expand Down Expand Up @@ -800,7 +795,10 @@ def run_cmd_advanced_training(**kwargs):

logging_dir = kwargs.get("logging_dir")
if logging_dir:
run_cmd += f' --logging_dir="{logging_dir}"'
if logging_dir.startswith('"') and logging_dir.endswith('"'):
logging_dir = logging_dir[1:-1]
if os.path.exists(logging_dir):
run_cmd += fr' --logging_dir="{logging_dir}"'

lora_network_weights = kwargs.get("lora_network_weights")
if lora_network_weights:
Expand Down Expand Up @@ -967,7 +965,10 @@ def run_cmd_advanced_training(**kwargs):

output_dir = kwargs.get("output_dir")
if output_dir:
run_cmd += f' --output_dir="{output_dir}"'
if output_dir.startswith('"') and output_dir.endswith('"'):
output_dir = output_dir[1:-1]
if os.path.exists(output_dir):
run_cmd += fr' --output_dir="{output_dir}"'

output_name = kwargs.get("output_name")
if output_name and not output_name == "":
Expand All @@ -991,7 +992,10 @@ def run_cmd_advanced_training(**kwargs):

reg_data_dir = kwargs.get("reg_data_dir")
if reg_data_dir and len(reg_data_dir):
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
if reg_data_dir.startswith('"') and reg_data_dir.endswith('"'):
reg_data_dir = reg_data_dir[1:-1]
if os.path.isdir(reg_data_dir):
run_cmd += fr' --reg_data_dir="{reg_data_dir}"'

resume = kwargs.get("resume")
if resume:
Expand Down Expand Up @@ -1059,7 +1063,10 @@ def run_cmd_advanced_training(**kwargs):

train_data_dir = kwargs.get("train_data_dir")
if train_data_dir:
run_cmd += f' --train_data_dir="{train_data_dir}"'
if train_data_dir.startswith('"') and train_data_dir.endswith('"'):
train_data_dir = train_data_dir[1:-1]
if os.path.exists(train_data_dir):
run_cmd += fr' --train_data_dir="{train_data_dir}"'

train_text_encoder = kwargs.get("train_text_encoder")
if train_text_encoder:
Expand Down
14 changes: 8 additions & 6 deletions kohya_gui/convert_lcm_gui.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import gradio as gr
import os
import subprocess
import sys
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
)
from .custom_logging import setup_logging

Expand All @@ -15,7 +17,7 @@
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄

PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe"
PYTHON = sys.executable


def convert_lcm(
Expand All @@ -24,7 +26,7 @@ def convert_lcm(
lora_scale,
model_type
):
run_cmd = f'{PYTHON} "{os.path.join("tools","lcm_convert.py")}"'
run_cmd = fr'{PYTHON} "{scriptdir}/tools/lcm_convert.py"'
# Construct the command to run the script
run_cmd += f' --name "{name}"'
run_cmd += f' --model "{model_path}"'
Expand All @@ -37,11 +39,11 @@ def convert_lcm(

log.info(run_cmd)

env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}"

# Run the command
if os.name == "posix":
os.system(run_cmd)
else:
subprocess.run(run_cmd)
subprocess.run(run_cmd, shell=True, env=env)

# Return a success message
log.info("Done extracting...")
Expand Down
Loading

0 comments on commit f8b1f2d

Please sign in to comment.