forked from AlexeyAB/darknet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdarknet_images_sets.py
97 lines (83 loc) · 4.11 KB
/
darknet_images_sets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Author: LiaoSteve
"""
from ctypes import *
import random
import os
import cv2
import darknet
import argparse
sets=[('2007', 'test'), ('2007', 'val'), ('2007', 'train')]
def parser():
parser = argparse.ArgumentParser(description="YOLO Object Detection")
parser.add_argument("--dataset_list", type=str, default="./data/",
help="path to your image set ")
parser.add_argument("--save_dir", type=str, default="./predict_image/tiny-best/",
help="path to save detection images")
parser.add_argument("--weights", default="./backup/yolov4-tiny-3l-2_best.weights",
help="yolo weights path")
parser.add_argument("--config_file", default="./cfg/tiny/yolov4-tiny-3l-2.cfg",
help="path to config file")
parser.add_argument("--data_file", default="./data/obj.data",
help="path to data file")
parser.add_argument("--thresh", type=float, default=.25,
help="remove detections with confidence below this value")
parser.add_argument("--iou_thresh", type=float, default=.45,
help="nms: remove detections with iou higher this value")
return parser.parse_args()
def check_arguments_errors(args):
assert 0 < args.thresh < 1, "Threshold should be a float between zero and one (non-inclusive)"
assert 0 < args.iou_thresh < 1, "Threshold should be a float between zero and one (non-inclusive)"
if not os.path.exists(args.config_file):
raise(ValueError("Invalid config path {}".format(os.path.abspath(args.config_file))))
if not os.path.exists(args.weights):
raise(ValueError("Invalid weight path {}".format(os.path.abspath(args.weights))))
if not os.path.exists(args.data_file):
raise(ValueError("Invalid data file path {}".format(os.path.abspath(args.data_file))))
if not os.path.exists(args.dataset_list):
raise(ValueError("Invalid image set file path {}".format(os.path.abspath(args.data_file))))
os.makedirs(args.save_dir, exist_ok=1)
if __name__ == '__main__':
args = parser()
check_arguments_errors(args)
network, class_names, class_colors = darknet.load_network(
args.config_file,
args.data_file,
args.weights,
batch_size=1
)
darknet_width = darknet.network_width(network)
darknet_height = darknet.network_height(network)
darknet_image = darknet.make_image(darknet_width, darknet_height, 3)
info = dict()
for year, set in sets:
save_dir = args.save_dir + set +'/'
os.makedirs(save_dir, exist_ok=1)
temps = list()
images = list()
f = open(args.dataset_list+year+'_'+set+'.txt')
for filename in f.readlines():
filename = filename.strip('\n')
temps.append(filename)
f.close()
for filename in temps:
if filename.endswith('jpg') or filename.endswith('png')\
or filename.endswith('jpeg') or filename.endswith('JPG')\
or filename.endswith('JPEG') or filename.endswith('PNG'):
images.append(filename)
else:
raise RuntimeError(f'notice that {filename} image format are not accepted(.jpg, .png, .jpeg)')
for image in images:
frame = cv2.imread(image)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (darknet_width, darknet_height),
interpolation=cv2.INTER_LINEAR)
darknet.copy_image_from_bytes(darknet_image, frame_resized.tobytes())
detections = darknet.detect_image(network, class_names, darknet_image, thresh=args.thresh, nms=args.iou_thresh)
frame = darknet.draw_boxes(detections, frame, class_colors, darknet_width)
cv2.imwrite(save_dir + image.split('/')[-1], frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
print(f'- [x] save image {image} to {save_dir}')
info[set]= len(images)
del temps
del images
print(info)