Skip to content

Commit

Permalink
Optimising some code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 7, 2024
1 parent 32e2927 commit 271b406
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 91 deletions.
17 changes: 8 additions & 9 deletions kohya_gui/basic_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,14 @@ def gradio_basic_caption_gui_tab(headless=False):
placeholder='Directory containing the images to caption',
interactive=True,
)
if not headless:
folder_button = gr.Button(
'📂', elem_id='open_folder_small'
)
folder_button.click(
get_folder_path,
outputs=images_dir,
show_progress=False,
)
folder_button = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
folder_button.click(
get_folder_path,
outputs=images_dir,
show_progress=False,
)
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
Expand Down
97 changes: 15 additions & 82 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,98 +577,31 @@ def get_int_or_default(kwargs, key, default_value=0):


def get_float_or_default(kwargs, key, default_value=0.0):
# Try to retrieve the value for the specified key from the kwargs.
# Use the provided default_value if the key does not exist.
value = kwargs.get(key, default_value)
if isinstance(value, float):
return value
elif isinstance(value, int):
return float(value)
elif isinstance(value, str):

try:
# Try to convert the value to a float. This should works for int, float,
# and strings that represent a valid floating-point number.
return float(value)
else:
log.info(
f"{key} is not an int, float or a string, setting value to {default_value}"
)
except ValueError:
# If the conversion fails (for example, the value is a string that cannot
# be converted to a float), log the issue and return the provided default_value.
log.info(f"{key} is not an int, float or a valid string for conversion, setting value to {default_value}")
return default_value


def get_str_or_default(kwargs, key, default_value=""):
value = kwargs.get(key, default_value)

# Check if the retrieved value is already a string.
if isinstance(value, str):
return value
elif isinstance(value, int):
return str(value)
elif isinstance(value, str):
return str(value)
else:
return default_value


# def run_cmd_training(**kwargs):
# run_cmd = ""

# lr_scheduler = kwargs.get("lr_scheduler", "")
# if lr_scheduler:
# run_cmd += f' --lr_scheduler="{lr_scheduler}"'

# lr_warmup_steps = kwargs.get("lr_warmup_steps", "")
# if lr_warmup_steps:
# if lr_scheduler == "constant":
# log.info("Can't use LR warmup with LR Scheduler constant... ignoring...")
# else:
# run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"'

# train_batch_size = kwargs.get("train_batch_size", "")
# if train_batch_size:
# run_cmd += f' --train_batch_size="{train_batch_size}"'

# max_train_steps = kwargs.get("max_train_steps", "")
# if max_train_steps:
# run_cmd += f' --max_train_steps="{max_train_steps}"'

# save_every_n_epochs = kwargs.get("save_every_n_epochs")
# if save_every_n_epochs:
# run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"'

# mixed_precision = kwargs.get("mixed_precision", "")
# if mixed_precision:
# run_cmd += f' --mixed_precision="{mixed_precision}"'

# save_precision = kwargs.get("save_precision", "")
# if save_precision:
# run_cmd += f' --save_precision="{save_precision}"'

# seed = kwargs.get("seed", "")
# if seed != "":
# run_cmd += f' --seed="{seed}"'

# caption_extension = kwargs.get("caption_extension", "")
# if caption_extension:
# run_cmd += f' --caption_extension="{caption_extension}"'

# cache_latents = kwargs.get("cache_latents")
# if cache_latents:
# run_cmd += " --cache_latents"

# cache_latents_to_disk = kwargs.get("cache_latents_to_disk")
# if cache_latents_to_disk:
# run_cmd += " --cache_latents_to_disk"

# optimizer_type = kwargs.get("optimizer", "AdamW")
# run_cmd += f' --optimizer_type="{optimizer_type}"'

# optimizer_args = kwargs.get("optimizer_args", "")
# if optimizer_args != "":
# run_cmd += f" --optimizer_args {optimizer_args}"

# lr_scheduler_args = kwargs.get("lr_scheduler_args", "")
# if lr_scheduler_args != "":
# run_cmd += f" --lr_scheduler_args {lr_scheduler_args}"

# max_grad_norm = kwargs.get("max_grad_norm", "")
# if max_grad_norm != "":
# run_cmd += f' --max_grad_norm="{max_grad_norm}"'

# return run_cmd
# If the value is not a string (e.g., int, float, or any other type),
# convert it to a string and return the converted value.
return str(value)


def run_cmd_advanced_training(**kwargs):
Expand Down

0 comments on commit 271b406

Please sign in to comment.