-
Notifications
You must be signed in to change notification settings - Fork 29
/
main.py
63 lines (42 loc) · 1.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Demo - train the DeepFuse network & use it to generate an image
from __future__ import print_function
import time
from train_recons import train_recons
from generate import generate
from utils import list_images
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# IS_TRAINING = True
IS_TRAINING = False
BATCH_SIZE = 2
EPOCHES = 4
MODEL_SAVE_PATH = './models/deepfuse_models/deepfuse_model_bs2_epoch4_all.ckpt'
# model_pre_path is just a pre-train model and not necessary. It is set as None when you want to train your own model.
# model_pre_path = 'your own pre-train model'
model_pre_path = None
def main():
if IS_TRAINING:
original_imgs_path = list_images('D:/ImageDatabase/Image_fusion_MSCOCO/original/')
print('\nBegin to train the network ...\n')
train_recons(original_imgs_path, MODEL_SAVE_PATH, model_pre_path, EPOCHES, BATCH_SIZE, debug=True)
print('\nSuccessfully! Done training...\n')
else:
output_save_path = 'outputs'
# sourceA_name = 'image'
# sourceB_name = 'image'
sourceA_name = 'IR'
sourceB_name = 'VIS'
print('\nBegin to generate pictures ...\n')
content_name = 'images/IV_images/' + sourceA_name
style_name = 'images/IV_images/' + sourceB_name
for i in range(1):
index = i + 1
content_path = content_name + str(index) + '.png'
style_path = style_name + str(index) + '.png'
# content_path = content_name + str(index) + '_left.png'
# style_path = style_name + str(index) + '_right.png'
generate(content_path, style_path, MODEL_SAVE_PATH, model_pre_path, index, output_path=output_save_path)
# print('\ntype(generated_images):', type(generated_images))
# print('\nlen(generated_images):', len(generated_images), '\n')
if __name__ == '__main__':
main()