-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathremove_bg.py
109 lines (86 loc) · 3.01 KB
/
remove_bg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import functools
import io
import numpy as np
from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from scipy.ndimage.morphology import binary_erosion
from .u2net import detect
def alpha_matting_cutout(
img,
mask,
foreground_threshold,
background_threshold,
erode_structure_size,
base_size,
):
size = img.size
img.thumbnail((base_size, base_size), Image.LANCZOS)
mask = mask.resize(img.size, Image.LANCZOS)
img = np.asarray(img)
mask = np.asarray(mask)
# guess likely foreground/background
is_foreground = mask > foreground_threshold
is_background = mask < background_threshold
# erode foreground/background
structure = None
if erode_structure_size > 0:
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
# build trimap
# 0 = background
# 128 = unknown
# 255 = foreground
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
# build the cutout image
img_normalized = img / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
foreground = estimate_foreground_ml(img_normalized, alpha)
cutout = stack_images(foreground, alpha)
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
cutout = Image.fromarray(cutout)
cutout = cutout.resize(size, Image.LANCZOS)
return cutout
def naive_cutout(img, mask):
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
return cutout
@functools.lru_cache(maxsize=None)
def get_model(model_name):
if model_name == "u2netp":
return detect.load_model(model_name="u2netp")
if model_name == "u2net_human_seg":
return detect.load_model(model_name="u2net_human_seg")
else:
return detect.load_model(model_name="u2net")
def remove(
data,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10,
alpha_matting_base_size=1000,
):
model = get_model(model_name)
img = Image.open(io.BytesIO(data)).convert("RGB")
mask = detect.predict(model, np.array(img)).convert("L")
if alpha_matting:
cutout = alpha_matting_cutout(
img,
mask,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_structure_size,
alpha_matting_base_size,
)
else:
cutout = naive_cutout(img, mask)
bio = io.BytesIO()
cutout.save(bio, "PNG")
return bio.getbuffer()