-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval_des3.py
74 lines (61 loc) · 2.56 KB
/
eval_des3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import random
import socket
import yaml
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import models
import datasets
import utils
from models import DenoisingDiffusion, DiffusiveRestoration
def parse_args_and_config():
parser = argparse.ArgumentParser(description='DeS3')
parser.add_argument("--config", type=str, required=True,
help="Path to the config file")
parser.add_argument('--resume', default='', type=str,
help='Path for the DeS3 model checkpoint to load for evaluation')
parser.add_argument("--grid_r", type=int, default=16,
help="Grid cell width r that defines the overlap between patches")
parser.add_argument("--sampling_timesteps", type=int, default=25,
help="Number of implicit sampling steps")
parser.add_argument("--test_set", type=str, default='SRD',
help="restoration test set options: ['SRD, AISTD, LRSS, UIUC']")
parser.add_argument("--image_folder", default='results/real_images/', type=str,
help="Location to save restored images")
parser.add_argument('--seed', default=61, type=int, metavar='N',
help='Seed for initializing training (default: 61)')
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config.device = device
if torch.cuda.is_available():
print('Note: GPU!')
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
DATASET = datasets.__dict__[config.data.dataset](config)
_, val_loader = DATASET.get_loaders(parse_patches=False, validation=args.test_set)
diffusion = DenoisingDiffusion(args, config)
model = DiffusiveRestoration(diffusion, args, config)
model.restore(val_loader, validation=args.test_set, r=args.grid_r)
if __name__ == '__main__':
main()