-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocess_image.py
56 lines (48 loc) · 2.07 KB
/
process_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
from os import listdir
from os.path import isfile, join
import numpy as np
class CONFIG:
MEANS = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 3))
def reshape_and_normalize_image(image):
"""
Reshape and normalize the input image (content or style)
"""
# Substract the mean to match the expected input of VGG16
image = tf.subtract(image,CONFIG.MEANS)
return image
def preprocess_fn(image,height,width):
image = tf.image.resize_images(image,[height,width],method=0)
return image
def train_image(batch_size, height, width, path, epochs=2, shuffle=True):
filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
if not shuffle:
filenames = sorted(filenames)
png = filenames[0].lower().endswith('png') # If first file is a png, assume they all are
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs)
reader = tf.WholeFileReader()
_, img_bytes = reader.read(filename_queue)
image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
processed_image = preprocess_fn(image, height, width)
processed_image = reshape_and_normalize_image(processed_image)
return tf.train.batch([processed_image], batch_size, dynamic_pad=True)
def style_image(path,height, width):
img_bytes = tf.read_file(path)
if path.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
processed_image = preprocess_fn(image, height, width)
processed_image = reshape_and_normalize_image(processed_image)
image = tf.expand_dims(processed_image, 0)
return image
def get_eval_image(path, height, width):
img_bytes = tf.read_file(path)
if path.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
image.set_shape([height, width, 3])
image = tf.to_float(image)
processed_image = reshape_and_normalize_image(image)
return processed_image