-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_model.py
104 lines (83 loc) · 3.17 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from autodistill.detection import CaptionOntology
import supervision as sv
import argparse
from distutils.dir_util import copy_tree
import shutil
import os
from glob import glob
from utils import load_config
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ontology=CaptionOntology({
"blazer": "blazer",
"denim jacket": "denim jacket",
"leather jacket": "leather jacket",
"coat": "coat",
"windbreaker jacket": "windbreaker jacket",
"cardigan": "cardigan",
"puffer": "puffer",
"tee shirt": "tee shirt",
"long sleeve shirt": "long sleeve shirt",
"tank top": "tank top",
"shirt": "shirt",
"polo shirt": "polo shirt",
"sweat shirt": "sweat shirt",
"hoodie sweat shirt": "hoodie sweat shirt",
"knit sweater": "knit sweater",
"dress": "dress",
"jeans": "jeans",
"slacks": "slacks",
"sweat pants": "sweat pants",
"skirt": "skirt",
"short pants": "shorts pants",
"sneakers": "sneakers",
"dress shoes": "dress shoes",
"sandals": "sandals"
})
def main(parser):
cfg = parser.parse_args()
cfg = load_config(cfg.config)
if cfg['model'] == "GroundedSAM":
from autodistill_grounded_sam import GroundedSAM
base_model = GroundedSAM(ontology=ontology)
dataset = base_model.label(
input_folder=cfg['input_dir'],
extension=".jpg",
output_folder=cfg['output_dir'])
elif cfg['model'] == "GroundingDino":
from autodistill_grounding_dino import GroundingDino
base_model = GroundingDino(ontology=ontology)
dataset = base_model.label(
input_folder=cfg['input_dir'],
extension=".jpg",
output_folder=cfg['output_dir'])
dataset = sv.DetectionDataset.from_yolo(
images_directory_path=cfg['output_dir']+'/train/images',
annotations_directory_path=cfg['output_dir']+'/train/labels',
data_yaml_path=cfg['output_dir']+'/data.yaml')
if cfg['data_merge']:
copy_tree(cfg['main_data_dir'], cfg['merge_data_dir'])
data_yaml_path = cfg['merge_data_dir']+'/data.yaml'
new_data_yaml = ''
with open(data_yaml_path, 'r') as f:
lines = f.readlines()
for i, l in enumerate(lines):
new_data_yaml += l.replace('main_data', 'augmented_data')
with open(data_yaml_path,'w') as f:
f.write(new_data_yaml)
train_data_path = cfg['merge_data_dir']+'/train/images'
train_label_path = cfg['merge_data_dir']+'/train/labels'
for file in glob(cfg['output_dir']+'/*/images/*'):
file_label = file.replace('jpg', 'txt').replace('images', 'labels')
file_name = file.split('/')[-1].split('.')[0]
shutil.copy(file, train_data_path+f'/{file_name}.jpg')
shutil.copy(file_label, train_label_path+f'/{file_name}.txt')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
default='./config/base_model.yaml',
help="Set the config to train base model."
)
main(parser)