Skip to content

Commit

Permalink
Fix accelerate issue on linux (#2281)
Browse files Browse the repository at this point in the history
* Fix issue discovered after removing shell=True. This is a significant rewrite... but will make things better for the future.
  • Loading branch information
bmaltais authored Apr 13, 2024
1 parent e8b54e6 commit a093b78
Show file tree
Hide file tree
Showing 29 changed files with 1,085 additions and 695 deletions.
12 changes: 12 additions & 0 deletions =13.7.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Requirement already satisfied: rich in ./venv/lib/python3.10/site-packages (13.7.0)
Collecting rich
Using cached rich-13.7.1-py3-none-any.whl (240 kB)
Requirement already satisfied: markdown-it-py>=2.2.0 in ./venv/lib/python3.10/site-packages (from rich) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in ./venv/lib/python3.10/site-packages (from rich) (2.17.2)
Requirement already satisfied: mdurl~=0.1 in ./venv/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich) (0.1.2)
Installing collected packages: rich
Attempting uninstall: rich
Found existing installation: rich 13.7.0
Uninstalling rich-13.7.0:
Successfully uninstalled rich-13.7.0
Successfully installed rich-13.7.1
4 changes: 2 additions & 2 deletions assets/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
}

#myTensorButton {
background: radial-gradient(ellipse, #007bff, #00b0ff);
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
color: white;
border: none;
}

#myTensorButtonStop {
background: radial-gradient(ellipse, #00b0ff, #007bff);
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
color: black;
border: none;
}
20 changes: 13 additions & 7 deletions kohya_gui/basic_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,31 @@ def caption_images(
log.info(f"Captioning files in {images_dir} with {caption_text}...")

# Build the command to run caption.py
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/caption.py"'
run_cmd += f' --caption_text="{caption_text}"'
run_cmd = [PYTHON, f"{scriptdir}/tools/caption.py"]

# Add required arguments
run_cmd.append('--caption_text')
run_cmd.append(caption_text)

# Add optional flags to the command
if overwrite:
run_cmd += f" --overwrite"
run_cmd.append("--overwrite")
if caption_ext:
run_cmd += f' --caption_file_ext="{caption_ext}"'
run_cmd.append('--caption_file_ext')
run_cmd.append(caption_ext)

run_cmd += f' "{images_dir}"'
# Add the directory containing the images
run_cmd.append(images_dir)

# Log the command
log.info(run_cmd)
log.info(' '.join(run_cmd))

# Set the environment variable for the Python path
env = os.environ.copy()
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}"
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command based on the operating system
subprocess.run(run_cmd, env=env)
Expand Down
42 changes: 30 additions & 12 deletions kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,49 @@ def caption_images(

log.info(f"Captioning files in {train_data_dir}...")

# Construct the command to run
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions.py"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --num_beams="{int(num_beams)}"'
run_cmd += f' --top_p="{top_p}"'
run_cmd += f' --max_length="{int(max_length)}"'
run_cmd += f' --min_length="{int(min_length)}"'
# Construct the command to run make_captions.py
run_cmd = [PYTHON, f"{scriptdir}/sd-scripts/finetune/make_captions.py"]

# Add required arguments
run_cmd.append('--batch_size')
run_cmd.append(str(batch_size))
run_cmd.append('--num_beams')
run_cmd.append(str(num_beams))
run_cmd.append('--top_p')
run_cmd.append(str(top_p))
run_cmd.append('--max_length')
run_cmd.append(str(max_length))
run_cmd.append('--min_length')
run_cmd.append(str(min_length))

# Add optional flags to the command
if beam_search:
run_cmd += f" --beam_search"
run_cmd.append("--beam_search")
if caption_file_ext:
run_cmd += f' --caption_extension="{caption_file_ext}"'
run_cmd += f' "{train_data_dir}"'
run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'
run_cmd.append('--caption_extension')
run_cmd.append(caption_file_ext)

log.info(run_cmd)
# Add the directory containing the training data
run_cmd.append(train_data_dir)

# Add URL for caption model weights
run_cmd.append('--caption_weights')
run_cmd.append("https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth")

# Log the command
log.info(' '.join(run_cmd))

# Set up the environment
env = os.environ.copy()
env["PYTHONPATH"] = (
f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, env=env, cwd=f"{scriptdir}/sd-scripts")


# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
Expand Down
65 changes: 31 additions & 34 deletions kohya_gui/class_accelerate_launch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import gradio as gr
import os
import shlex

from .class_gui_config import KohyaSSGUIConfig


Expand Down Expand Up @@ -75,46 +77,41 @@ def __init__(
info="List of extra parameters to pass to accelerate launch",
)

def run_cmd(**kwargs):
run_cmd = ""

if "extra_accelerate_launch_args" in kwargs:
extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args")
if extra_accelerate_launch_args != "":
run_cmd += rf" {extra_accelerate_launch_args}"
def run_cmd(run_cmd: list, **kwargs):
if (
"extra_accelerate_launch_args" in kwargs
and kwargs.get("extra_accelerate_launch_args") != ""
):
run_cmd.append(kwargs["extra_accelerate_launch_args"])

if "gpu_ids" in kwargs:
gpu_ids = kwargs.get("gpu_ids")
if not gpu_ids == "":
run_cmd += f' --gpu_ids="{gpu_ids}"'
if "gpu_ids" in kwargs and kwargs.get("gpu_ids") != "":
run_cmd.append("--gpu_ids")
run_cmd.append(shlex.quote(kwargs["gpu_ids"]))

if "main_process_port" in kwargs:
main_process_port = kwargs.get("main_process_port")
if main_process_port > 0:
run_cmd += f' --main_process_port="{main_process_port}"'
if "main_process_port" in kwargs and kwargs.get("main_process_port", 0) > 0:
run_cmd.append("--main_process_port")
run_cmd.append(str(int(kwargs["main_process_port"])))

if "mixed_precision" in kwargs:
run_cmd += rf' --mixed_precision="{kwargs.get("mixed_precision")}"'
if "mixed_precision" in kwargs and kwargs.get("mixed_precision"):
run_cmd.append("--mixed_precision")
run_cmd.append(shlex.quote(kwargs["mixed_precision"]))

if "multi_gpu" in kwargs:
if kwargs.get("multi_gpu"):
run_cmd += " --multi_gpu"
if "multi_gpu" in kwargs and kwargs.get("multi_gpu"):
run_cmd.append("--multi_gpu")

if "num_processes" in kwargs:
num_processes = kwargs.get("num_processes")
if int(num_processes) > 0:
run_cmd += f" --num_processes={int(num_processes)}"
if "num_processes" in kwargs and int(kwargs.get("num_processes", 0)) > 0:
run_cmd.append("--num_processes")
run_cmd.append(str(int(kwargs["num_processes"])))

if "num_machines" in kwargs:
num_machines = kwargs.get("num_machines")
if int(num_machines) > 0:
run_cmd += f" --num_machines={int(num_machines)}"
if "num_machines" in kwargs and int(kwargs.get("num_machines", 0)) > 0:
run_cmd.append("--num_machines")
run_cmd.append(str(int(kwargs["num_machines"])))

if "num_cpu_threads_per_process" in kwargs:
num_cpu_threads_per_process = kwargs.get("num_cpu_threads_per_process")
if int(num_cpu_threads_per_process) > 0:
run_cmd += (
f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}"
)
if (
"num_cpu_threads_per_process" in kwargs
and int(kwargs.get("num_cpu_threads_per_process", 0)) > 0
):
run_cmd.append("--num_cpu_threads_per_process")
run_cmd.append(str(int(kwargs["num_cpu_threads_per_process"])))

return run_cmd
2 changes: 1 addition & 1 deletion kohya_gui/class_basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def init_grad_and_lr_controls(self) -> None:
self.lr_scheduler_args = gr.Textbox(
label="LR scheduler extra arguments",
lines=2,
placeholder='(Optional) eg: "milestones=[1,10,30,50]" "gamma=0.1"',
placeholder='(Optional) eg: milestones=[1,10,30,50] gamma=0.1',
value=self.config.get("basic.lr_scheduler_args", ""),
)
# Initialize the optimizer extra arguments textbox
Expand Down
7 changes: 7 additions & 0 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import subprocess
import psutil
import os
import gradio as gr
import shlex
from .custom_logging import setup_logging

# Set up logging
Expand Down Expand Up @@ -29,6 +31,11 @@ def execute_command(self, run_cmd: str, **kwargs):
if self.process and self.process.poll() is None:
log.info("The command is already running. Please wait for it to finish.")
else:
# Reconstruct the safe command string for display
command_to_run = ' '.join(run_cmd)
log.info(f"Executings command: {command_to_run}")

Check warning on line 36 in kohya_gui/class_command_executor.py

View workflow job for this annotation

GitHub Actions / build

"Executings" should be "Executions".

# Execute the command securely
self.process = subprocess.Popen(run_cmd, **kwargs)

def kill_command(self):
Expand Down
27 changes: 21 additions & 6 deletions kohya_gui/class_sample_images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import gradio as gr
import shlex

from .custom_logging import setup_logging
from .class_gui_config import KohyaSSGUIConfig
Expand All @@ -19,6 +20,7 @@


def run_cmd_sample(
run_cmd: list,
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
Expand All @@ -41,8 +43,6 @@ def run_cmd_sample(
output_dir = os.path.join(output_dir, "sample")
os.makedirs(output_dir, exist_ok=True)

run_cmd = ""

if sample_every_n_epochs is None:
sample_every_n_epochs = 0

Expand All @@ -58,18 +58,33 @@ def run_cmd_sample(
with open(sample_prompts_path, "w") as f:
f.write(sample_prompts)

run_cmd += f" --sample_sampler={sample_sampler}"
run_cmd += f' --sample_prompts="{sample_prompts_path}"'
# Append the sampler with proper quoting for safety against special characters
run_cmd.append("--sample_sampler")
run_cmd.append(shlex.quote(sample_sampler))

# Normalize and fix the path for the sample prompts, handle cross-platform path differences
sample_prompts_path = os.path.abspath(os.path.normpath(sample_prompts_path))
if os.name == "nt": # Normalize path for Windows
sample_prompts_path = sample_prompts_path.replace("\\", "/")

# Append the sample prompts path
run_cmd.append('--sample_prompts')
run_cmd.append(sample_prompts_path)

# Append the sampling frequency for epochs, only if non-zero
if sample_every_n_epochs != 0:
run_cmd += f" --sample_every_n_epochs={sample_every_n_epochs}"
run_cmd.append("--sample_every_n_epochs")
run_cmd.append(str(sample_every_n_epochs))

# Append the sampling frequency for steps, only if non-zero
if sample_every_n_steps != 0:
run_cmd += f" --sample_every_n_steps={sample_every_n_steps}"
run_cmd.append("--sample_every_n_steps")
run_cmd.append(str(sample_every_n_steps))

return run_cmd



class SampleImages:
"""
A class for managing the Gradio interface for sampling images during training.
Expand Down
Loading

0 comments on commit a093b78

Please sign in to comment.