Skip to content

Commit

Permalink
Merge pull request #2224 from bmaltais/wd14
Browse files Browse the repository at this point in the history
Update WD14 captioning
  • Loading branch information
bmaltais authored Apr 7, 2024
2 parents 33ef8b3 + d812b1f commit 1bd9f50
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 88 deletions.
27 changes: 16 additions & 11 deletions kohya_gui/blip2_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def generate_caption(
max_new_tokens=40,
min_new_tokens=20,
do_sample=True,
temperature=1.0,
top_p=0.0,
):
"""
Expand Down Expand Up @@ -108,6 +109,7 @@ def generate_caption(
top_p=top_p,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
)

generated_text = processor.batch_decode(
Expand Down Expand Up @@ -154,7 +156,7 @@ def caption_images_beam_search(
model=model,
device=device,
num_beams=int(num_beams),
repetition_penalty=repetition_penalty,
repetition_penalty=float(repetition_penalty),
length_penalty=length_penalty,
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
Expand All @@ -165,6 +167,7 @@ def caption_images_beam_search(
def caption_images_nucleus(
directory_path,
do_sample,
temperature,
top_p,
min_new_tokens,
max_new_tokens,
Expand All @@ -190,6 +193,7 @@ def caption_images_nucleus(
model=model,
device=device,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
Expand Down Expand Up @@ -278,16 +282,6 @@ def list_train_dirs(path):
label="Number of beams",
)

temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
info="used with nucleus sampling",
)

len_penalty = gr.Slider(
minimum=-1.0,
maximum=2.0,
Expand Down Expand Up @@ -326,6 +320,16 @@ def list_train_dirs(path):
with gr.Tab("Nucleus sampling"):
with gr.Row():
do_sample = gr.Checkbox(label="Sample", value=True)

temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
info="used with nucleus sampling",
)

top_p = gr.Slider(
minimum=-0,
Expand All @@ -344,6 +348,7 @@ def list_train_dirs(path):
inputs=[
directory_path_dir,
do_sample,
temperature,
top_p,
min_new_tokens,
max_new_tokens,
Expand Down
179 changes: 102 additions & 77 deletions kohya_gui/wd14_caption_gui.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs
from .common_gui import get_folder_path, scriptdir, list_dirs
import os

from .custom_logging import setup_logging
Expand All @@ -16,43 +16,24 @@ def caption_images(
batch_size: int,
general_threshold: float,
character_threshold: float,
replace_underscores: bool,
repo_id: str,
recursive: bool,
max_data_loader_n_workers: int,
debug: bool,
undesired_tags: str,
frequency_tags: bool,
prefix: str,
postfix: str,
always_first_tags: str,
onnx: bool,
append_tags: bool,
force_download: bool,
caption_separator: str,
tag_replacement: bool,
character_tag_expand: str,
use_rating_tags: bool,
use_ratuse_rating_tags_as_last_taging_tags: bool,
remove_underscore: bool,
thresh: float,
) -> None:
"""
Captions images in a given directory using the WD14 model.
Args:
train_data_dir (str): The directory containing the images to be captioned.
caption_extension (str): The extension to be used for the caption files.
batch_size (int): The batch size for the captioning process.
general_threshold (float): The general threshold for the captioning process.
character_threshold (float): The character threshold for the captioning process.
replace_underscores (bool): Whether to replace underscores in filenames with spaces.
repo_id (str): The ID of the repository containing the WD14 model.
recursive (bool): Whether to process subdirectories recursively.
max_data_loader_n_workers (int): The maximum number of workers for the data loader.
debug (bool): Whether to enable debug mode.
undesired_tags (str): Comma-separated list of tags to be removed from the captions.
frequency_tags (bool): Whether to include frequency tags in the captions.
prefix (str): The prefix to be added to the captions.
postfix (str): The postfix to be added to the captions.
onnx (bool): Whether to use ONNX for the captioning process.
append_tags (bool): Whether to append tags to existing tags.
force_download (bool): Whether to force the model to be downloaded.
caption_separator (str): The separator to be used for the captions.
"""
# Check for images_dir_input
if train_data_dir == "":
msgbox("Image folder is missing...")
Expand All @@ -64,11 +45,15 @@ def caption_images(

log.info(f"Captioning files in {train_data_dir}...")
run_cmd = rf'accelerate launch "{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py"'
if always_first_tags:
run_cmd += f' --always_first_tags="{always_first_tags}'
if append_tags:
run_cmd += f" --append_tags"
run_cmd += f" --batch_size={int(batch_size)}"
run_cmd += f' --caption_extension="{caption_extension}"'
run_cmd += f' --caption_separator="{caption_separator}"'
if character_tag_expand:
run_cmd += f' --character_tag_expand="{character_tag_expand}"'
run_cmd += f" --character_threshold={character_threshold}"
if debug:
run_cmd += f" --debug"
Expand All @@ -82,11 +67,19 @@ def caption_images(
run_cmd += f" --onnx"
if recursive:
run_cmd += f" --recursive"
if replace_underscores:
if remove_underscore:
run_cmd += f" --remove_underscore"
run_cmd += f' --repo_id="{repo_id}"'
if tag_replacement:
run_cmd += f" --tag_replacement"
if thresh:
run_cmd += f" --thresh={thresh}"
if not undesired_tags == "":
run_cmd += f' --undesired_tags="{undesired_tags}"'
if use_rating_tags:
run_cmd += f" --use_rating_tags"
if use_ratuse_rating_tags_as_last_taging_tags:
run_cmd += f" --use_ratuse_rating_tags_as_last_taging_tags"
run_cmd += rf' "{train_data_dir}"'

log.info(run_cmd)
Expand All @@ -95,18 +88,11 @@ def caption_images(
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command
subprocess.run(run_cmd, shell=True, env=env)

# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_extension,
prefix=prefix,
postfix=postfix,
)

log.info("...captioning done")


Expand Down Expand Up @@ -162,6 +148,30 @@ def list_train_dirs(path):
outputs=train_data_dir,
show_progress=False,
)

repo_id = gr.Dropdown(
label="Repo ID",
choices=[
"SmilingWolf/wd-v1-4-convnext-tagger-v2",
"SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
"SmilingWolf/wd-v1-4-vit-tagger-v2",
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"SmilingWolf/wd-v1-4-moat-tagger-v2",
'SmilingWolf/wd-swinv2-tagger-v3',
'SmilingWolf/wd-vit-tagger-v3',
'SmilingWolf/wd-convnext-tagger-v3',
],
value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
show_label="Repo id for wd14 tagger on Hugging Face",
)

force_download = gr.Checkbox(
label="Force model re-download",
value=False,
info="Useful to force model re download when switching to onnx",
)

with gr.Row():

caption_extension = gr.Textbox(
label="Caption file extension",
Expand All @@ -175,6 +185,22 @@ def list_train_dirs(path):
value=",",
interactive=True,
)

with gr.Row():

tag_replacement = gr.Textbox(
label="Tag replacement",
info="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`",
value="",
interactive=True,
)

character_tag_expand = gr.Textbox(
label="Character tag expand",
info="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`",
value="",
interactive=True,
)

undesired_tags = gr.Textbox(
label="Undesired tags",
Expand All @@ -183,22 +209,17 @@ def list_train_dirs(path):
)

with gr.Row():
prefix = gr.Textbox(
always_first_tags = gr.Textbox(
label="Prefix to add to WD14 caption",
placeholder="(Optional)",
interactive=True,
)

postfix = gr.Textbox(
label="Postfix to add to WD14 caption",
info="comma-separated list of tags to always put at the beginning, e.g. 1girl,1boy",
placeholder="(Optional)",
interactive=True,
)

with gr.Row():
onnx = gr.Checkbox(
label="Use onnx",
value=False,
value=True,
interactive=True,
info="https://github.com/onnx/onnx",
)
Expand All @@ -208,54 +229,54 @@ def list_train_dirs(path):
interactive=True,
info="This option appends the tags to the existing tags, instead of replacing them.",
)

with gr.Row():
replace_underscores = gr.Checkbox(
label="Replace underscores in filenames with spaces",
value=True,

use_rating_tags = gr.Checkbox(
label="Use rating tags",
value=False,
interactive=True,
info="Adds rating tags as the first tag",
)

use_ratuse_rating_tags_as_last_taging_tags = gr.Checkbox(
label="Use rating tags as last tag",
value=False,
interactive=True,
info="Adds rating tags as the last tag",
)

with gr.Row():
recursive = gr.Checkbox(
label="Recursive",
value=False,
info="Tag subfolders images as well",
)
remove_underscore = gr.Checkbox(
label="Remove underscore",
value=True,
info="replace underscores with spaces in the output tags",
)

debug = gr.Checkbox(
label="Verbose logging",
label="Debug",
value=True,
info="Debug while tagging, it will print your image file with general tags and character tags.",
info="Debug mode",
)
frequency_tags = gr.Checkbox(
label="Show tags frequency",
value=True,
info="Show frequency of tags for images.",
)

# Model Settings

with gr.Row():
repo_id = gr.Dropdown(
label="Repo ID",
choices=[
"SmilingWolf/wd-v1-4-convnext-tagger-v2",
"SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
"SmilingWolf/wd-v1-4-vit-tagger-v2",
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"SmilingWolf/wd-v1-4-moat-tagger-v2",
# 'SmilingWolf/wd-swinv2-tagger-v3',
# 'SmilingWolf/wd-vit-tagger-v3',
# 'SmilingWolf/wd-convnext-tagger-v3',
],
value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
show_label="Repo id for wd14 tagger on Hugging Face",
)

force_download = gr.Checkbox(
label="Force model re-download",
value=False,
info="Useful to force model re download when switching to onnx",
thresh = gr.Slider(
value=0.35,
label="Threshold",
info="threshold of confidence to add a tag",
minimum=0,
maximum=1,
step=0.05,
)

general_threshold = gr.Slider(
value=0.35,
label="General threshold",
Expand Down Expand Up @@ -290,19 +311,23 @@ def list_train_dirs(path):
batch_size,
general_threshold,
character_threshold,
replace_underscores,
repo_id,
recursive,
max_data_loader_n_workers,
debug,
undesired_tags,
frequency_tags,
prefix,
postfix,
always_first_tags,
onnx,
append_tags,
force_download,
caption_separator,
tag_replacement,
character_tag_expand,
use_rating_tags,
use_ratuse_rating_tags_as_last_taging_tags,
remove_underscore,
thresh,
],
show_progress=False,
)
Expand Down

0 comments on commit 1bd9f50

Please sign in to comment.