-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSightSpotDetector.py
199 lines (168 loc) · 6.81 KB
/
SightSpotDetector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python
"""
SightSpotDetector is fully-functional visual attention framework.
It allows the following:
1) Calculation of saliency map for given image (grayscale or heatmap-colored).
2) Obtain pixel-precise saliency map of given image.
3) Extract main objects using specific algorithm.
Framework uses lazy calculation so it`s doing less worthless work.
"""
__author__ = 'Igor Ryabtsov aka Tinnulion'
__copyright__ = "Copyright (c) 2014"
__license__ = "Apache 2.0"
__version__ = "1.0"
import math
import numpy
import scipy.ndimage
from PIL import Image
import SightSpotUtil
class SightSpotDetector():
# Global constants.
_SMALL_BLUR_RATIO = 0.1
_LARGE_BLUR_RATIO = 0.15
_CELL_SIZE_RATIO = 1.0
_SLIC_ALPHA = 0.25
_ITERATION_NUMBER = 4
def __init__(self, image, grain=0.1):
"""
Public constructor of SightSpotDetector.
Parameters
----------
image : PIL.Image, numpy.ndarray or string
Should be either 3 channel RGB image with 8bit channels (PIL.Image or numpy array) or filename.
grain : floating-point value
Value should be between 0.0 and 1.0.
"""
rgb_image = image
if isinstance(rgb_image, basestring):
rgb_image = Image.open(rgb_image)
if isinstance(rgb_image, Image.Image):
rgb_image = numpy.asarray(rgb_image, dtype='float32')
if not isinstance(rgb_image, numpy.ndarray):
raise Exception('Argument "image" should be either PIL.Image, numpy.ndarray or string filename!')
self._rgb_image = rgb_image
self._orgb_image = None
self._saliency_map = None
self._fusion_map = None
self._raw_heatmap = None
self._precise_heatmap = None
(width, height) = self._rgb_image.shape[1], self._rgb_image.shape[0]
area = float(width * height)
dim = math.sqrt(area) * grain
self._small_sigma = self._SMALL_BLUR_RATIO * dim
self._large_sigma = self._LARGE_BLUR_RATIO * dim
self._cell_size = self._CELL_SIZE_RATIO * dim
if self._cell_size <= 4.0:
self._cell_size = 4.0
def _get_orgb_image(self):
if self._orgb_image is None:
self._orgb_image = SightSpotUtil.eval_orgb_image(self._rgb_image)
return self._orgb_image
def _get_saliency_map(self, type):
orgb_image = self._get_orgb_image()
if type == 'raw':
if self._saliency_map is None:
small_sigma = self._small_sigma
large_sigma = self._large_sigma
self._saliency_map = SightSpotUtil.eval_saliency_map(orgb_image, small_sigma, large_sigma, 'auto')
return self._saliency_map
if type == 'precise':
if self._fusion_map is None:
saliency_map = self._get_saliency_map(type='raw')
cell_size = self._cell_size
alpha = self._SLIC_ALPHA
iterations = self._ITERATION_NUMBER
segmentation_map = SightSpotUtil.eval_slic_map(orgb_image, cell_size, alpha, iterations)
self._fusion_map = SightSpotUtil.combine_saliency_and_segmentation(saliency_map, segmentation_map)
return self._fusion_map
raise Exception('Unknown argument value type = "' + str(type) + '"')
def get_saliency_map(self, type='precise'):
"""
Returns saliency map of image passed to SightSpotDetector constructor.
Parameters
----------
type : 'raw' or 'precise'
Specifies whenever segmentation will be used to improve saliency map (slower).
Returns
-------
out : ndarray
Saliency map.
"""
return self._get_saliency_map(type).copy()
def get_heatmap(self, type='raw'):
"""
Returns heatmap calculated from saliency map of image passed to SightSpotDetector constructor.
Parameters
----------
type : 'raw' or 'precise'
Specifies whenever segmentation will be used to improve saliency map (slower).
Returns
-------
out : PIL.Image
Heatmap.
"""
if type == 'raw':
if self._raw_heatmap == None:
saliency_map = self._get_saliency_map(type)
self._raw_heatmap = SightSpotUtil.eval_heatmap(saliency_map)
return self._raw_heatmap.copy()
if type == 'precise':
if self._precise_heatmap == None:
fusion_map = self._get_saliency_map(type)
self._precise_heatmap = SightSpotUtil.eval_heatmap(fusion_map)
return self._precise_heatmap.copy()
raise Exception('Unknown argument value type = "' + str(type) + '"')
def threshold(self, source, value='auto'):
"""
Returns black-and-white image - thresholded version of saliency map.
Parameters
----------
source : 'raw' or 'precise'
Which saliency map to use for thresholding.
type : float or 'auto'
Specifies thresholding value.
Returns
-------
out : PIL.Image
Binary image.
"""
saliency_map = self._get_saliency_map(source)
return Image.fromarray(255.0 * SightSpotUtil.threshold(saliency_map, value)).convert('RGB')
def get_foreground(self, source, value='auto'):
"""
Removes background pixels from input image and returns result.
Parameters
----------
source : 'raw' or 'precise'
Which saliency map to use for thresholding.
type : float or 'auto'
Specifies thresholding value.
Returns
-------
out : PIL.Image
Image without background pixels.
"""
saliency_map = self._get_saliency_map(source)
return Image.fromarray(SightSpotUtil.remove_background(self._rgb_image, saliency_map, value))
def cut_objects(self, source, value='auto'):
"""
Extract connected components from thresholded saliency map and cuts images.
Parameters
----------
source : 'raw' or 'precise'
Which saliency map to use for thresholding.
type : float or 'auto'
Specifies thresholding value.
Returns
-------
out : list of PIL.Images
Detected objects.
"""
saliency_map = self._get_saliency_map(source)
result = SightSpotUtil.detect_objects(self._rgb_image, saliency_map, value)
windows = []
for item in result:
windows.append(Image.fromarray(item))
return windows
if __name__ == "__main__":
print __doc__