Skip to content

Commit

Permalink
Fix issues with list_images_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 7, 2024
1 parent c1536c2 commit ebc3296
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 44 deletions.
5 changes: 4 additions & 1 deletion kohya_gui/basic_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def list_images_dirs(path):
allow_custom_value=True,
)
# Refresh button for image folder
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dir(current_images_dir)},"open_folder_small")
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
# Button to open folder
folder_button = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless)
Expand Down Expand Up @@ -154,6 +154,7 @@ def list_images_dirs(path):
label='Caption text',
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
interactive=True,
lines=2,
)
# Textbox for caption postfix
postfix = gr.Textbox(
Expand All @@ -168,12 +169,14 @@ def list_images_dirs(path):
label='Find text',
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
interactive=True,
lines=2,
)
# Textbox for replace text
replace_text = gr.Textbox(
label='Replacement text',
placeholder='e.g., "by some artist". Leave empty if you want to replace with nothing.',
interactive=True,
lines=2,
)
# Button to caption images
caption_button = gr.Button('Caption images')
Expand Down
87 changes: 47 additions & 40 deletions kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,40 @@ def caption_images(
postfix,
):
# Check if the image folder is provided
if train_data_dir == '':
msgbox('Image folder is missing...')
if train_data_dir == "":
msgbox("Image folder is missing...")
return

# Check if the caption file extension is provided
if caption_file_ext == '':
msgbox('Please provide an extension for the caption files.')
if caption_file_ext == "":
msgbox("Please provide an extension for the caption files.")
return

log.info(f'Captioning files in {train_data_dir}...')
log.info(f"Captioning files in {train_data_dir}...")

# Construct the command to run
run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/make_captions.py"'
run_cmd = rf'{PYTHON} "{scriptdir}/sd-scripts/finetune/make_captions.py"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --num_beams="{int(num_beams)}"'
run_cmd += f' --top_p="{top_p}"'
run_cmd += f' --max_length="{int(max_length)}"'
run_cmd += f' --min_length="{int(min_length)}"'
if beam_search:
run_cmd += f' --beam_search'
if caption_file_ext != '':
run_cmd += f" --beam_search"
if caption_file_ext != "":
run_cmd += f' --caption_extension="{caption_file_ext}"'
run_cmd += f' "{train_data_dir}"'
run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'

log.info(run_cmd)

env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)

# Run the command
subprocess.run(run_cmd, shell=True, env=env)
# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, shell=True, env=env, cwd=rf"{scriptdir}/sd-scripts")

# Add prefix and postfix
add_pre_postfix(
Expand All @@ -66,7 +68,7 @@ def caption_images(
postfix=postfix,
)

log.info('...captioning done')
log.info("...captioning done")


###
Expand All @@ -77,28 +79,41 @@ def caption_images(
def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None):
from .common_gui import create_refresh_button

default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data")
default_train_dir = (
default_train_dir
if default_train_dir is not None
else os.path.join(scriptdir, "data")
)
current_train_dir = default_train_dir

def list_train_dirs(path):
nonlocal current_train_dir
current_train_dir = path
return list(list_dirs(path))

with gr.Tab('BLIP Captioning'):
with gr.Tab("BLIP Captioning"):
gr.Markdown(
'This utility uses BLIP to caption files for each image in a folder.'
"This utility uses BLIP to caption files for each image in a folder."
)
with gr.Group(), gr.Row():
train_data_dir = gr.Dropdown(
label='Image folder to caption (containing the images to caption)',
label="Image folder to caption (containing the images to caption)",
choices=list_train_dirs(default_train_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dir(current_train_dir)},"open_folder_small")
create_refresh_button(
train_data_dir,
lambda: None,
lambda: {"choices": list_train_dirs(current_train_dir)},
"open_folder_small",
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_train_data_dir_input.click(
get_folder_path,
Expand All @@ -107,44 +122,36 @@ def list_train_dirs(path):
)
with gr.Row():
caption_file_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file, e.g., .caption, .txt',
value='.txt',
label="Caption file extension",
placeholder="Extension for caption file, e.g., .caption, .txt",
value=".txt",
interactive=True,
)

prefix = gr.Textbox(
label='Prefix to add to BLIP caption',
placeholder='(Optional)',
label="Prefix to add to BLIP caption",
placeholder="(Optional)",
interactive=True,
)

postfix = gr.Textbox(
label='Postfix to add to BLIP caption',
placeholder='(Optional)',
label="Postfix to add to BLIP caption",
placeholder="(Optional)",
interactive=True,
)

batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
batch_size = gr.Number(value=1, label="Batch size", interactive=True)

with gr.Row():
beam_search = gr.Checkbox(
label='Use beam search', interactive=True, value=True
)
num_beams = gr.Number(
value=1, label='Number of beams', interactive=True
)
top_p = gr.Number(value=0.9, label='Top p', interactive=True)
max_length = gr.Number(
value=75, label='Max length', interactive=True
)
min_length = gr.Number(
value=5, label='Min length', interactive=True
label="Use beam search", interactive=True, value=True
)
num_beams = gr.Number(value=1, label="Number of beams", interactive=True)
top_p = gr.Number(value=0.9, label="Top p", interactive=True)
max_length = gr.Number(value=75, label="Max length", interactive=True)
min_length = gr.Number(value=5, label="Min length", interactive=True)

caption_button = gr.Button('Caption images')
caption_button = gr.Button("Caption images")

caption_button.click(
caption_images,
Expand Down
9 changes: 6 additions & 3 deletions kohya_gui/manual_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data")
current_images_dir = default_images_dir

# Function to list directories
def list_images_dirs(path):
# Allows list_images_dirs to modify current_images_dir outside of this function
nonlocal current_images_dir
current_images_dir = path
return list(list_dirs(path))

Expand All @@ -288,7 +291,7 @@ def list_images_dirs(path):
interactive=True,
allow_custom_value=True,
)
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dir(current_images_dir)},"open_folder_small")
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
folder_button = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
)
Expand All @@ -297,7 +300,7 @@ def list_images_dirs(path):
outputs=images_dir,
show_progress=False,
)
load_images_button = gr.Button('Load 💾', elem_id='open_folder')
load_images_button = gr.Button('Load', elem_id='open_folder')
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file. eg: .caption, .txt',
Expand All @@ -322,7 +325,7 @@ def list_images_dirs(path):
placeholder='Comma separated list of tags',
interactive=True,
)
import_tags_button = gr.Button('Import 📄', elem_id='open_folder')
import_tags_button = gr.Button('Import', elem_id='open_folder')
ignore_load_tags_word_count = gr.Slider(
minimum=1,
maximum=100,
Expand Down

0 comments on commit ebc3296

Please sign in to comment.