-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
88 lines (77 loc) · 3.36 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import os
import numpy as np
from src.utils import list_models
from src.prediction import PredictorFatory
import cv2
import copy
import operator
EXAMPLE_PATH = './assets/examples'
MODEL_PATH = './model/'
DEFAULT_MODEL = 'xception_299_299.hdf5'
EXAMPLES = [[EXAMPLE_PATH + '/' + file, DEFAULT_MODEL, True] for file in os.listdir(EXAMPLE_PATH)]
CLASSES = ['Figure 人物 ', 'Item 道具','Landscape 自然风光', 'Machine 机械', 'City 城市建筑', 'Indoor 室内']
LABELS = ['Figure', 'Item','Landscape', 'Machine', 'City', 'Indoor']
MATTING_WEIGHT = './seg/onnx_mobilenetv2_hd.onnx'
SEG_MODEL = './seg/seg.h5'
article = """
<p style='text-align: center'>
<a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a>
</p>
"""
def entry(image, model_choice, enhance=False):
factory = PredictorFatory(CLASSES, MODEL_PATH + model_choice)
cls_ret = None
if 'xception' in model_choice:
predictor = factory.get_predictor('xception')
if 'matting' in model_choice:
mat_pridictior = factory.get_matting(MATTING_WEIGHT)
pha, fgr = mat_pridictior.predict(image)
cls_ret = predictor.predict(image)
return cls_ret, None
elif enhance:
segmentor = factory.get_seg(SEG_MODEL)
seg = segmentor.predict(copy.deepcopy(image))
inv_mask = seg.astype(bool)
inv_mask = ~inv_mask
inv_mask = inv_mask.astype(np.uint8)
cls_image = copy.deepcopy(image)*np.expand_dims(inv_mask, axis=-1)
cls_ret = predictor.predict(cls_image)
# max connected components
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(seg, connectivity=8)
area = np.sum(labels)
weight = 0.9
idx = 0
for i in range(num_labels-1):
if stats[i+1][-1] > area:
idx = i+1
if area > image.shape[0] * image.shape[1] * 0.02:
cls_ret[CLASSES[0]] = weight + (1-weight) * cls_ret[CLASSES[0]]
cv2.putText(image, LABELS[0], (int(centroids[idx][0]), int(centroids[idx][1])), cv2.FONT_HERSHEY_SIMPLEX, 1, (55,255,155), 2)
contours, hierarchy = cv2.findContours(seg.astype(np.uint8), cv2.RETR_LIST, 2)
cv2.drawContours(image, contours, -1, (0, 0, 255), 2)
text = ""
a = sorted(copy.deepcopy(cls_ret).items(),key=operator.itemgetter(1), reverse=True)
for i, value in enumerate(a):
pred = value[1]
name = value[0]
if pred != 0 and pred >= 0.05:
text += " "+name.split(" ")[0] + ": " + str(round(pred, 2))
cv2.putText(image, text, (int(image.shape[1]/7), 35), cv2.FONT_HERSHEY_SIMPLEX, .5 * image.shape[1] / 557 , (55,255,155), int(1.78 * image.shape[1] / 557))
# cv2.imshow("seg", seg)
# cv2.imshow("cls_image", cls_image)
# cv2.imshow("image", cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
# cv2.waitKey(0)
# cv2.destroyAllWindows()
return cls_ret, image
else:
cls_ret = predictor.predict(image)
return cls_ret, None
elif "SWIN2" in model_choice:
pass
gr.Interface(fn=entry,
description=article,
inputs=[gr.Image(), gr.inputs.Dropdown(choices=list_models(MODEL_PATH), type="value", default=DEFAULT_MODEL), gr.inputs.Checkbox(default=False)],
outputs=[gr.Label(num_top_classes=len(CLASSES)), gr.Image()],
examples=EXAMPLES,
title = "Anime Theme Classification").launch()