-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathimaginarynet.py
98 lines (93 loc) · 2.66 KB
/
imaginarynet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from gen_prompt import generating_prompts
from gen_image import p_generator
from gen_image import get_xml
from CLIP_filter import CLIP_filter
from tqdm import tqdm
import os
import argparse
import time
import random
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--seed",
type=int,
default=0,
help="random seed",
)
parser.add_argument(
"--clip",
action='store_true',
help="use CLIP as filter",
)
parser.add_argument(
"--cpu",
action='store_true',
help="use CLIP on CPU",
)
parser.add_argument(
"--backend",
type=str,
choices=["dalle-mini","stablediffusion"],
help="choose the backend"
)
parser.add_argument(
"--num",
type=int,
default=1250,
help="the number of pictures per prompt"
)
parser.add_argument(
"--classfile",
type=str,
default="COCO.txt",
help="file which records the name of each class"
)
parser.add_argument(
"--outputdir",
type=str,
default="./dallemini_coco_GPT_10W",
help="output dir"
)
parser.add_argument(
"--gpt",
action="store_true",
help="whether using GPT2 to generate descriptions or not"
)
parser.add_argument(
"--threshold",
type=float,
default=0.6,
help="the threshold of clip in filtering the images"
)
opt = parser.parse_args()
random.seed(opt.seed)
if not os.path.exists(f"{opt.outputdir}"):
os.mkdir(f"{opt.outputdir}")
os.mkdir(f"{opt.outputdir}/annotation")
os.mkdir(f"{opt.outputdir}/image")
print("generating prompts:")
objects=generating_prompts(opt.num,opt.classfile,f"{opt.outputdir}/prompt.txt",opt.gpt)
print("\nstarting generating images\n")
pg=p_generator(opt.backend)
with open(f"{opt.outputdir}/prompt.txt") as reader:
lines = reader.readlines()
all_idx = 0
for idx in tqdm(range(0, len(lines))):
line = lines[idx]
line = line.strip()
n = line.count("\t")
txt, obj= line.split("\t")
print(txt)
objs = [obj]
images, detect_results, clip_score = pg.gen_group(txt, 1)
if opt.clip:
images=CLIP_filter(images,obj,objects,threshold=opt.threshold,cpu=opt.cpu)
for n, img in enumerate(images):
file_name = "{:0>6d}".format(all_idx)
#file_name = "{:0>3d}".format(n)
img.save(f"{opt.outputdir}/image/{file_name}.jpg")
get_xml(all_idx, objs,opt.outputdir)
all_idx += 1
if __name__=="__main__":
main()