forked from Jongchan/tensorflow-vdsr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathincrease_resolution.py
51 lines (44 loc) · 1.73 KB
/
increase_resolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
from scipy import misc
from PIL import Image
import tensorflow as tf
from model import model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--img")
args = parser.parse_args()
img_path = args.img
def im2double(img):
return img.astype(np.float)/255
def double2im(img):
return (img*255).astype(np.int)
def YCbCr2rgb(ycbcr_img):
img = ycbcr_img.astype(np.int)
rgb = np.zeros([img.shape[0],img.shape[1],img.shape[2]],dtype=np.float32)
rgb[:,:,0] = ((img[:,:,0] - 16)*298.082 + (img[:,:,2]-128)*408.583)/256
rgb[:,:,1] = ((img[:,:,0] - 16)*298.082 - (img[:,:,1]-128)*100.291 - (img[:,:,2]-128)*208.12)/256
rgb[:,:,2] = ((img[:,:,0] - 16)*298.082 + (img[:,:,1]-128)*516.411)/256
rgb = rgb.astype(np.int)
return rgb
if __name__ == '__main__':
ckpt_dir = './checkpoints'
img = misc.imread(img_path, mode='YCbCr')
img_Y ,img_Cb, img_Cr = img[:,:,0], img[:,:,1], img[:,:,2]
img_Y = im2double(img_Y)
with tf.Session() as sess:
input_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
shared_model = tf.make_template('shared_model', model)
output_tensor, weights = shared_model(input_tensor)
init = tf.global_variables_initializer()
sess.run(init)
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
saver = tf.train.Saver(weights)
saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))
output_Y = sess.run([output_tensor], feed_dict={input_tensor: np.resize(img_Y, (1, img_Y.shape[0], img_Y.shape[1], 1))})
output_Y = np.resize(output_Y, (img_Y.shape[0], img_Y.shape[1]))
out_path = img_path.split('.')[0] + 'out.' + img_path.split('.')[-1]
output_Y = double2im(output_Y)
img = img.astype(np.int)
img[:,:,0] = output_Y
rgb_img = YCbCr2rgb(img)
misc.imsave(out_path,rgb_img)