From 79693f09ef6c5e5ea5ae0253ade5837998bb2072 Mon Sep 17 00:00:00 2001 From: liesen Date: Thu, 7 Mar 2024 21:12:20 +0300 Subject: [PATCH 1/3] Use snapshot_download to download the models. --- finetune/tag_images_by_wd14_tagger.py | 33 ++++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..1a05f6202 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -6,7 +6,7 @@ import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download from PIL import Image from tqdm import tqdm @@ -25,7 +25,7 @@ FILES_ONNX = ["model.onnx"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] -CSV_FILE = FILES[-1] +CSV_FILE = [-1] def preprocess_image(image): @@ -84,20 +84,21 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - files = FILES - if args.onnx: - files += FILES_ONNX - for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) + snapshot_download(args.repo_id, cache_dir=args.model_dir, force_download=True) + # files = FILES + # if args.onnx: + # files += FILES_ONNX + # for file in files: + # hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + # for file in SUB_DIR_FILES: + # hf_hub_download( + # args.repo_id, + # file, + # subfolder=SUB_DIR, + # cache_dir=os.path.join(args.model_dir, SUB_DIR), + # force_download=True, + # force_filename=file, + # ) else: logger.info("using existing wd14 tagger model") From 2367e8da96422eb521b73fd9d52fb834be8dfd5d Mon Sep 17 00:00:00 2001 From: liesen Date: Thu, 7 Mar 2024 21:16:22 +0300 Subject: [PATCH 2/3] fix CSV_FILE --- finetune/tag_images_by_wd14_tagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 1a05f6202..5b293f185 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -25,7 +25,7 @@ FILES_ONNX = ["model.onnx"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] -CSV_FILE = [-1] +CSV_FILE = FILES[-1] def preprocess_image(image): From 1b4829dc1a2ead01c190e58edea38a8f1eee582a Mon Sep 17 00:00:00 2001 From: liesen Date: Thu, 7 Mar 2024 21:17:17 +0300 Subject: [PATCH 3/3] remove old code --- finetune/tag_images_by_wd14_tagger.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 5b293f185..f68b4ae93 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -85,20 +85,6 @@ def main(args): if not os.path.exists(args.model_dir) or args.force_download: logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") snapshot_download(args.repo_id, cache_dir=args.model_dir, force_download=True) - # files = FILES - # if args.onnx: - # files += FILES_ONNX - # for file in files: - # hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - # for file in SUB_DIR_FILES: - # hf_hub_download( - # args.repo_id, - # file, - # subfolder=SUB_DIR, - # cache_dir=os.path.join(args.model_dir, SUB_DIR), - # force_download=True, - # force_filename=file, - # ) else: logger.info("using existing wd14 tagger model")