Skip to content

Commit

Permalink
Manually merged: Support ML-Danbooru #6, changes amended from CCRcmcpe's
Browse files Browse the repository at this point in the history
pull request to Toriato
  • Loading branch information
Roel Kluin committed Jul 16, 2023
1 parent 2748a7e commit b7918a2
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 130 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
deepdanbooru
onnxruntime-silicon; sys_platform == 'darwin'
onnxruntime-gpu; sys_platform != 'darwin'
fastapi
gradio
huggingface_hub
Expand Down
23 changes: 23 additions & 0 deletions tagger/dbimutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@
from PIL import Image


def fill_transparent(image: Image.Image, color='WHITE'):
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, color)
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
return image


def resize(pic: Image.Image, size: int, keep_ratio=True) -> Image.Image:
if not keep_ratio:
target_size = (size, size)
else:
min_edge = min(pic.size)
target_size = (
int(pic.size[0] / min_edge * size),
int(pic.size[1] / min_edge * size),
)

target_size = (target_size[0] & ~3, target_size[1] & ~3)

return pic.resize(target_size, resample=Image.Resampling.LANCZOS)


def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
""" Read an image, convert to 24-bit if necessary """
if img.endswith(".gif"):
Expand Down
197 changes: 128 additions & 69 deletions tagger/interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
from hashlib import sha256
import json
import numpy as np
from platform import system
from typing import Tuple, List, Dict, Callable
from pandas import read_csv, read_json
Expand All @@ -17,15 +18,21 @@
from . import dbimutils
from tagger import settings
from tagger.uiset import QData, IOData, ItRetTP
import gradio as gr

Its = settings.InterrogatorSettings

# select a device to process
use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
'interrogate' in shared.cmd_opts.use_cpu)

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

if use_cpu:
TF_DEVICE_NAME = '/cpu:0'
onnxrt_providers.pop(0)
else:
TF_DEVICE_NAME = '/gpu:0'

Expand Down Expand Up @@ -63,7 +70,7 @@ class Interrogator:
"output_dir": '',
}
output = None
#odd_increment = 0
# odd_increment = 0

@classmethod
def flip(cls, key):
Expand Down Expand Up @@ -131,10 +138,12 @@ def unload(self) -> bool:
del self.model
self.model = None
unloaded = True
gr.collect()
print(f'Unloaded {self.name}')

if hasattr(self, 'tags'):
del self.tags
self.tags = None

return unloaded

Expand Down Expand Up @@ -259,6 +268,7 @@ class DeepDanbooruInterrogator(Interrogator):
def __init__(self, name: str, project_path: os.PathLike) -> None:
super().__init__(name)
self.project_path = project_path
self.model = None
self.tags = None

def load(self) -> None:
Expand Down Expand Up @@ -331,7 +341,7 @@ def interrogate(
Dict[str, float] # tag confidences
]:
# init model
if not hasattr(self, 'model') or self.model is None:
if self.model is None:
self.load()

import deepdanbooru.data as ddd
Expand Down Expand Up @@ -363,36 +373,62 @@ def interrogate(
return ratings, tags


def get_onnxrt():
try:
import onnxruntime
return onnxruntime
except ImportError:
# only one of these packages should be installed at one time in an env
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
from launch import is_installed, run_pip
if not is_installed('onnxruntime'):
if system() == "Darwin":
package_name = "onnxruntime-silicon"
else:
package_name = "onnxruntime-gpu"
package = os.environ.get(
'ONNXRUNTIME_PACKAGE',
package_name
)

run_pip(f'install {package}', 'onnxruntime')

import onnxruntime
return onnxruntime


class WaifuDiffusionInterrogator(Interrogator):
""" Interrogator for Waifu Diffusion models """
def __init__(
self,
name: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
repo_id=None,
) -> None:
super().__init__(name)
self.repo_id = repo_id
self.model_path = model_path
self.tags_path = tags_path
self.tags = None
self.kwargs = kwargs

def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
self.model = None
self.tags = None

def download(self) -> None:
mdir = Path(shared.models_path, 'interrogators')
model_path = Path(hf_hub_download(**self.kwargs,
filename=self.model_path,
cache_dir=mdir))
tags_path = Path(hf_hub_download(**self.kwargs,
filename=self.tags_path,
cache_dir=mdir))
if self.repo_id is not None:
print(f"Loading {self.name} model file from {self.repo_id}")

self.model_path = hf_hub_download(self.repo_id, self.model_path,
cache_dir=mdir)
self.tags_path = hf_hub_download(self.repo_id, self.tags_path,
cache_dir=mdir)

download_model = {
'name': self.name,
'model_path': str(model_path),
'tags_path': str(tags_path),
'model_path': self.model_path,
'tags_path': self.tags_path,
}
mpath = Path(mdir, 'model.json')

Expand All @@ -411,56 +447,14 @@ def download(self) -> Tuple[os.PathLike, os.PathLike]:
with io.open(mpath, 'w') as filename:
json.dump(data, filename)

return model_path, tags_path

def get_model_path(self) -> Tuple[os.PathLike, os.PathLike]:
model_path = ''
tags_path = ''
mpath = Path(shared.models_path, 'interrogators', 'model.json')
try:
models = read_json(mpath).to_dict(orient='records')
i = next(i for i in models if i['name'] == self.name)
model_path = i['model_path']
tags_path = i['tags_path']
except Exception as e:
print(f'{mpath}: requires a name, model_ and tags_path: {repr(e)}')
model_path, tags_path = self.download()
return model_path, tags_path

def load(self) -> None:
if isinstance(self.model_path, str) or isinstance(self.tags_path, str):
model_path, tags_path = self.download()
else:
model_path = self.model_path
tags_path = self.tags_path

# only one of these packages should be installed a time in any one env
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
from launch import is_installed, run_pip
if not is_installed('onnxruntime'):
if system() == "Darwin":
package_name = "onnxruntime-silicon"
else:
package_name = "onnxruntime-gpu"
package = os.environ.get(
'ONNXRUNTIME_PACKAGE',
package_name
)

run_pip(f'install {package}', 'onnxruntime')

from onnxruntime import InferenceSession

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if use_cpu:
providers.pop(0)
self.download()
ort = get_onnxrt()
self.model = ort.InferenceSession(self.model_path,
providers=onnxrt_providers)

print(f'Loading {self.name} model from {model_path}, {tags_path}')
self.model = InferenceSession(str(model_path), providers=providers)
self.tags = read_csv(tags_path)
print(f'Loaded {self.name} model from {self.repo_id}')
self.tags = read_csv(self.tags_path)

def interrogate(
self,
Expand All @@ -470,7 +464,7 @@ def interrogate(
Dict[str, float] # tag confidences
]:
# init model
if not hasattr(self, 'model') or self.model is None:
if self.model is None:
self.load()

# code for converting the image and running the model is taken from the
Expand All @@ -481,15 +475,14 @@ def interrogate(
_, height, _, _ = self.model.get_inputs()[0].shape

# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = asarray(image)
image = dbimutils.fill_transparent(image)

image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]

tags = dict

image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(float32)
Expand Down Expand Up @@ -609,3 +602,69 @@ def pred_model(model):
QData.add_tag = orig_add_tags
del os.environ["TF_XLA_FLAGS"]
return ''


class MLDanbooruInterrogator(Interrogator):
def __init__(
self,
name: str,
repo_id: str,
model_path: str,
tags_path='classes.json'
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.repo_id = repo_id
self.tags = None
self.model = None

def download(self) -> Tuple[str, str]:
print(f"Loading {self.name} model file from {self.repo_id}")

model_path = hf_hub_download(
repo_id=self.repo_id, filename=self.model_path)
tags_path = hf_hub_download(
repo_id=self.repo_id, filename=self.tags_path)
return model_path, tags_path

def load(self) -> None:
self.model_path, self.tags_path = self.download()

ort = get_onnxrt()
self.model = ort.InferenceSession(self.model_path, providers=onnxrt_providers)

print(f'Loaded {self.name} model from {self.model_path}')

with open(self.tags_path, 'r', encoding='utf-8') as f:
self.tags = json.load(f)

def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
# init model
if self.model is None:
self.load()

image = dbimutils.fill_transparent(image)
image = dbimutils.resize(image, 448) # TODO CUSTOMIZE

x = np.asarray(image, dtype=np.float32) / 255
# HWC -> 1CHW
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)

input_ = self.model.get_inputs()[0]
output = self.model.get_outputs()[0]
# evaluate model
y, = self.model.run([output.name], {input_.name: x})

# Softmax
y = 1 / (1 + np.exp(-y))

tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())}
return {}, tags
13 changes: 9 additions & 4 deletions tagger/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def unload_interrogators() -> List[str]:

def check_for_errors(name) -> str:
errors = It.get_errors()
if name not in utils.interrogators:
if not any(i.name == name for i in utils.interrogators.values()):
errors += f"'{name}': invalid interrogator"

return errors
Expand All @@ -44,7 +44,8 @@ def on_interrogate(name: str, inverse=False) -> ItRetTP:
if err != '':
return (None, None, None, err)

interrogator: It = utils.interrogators[name]
such_name = (i for i in utils.interrogators.values() if name == i.name)
interrogator: It = next(such_name, None)
QData.inverse = inverse
return interrogator.batch_interrogate()

Expand Down Expand Up @@ -209,7 +210,11 @@ def on_ui_tabs():
# interrogator selector
with gr.Column():
with gr.Row(variant='compact'):
interrogator_names = utils.refresh_interrogators()
def refresh():
utils.refresh_interrogators()
return sorted(x.name for x in utils.interrogators
.values())
interrogator_names = refresh()
interrogator = utils.preset.component(
gr.Dropdown,
label='Interrogator',
Expand All @@ -224,7 +229,7 @@ def on_ui_tabs():
ui.create_refresh_button(
interrogator,
lambda: None,
lambda: {'choices': utils.refresh_interrogators()},
lambda: {'choices': refresh()},
'refresh_interrogator'
)

Expand Down
Loading

0 comments on commit b7918a2

Please sign in to comment.