Skip to content

Commit

Permalink
improve yolo performance
Browse files Browse the repository at this point in the history
  • Loading branch information
WongGawa committed Jan 9, 2025
1 parent 2382df1 commit 470b829
Show file tree
Hide file tree
Showing 22 changed files with 52 additions and 38 deletions.
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov3/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov4/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov5/hyp.scratch-high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov5/hyp.scratch-low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.p5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.p6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/seg/hyp.scratch.high.seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
max_object_num: 600

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms: {
stage_epochs: [ 290, 10 ],
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov9/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolox/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ img_size: 640
sync_bn: False

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms: {
stage_epochs: [ 285, 15 ],
Expand Down
2 changes: 1 addition & 1 deletion docs/en/tutorials/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ This part of the parameters is defined in [configs/yolov3/hyp.scratch.yaml](http
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/tutorials/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ data:
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down
48 changes: 31 additions & 17 deletions mindyolo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(
self.is_training = is_training

# set column names
self.column_names_getitem = ['samples']
self.column_names_getitem = ['im_file', 'cls', 'bboxes', 'segments', 'keypoints', 'bbox_format', 'segment_format',
'img', 'ori_shape', 'hw_scale', 'hw_pad'] if self.is_training else ['samples']
if self.is_training:
self.column_names_collate = ['images', 'labels']
if self.return_segments:
Expand Down Expand Up @@ -169,7 +170,10 @@ def __init__(
self.batch = bi # batch index of image

# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs, self.img_hw_ori, self.indices = None, None, range(n)
self.imgs, self.img_hw_ori, self.indices = [None] * n, [None] * n, range(n)
# Buffer thread for mosaic images
self.buffer = []
self.max_buffer_length = min((n, batch_size * 8, 1000)) if self.augment else 0

# Rectangular Train/Test
if self.rect:
Expand Down Expand Up @@ -313,6 +317,14 @@ def __getitem__(self, index):
sample = getattr(self, func_name)(sample, **_trans)

sample['img'] = np.ascontiguousarray(sample['img'])
if self.is_training:
train_sample = []
for col_name in self.column_names_getitem:
if sample.get(col_name) is None:
train_sample.append(np.nan)
else:
train_sample.append(sample.get(col_name, np.nan))
return tuple(train_sample)
return sample

def __len__(self):
Expand All @@ -321,7 +333,8 @@ def __len__(self):
def get_sample(self, index):
"""Get and return label information from the dataset."""
sample = deepcopy(self.labels[index])
if self.imgs is None:
img = self.imgs[index]
if img is None:
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, "Image Not Found " + path
Expand All @@ -331,8 +344,13 @@ def get_sample(self, index):
interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp)

if self.augment:
self.imgs[index], self.img_hw_ori[index] = img, np.array([h_ori, w_ori]) # img, hw_original
self.buffer.append(index)
if 1 < len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.imgs[j], self.img_hw_ori[j] = None, np.array([None, None])
sample['img'], sample['ori_shape'] = img, np.array([h_ori, w_ori]) # img, hw_original

else:
sample['img'], sample['ori_shape'] = self.imgs[index], self.img_hw_ori[index] # img, hw_original

Expand Down Expand Up @@ -367,7 +385,7 @@ def _mosaic4(self, sample):
# loads images in a 4-mosaic
classes4, bboxes4, segments4 = [], [], []
mosaic_samples = [sample, ]
indices = random.choices(self.indices, k=3) # 3 additional image indices
indices = random.choices(self.buffer, k=3) # 3 additional image indices

segments_is_list = isinstance(sample['segments'], list)
if segments_is_list:
Expand Down Expand Up @@ -444,7 +462,7 @@ def _mosaic9(self, sample):
# loads images in a 9-mosaic
classes9, bboxes9, segments9 = [], [], []
mosaic_samples = [sample, ]
indices = random.choices(self.indices, k=8) # 8 additional image indices
indices = random.choices(self.buffer, k=8) # 8 additional image indices

segments_is_list = isinstance(sample['segments'], list)
if segments_is_list:
Expand Down Expand Up @@ -1156,21 +1174,17 @@ def _exif_size(self, img):

return s

def train_collate_fn(self, batch_samples, batch_info):
imgs = [sample.pop('img') for sample in batch_samples]
def train_collate_fn(self, im_file, cls, bboxes, segments, keypoints, bbox_format,
segment_format, img, ori_shape, hw_scale, hw_pad, batch_info):
labels = []
for i, sample in enumerate(batch_samples):
cls, bboxes = sample.pop('cls'), sample.pop('bboxes')
labels.append(np.concatenate((np.full_like(cls, i), cls, bboxes), axis=-1))
return_items = [np.stack(imgs, 0), np.stack(labels, 0)]

for i, (c, b) in enumerate(zip(cls, bboxes)):
labels.append(np.concatenate((np.full_like(c, i), c, b), axis=-1))
return_items = [np.stack(img, 0), np.stack(labels, 0)]
if self.return_segments:
masks = [sample.pop('segments', None) for sample in batch_samples]
return_items.append(np.stack(masks, 0))
return_items.append(np.stack(segments, 0))
if self.return_keypoints:
keypoints = [sample.pop('keypoints', None) for sample in batch_samples]
return_items.append(np.stack(keypoints, 0))

return tuple(return_items)

def test_collate_fn(self, batch_samples, batch_info):
Expand Down
2 changes: 1 addition & 1 deletion mindyolo/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def create_loader(
Returns:
BatchDataset, dataset batched.
"""
de.config.set_seed(1236517205 + rank)
cores = multiprocessing.cpu_count()
num_parallel_workers = min(int(cores / rank_size), num_parallel_workers)
logger.info(f"Dataloader num parallel workers: [{num_parallel_workers}]")
de.config.set_seed(1236517205 + rank * num_parallel_workers)
if rank_size > 1:
ds = de.GeneratorDataset(
dataset,
Expand Down
2 changes: 1 addition & 1 deletion mindyolo/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def train(
manager = CheckpointManager(ckpt_save_policy="latest_k")
manager_ema = CheckpointManager(ckpt_save_policy="latest_k") if self.ema else None

loader = self.dataloader.create_dict_iterator(output_numpy=False, num_epochs=1)
loader = self.dataloader.create_dict_iterator(output_numpy=False, num_epochs=1, do_copy=False)
s_step_time = time.time()
s_epoch_time = time.time()
run_context = RunContext(
Expand Down
2 changes: 1 addition & 1 deletion tutorials/configuration_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ data:
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down

0 comments on commit 470b829

Please sign in to comment.