-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtest.py
73 lines (63 loc) · 3.01 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import sys
import time
import os
import csv
import torch
from util import Logger, printSet
from validate import validate
from networks.resnet import resnet50
from options.test_options import TestOptions
import networks.resnet as resnet
import numpy as np
import random
import random
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
seed_torch(100)
DetectionTests = {
'ForenSynths': { 'dataroot' : '/opt/data/private/DeepfakeDetection/ForenSynths/',
'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop' : True,
},
'GANGen-Detection': { 'dataroot' : '/opt/data/private/DeepfakeDetection/GANGen-Detection/',
'no_resize' : True,
'no_crop' : True,
},
'DiffusionForensics': { 'dataroot' : '/opt/data/private/DeepfakeDetection/DiffusionForensics/',
'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop' : True,
},
'UniversalFakeDetect': { 'dataroot' : '/opt/data/private/DeepfakeDetection/UniversalFakeDetect/',
'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop' : True,
},
}
opt = TestOptions().parse(print_options=False)
print(f'Model_path {opt.model_path}')
# get model
model = resnet50(num_classes=1)
model.load_state_dict(torch.load(opt.model_path, map_location='cpu'), strict=True)
model.cuda()
model.eval()
for testSet in DetectionTests.keys():
dataroot = DetectionTests[testSet]['dataroot']
printSet(testSet)
accs = [];aps = []
print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))
for v_id, val in enumerate(os.listdir(dataroot)):
opt.dataroot = '{}/{}'.format(dataroot, val)
opt.classes = '' #os.listdir(opt.dataroot) if multiclass[v_id] else ['']
opt.no_resize = DetectionTests[testSet]['no_resize']
opt.no_crop = DetectionTests[testSet]['no_crop']
acc, ap, _, _, _, _ = validate(model, opt)
accs.append(acc);aps.append(ap)
print("({} {:12}) acc: {:.1f}; ap: {:.1f}".format(v_id, val, acc*100, ap*100))
print("({} {:10}) acc: {:.1f}; ap: {:.1f}".format(v_id+1,'Mean', np.array(accs).mean()*100, np.array(aps).mean()*100));print('*'*25)