-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize.py
41 lines (29 loc) · 1.12 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# coding: UTF-8
import os
import numpy as np
from PIL import Image
import chainer
def out_generated_image(gen, rows, cols, seed, iteration, xp):
np.random.seed(seed)
n_images = rows * cols
z = chainer.Variable(xp.asarray(gen.make_hidden(n_images)))
# seedのクリア
np.random.seed()
# ラベル0*10 -> 1*10 -> ...なるarrayを用意
labels = [i for i in range(rows) for j in range(cols)]
# one_hotラベルをバッチ数分作成
from train import create_one_hot_label
labels = xp.asarray(create_one_hot_label(10, labels)).reshape(n_images, 10, 1, 1)
with chainer.using_config('train', False):
x = gen(z, labels)
x = chainer.cuda.to_cpu(x.data)
x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
_, _, H, W = x.shape
x = x.reshape((rows, cols, 1, H, W))
x = x.transpose(0, 3, 1, 4, 2)
x = x.reshape((rows * H, cols * W))
preview_dir = 'images/train'
preview_path = preview_dir + '/image_iteration_{:0>8}.png'.format(iteration)
if not os.path.exists(preview_dir):
os.makedirs(preview_dir)
Image.fromarray(x).save(preview_path)