-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclean_fid.py
69 lines (56 loc) · 2.7 KB
/
clean_fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os, argparse, sys
# sys.path.insert(0, os.getcwd())
# sys.path.insert(0, "clean_fid")
from glob import glob
# from cleanfid import fid
from source_override.clean_fid.cleanfid import fid
# from source_override.fid import compute_fid, compute_kid
parser = argparse.ArgumentParser()
parser.add_argument("--outfile", type=str, default="/home/xinrui/projects/SelfCascade/SDXL/results/sdxl_selfcascade/pexels_txtilm7b_resized_v2_step700_100/inference_image/laion_wofinetune.txt", help="")
parser.add_argument("--real", type=str, default="/home/xinrui/projects/SelfCascade/SDXL/dataset/ScaleCrafter/images_test", help="")
parser.add_argument("--fake", type=str, default="/home/xinrui/projects/SelfCascade/SDXL/results/sdxl_selfcascade/pexels_txtilm7b_resized_v2_step700_100/inference_image/laion_finetune/upscale_image", help="")
parser.add_argument("--test_file", type=str, default="/home/xinrui/projects/SelfCascade/SDXL/dataset/ScaleCrafter/test.txt", help="")
parser.add_argument("--random_crop", action="store_true", help="")
parser.add_argument("--output_block", default=3, choices=[3, 2], type=int, help="set to 2 if sFID")
parser.add_argument("--n_crops_per_img", default=3, type=int, help="")
parser.add_argument("--nimgs_real", default=None, type=int, help="")
opt = parser.parse_args()
outfile=opt.outfile
real=opt.real
fake=opt.fake
with open(opt.test_file, 'r') as f:
prompt_files = [line.strip() for line in f.readlines()]
if fake.endswith("txt"):
# multi fake files
with open(fake, "r") as f:
fakes = f.readlines()
fakes = [f.strip() for f in fakes]
else:
fakes = [fake]
os.makedirs(os.path.dirname(outfile), exist_ok=True)
for fake in fakes:
f=open(outfile, 'a')
print("start fid")
# nfake = os.path.join(fake, os.path.splitext(filename)[0]+"_upscale.png")
# nreal = os.path.join(real, os.path.splitext(filename)[0]+".jpg")
sfid = fid.compute_fid(real, fake,
mode="center_crop",
random_crop=False,
output_blocks=[opt.output_block],
n_crops_per_img=opt.n_crops_per_img,
dataset_res=2048)
print(f'fid={sfid}')
print(f'fake path: {fake}', file=f)
print(f'real path: {real}', file=f)
print(f'fid={sfid}',file=f)
print("start kid")
skid = fid.compute_kid(real, fake,
mode="center_crop",
random_crop=False,
output_blocks=[opt.output_block],
n_crops_per_img=opt.n_crops_per_img,
dataset_res=2048)
print(f'kid={skid}')
print(f'kid={skid}',file=f)
print(f'\n',file=f)
f.close()