-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_prep.py
96 lines (67 loc) · 2.9 KB
/
data_prep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import tensorflow as tf
import numpy as np
import cv2
from model import INPUT_SHAPE
def class_map_road(seg):
# map class 0=anything, 1=road
return tf.where(seg == 7, [0, 1.0], [1.0, 0])
def cityscapes_prep(output_shape, input_shape=INPUT_SHAPE, class_map_func=None, float_range=True):
def prep_map(sample):
img = sample['image_left']
seg = sample['segmentation_label']
if float_range:
img /= 255
img = tf.image.resize(img, input_shape[0:2])
seg = tf.image.resize(seg, output_shape[0:2])
if callable(class_map_func):
seg = class_map_func(seg)
else:
seg = tf.keras.utils.to_categorical(seg, num_classes=output_shape[-1])
return img, seg
return prep_map
def create_labelme_segmentation(contents):
meta = json.loads(contents.numpy().decode('utf-8'))
seg = np.zeros((meta['imageHeight'], meta['imageWidth']))
# TODO: create a good way to have multiple classes
for shape in meta['shapes']:
points = np.array(shape['points']).astype(np.int32)
cv2.fillPoly(seg, [points], (1))
return seg, meta['imageHeight'], meta['imageWidth']
def labelme_prep(output_shape, input_shape, float_range=True):
def labelme_map(json_file):
contents = tf.io.read_file(json_file)
seg, h, w = tf.py_function(create_labelme_segmentation, [contents], [tf.float32, tf.int32, tf.int32])
seg = tf.reshape(seg, (h, w, 1))
seg = tf.image.resize(seg, output_shape[0:2], method='nearest')
seg = tf.where(seg == 1, [0, 1.0], [1.0, 0])
jpeg_filename = tf.strings.regex_replace(json_file, '\.json', '.jpg')
jpeg_contents = tf.io.read_file(jpeg_filename)
img = tf.io.decode_jpeg(jpeg_contents, channels=3)
img = tf.image.resize(img, input_shape[0:2])
if float_range:
img /= 255
return img, seg
return labelme_map
def uwula_prep(output_shape, input_shape, float_range=True):
color_map = tf.constant([
(0,0,0), # background
(32,224,224) # road
], dtype=tf.uint8)
def uwula_map(jpeg_filename):
jpeg_contents = tf.io.read_file(jpeg_filename)
img = tf.io.decode_jpeg(jpeg_contents, channels=3)
img = tf.image.resize(img, input_shape[0:2])
if float_range:
img /= 255
seg_filename = tf.strings.regex_replace(jpeg_filename, '\.jpg', '.png')
seg_contents = tf.io.read_file(seg_filename)
seg = tf.io.decode_png(seg_contents, channels=3)
seg = tf.image.resize(seg, output_shape[0:2], method='nearest')
class_map = tf.map_fn(
lambda color: tf.reduce_all(tf.equal(seg, color), axis=-1),
color_map, fn_output_signature=tf.bool)
class_map = tf.transpose(class_map, perm=(1,2,0))
class_map = tf.cast(class_map, tf.float32)
return img, class_map
return uwula_map