forked from BinRoot/TensorFlow-Book
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv_visuals.py
74 lines (56 loc) · 1.88 KB
/
conv_visuals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import matplotlib.pyplot as plt
import cifar_tools
import tensorflow as tf
names, data, labels = \
cifar_tools.read_data('/home/binroot/res/cifar-10-batches-py')
def show_conv_results(data, filename=None):
plt.figure()
rows, cols = 4, 8
for i in range(np.shape(data)[3]):
img = data[0, :, :, i]
plt.subplot(rows, cols, i + 1)
plt.imshow(img, cmap='Greys_r', interpolation='none')
plt.axis('off')
if filename:
plt.savefig(filename)
else:
plt.show()
def show_weights(W, filename=None):
plt.figure()
rows, cols = 4, 8
for i in range(np.shape(W)[3]):
img = W[:, :, 0, i]
plt.subplot(rows, cols, i + 1)
plt.imshow(img, cmap='Greys_r', interpolation='none')
plt.axis('off')
if filename:
plt.savefig(filename)
else:
plt.show()
raw_data = data[4, :]
raw_img = np.reshape(raw_data, (24, 24))
plt.figure()
plt.imshow(raw_img, cmap='Greys_r')
plt.savefig('input_image.png')
x = tf.reshape(raw_data, shape=[-1, 24, 24, 1])
W = tf.Variable(tf.random_normal([5, 5, 1, 32]))
b = tf.Variable(tf.random_normal([32]))
conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
conv_with_b = tf.nn.bias_add(conv, b)
conv_out = tf.nn.relu(conv_with_b)
k = 2
maxpool = tf.nn.max_pool(conv_out, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
W_val = sess.run(W)
show_weights(W_val, 'step0_weights.png')
conv_val = sess.run(conv)
show_conv_results(conv_val, 'step1_convs.png')
print(np.shape(conv_val))
conv_out_val = sess.run(conv_out)
show_conv_results(conv_out_val, 'step2_conv_outs.png')
print(np.shape(conv_out_val))
maxpool_val = sess.run(maxpool)
show_conv_results(maxpool_val, 'step3_maxpool.png')
print(np.shape(maxpool_val))