Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve yolo performance #409

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov10/hyp.scratch.med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov3/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov4/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov5/hyp.scratch-high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov5/hyp.scratch-low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.p5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.p6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov7/hyp.scratch.tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ loss:
label_smoothing: 0.0 # label smoothing epsilon

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2 }
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/hyp.scratch.med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov8/seg/hyp.scratch.high.seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ loss:
max_object_num: 600

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms: {
stage_epochs: [ 290, 10 ],
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov9/hyp.scratch.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ loss:
reg_max: 16

data:
num_parallel_workers: 4
num_parallel_workers: 8

# multi-stage data augment
train_transforms: {
Expand Down
2 changes: 1 addition & 1 deletion configs/yolox/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ img_size: 640
sync_bn: False

data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms: {
stage_epochs: [ 285, 15 ],
Expand Down
2 changes: 1 addition & 1 deletion docs/en/tutorials/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ This part of the parameters is defined in [configs/yolov3/hyp.scratch.yaml](http
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/tutorials/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ data:
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down
50 changes: 33 additions & 17 deletions mindyolo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def __init__(
self.is_training = is_training

# set column names
self.column_names_getitem = ['samples']
# https://www.mindspore.cn/docs/zh-CN/master/api_python/dataset/mindspore.dataset.config.set_enable_shared_mem.html
# MS version limitations, shared memory does not support dict data type
self.column_names_getitem = ['im_file', 'cls', 'bboxes', 'segments', 'keypoints', 'bbox_format', 'segment_format',
WongGawa marked this conversation as resolved.
Show resolved Hide resolved
'img', 'ori_shape', 'hw_scale', 'hw_pad'] if self.is_training else ['samples']
if self.is_training:
self.column_names_collate = ['images', 'labels']
if self.return_segments:
Expand Down Expand Up @@ -169,7 +172,10 @@ def __init__(
self.batch = bi # batch index of image

# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs, self.img_hw_ori, self.indices = None, None, range(n)
self.imgs, self.img_hw_ori, self.indices = [None] * n, [None] * n, range(n)
# Buffer thread for mosaic images
self.buffer = []
self.max_buffer_length = min((n, batch_size * 8, 1000)) if self.augment else 0

# Rectangular Train/Test
if self.rect:
Expand Down Expand Up @@ -313,6 +319,14 @@ def __getitem__(self, index):
sample = getattr(self, func_name)(sample, **_trans)

sample['img'] = np.ascontiguousarray(sample['img'])
if self.is_training:
train_sample = []
for col_name in self.column_names_getitem:
if sample.get(col_name) is None:
train_sample.append(np.nan)
else:
train_sample.append(sample.get(col_name, np.nan))
return tuple(train_sample)
return sample

def __len__(self):
Expand All @@ -321,7 +335,8 @@ def __len__(self):
def get_sample(self, index):
"""Get and return label information from the dataset."""
sample = deepcopy(self.labels[index])
if self.imgs is None:
img = self.imgs[index]
if img is None:
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, "Image Not Found " + path
Expand All @@ -331,8 +346,13 @@ def get_sample(self, index):
interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp)

if self.augment:
self.imgs[index], self.img_hw_ori[index] = img, np.array([h_ori, w_ori]) # img, hw_original
self.buffer.append(index)
if 1 < len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.imgs[j], self.img_hw_ori[j] = None, np.array([None, None])
sample['img'], sample['ori_shape'] = img, np.array([h_ori, w_ori]) # img, hw_original

else:
sample['img'], sample['ori_shape'] = self.imgs[index], self.img_hw_ori[index] # img, hw_original

Expand Down Expand Up @@ -367,7 +387,7 @@ def _mosaic4(self, sample):
# loads images in a 4-mosaic
classes4, bboxes4, segments4 = [], [], []
mosaic_samples = [sample, ]
indices = random.choices(self.indices, k=3) # 3 additional image indices
indices = random.choices(self.buffer, k=3) # 3 additional image indices

segments_is_list = isinstance(sample['segments'], list)
if segments_is_list:
Expand Down Expand Up @@ -444,7 +464,7 @@ def _mosaic9(self, sample):
# loads images in a 9-mosaic
classes9, bboxes9, segments9 = [], [], []
mosaic_samples = [sample, ]
indices = random.choices(self.indices, k=8) # 8 additional image indices
indices = random.choices(self.buffer, k=8) # 8 additional image indices

segments_is_list = isinstance(sample['segments'], list)
if segments_is_list:
Expand Down Expand Up @@ -1156,21 +1176,17 @@ def _exif_size(self, img):

return s

def train_collate_fn(self, batch_samples, batch_info):
imgs = [sample.pop('img') for sample in batch_samples]
def train_collate_fn(self, im_file, cls, bboxes, segments, keypoints, bbox_format,
segment_format, img, ori_shape, hw_scale, hw_pad, batch_info):
labels = []
for i, sample in enumerate(batch_samples):
cls, bboxes = sample.pop('cls'), sample.pop('bboxes')
labels.append(np.concatenate((np.full_like(cls, i), cls, bboxes), axis=-1))
return_items = [np.stack(imgs, 0), np.stack(labels, 0)]

for i, (c, b) in enumerate(zip(cls, bboxes)):
labels.append(np.concatenate((np.full_like(c, i), c, b), axis=-1))
return_items = [np.stack(img, 0), np.stack(labels, 0)]
if self.return_segments:
masks = [sample.pop('segments', None) for sample in batch_samples]
return_items.append(np.stack(masks, 0))
return_items.append(np.stack(segments, 0))
if self.return_keypoints:
keypoints = [sample.pop('keypoints', None) for sample in batch_samples]
return_items.append(np.stack(keypoints, 0))

return tuple(return_items)

def test_collate_fn(self, batch_samples, batch_info):
Expand Down
2 changes: 1 addition & 1 deletion mindyolo/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def create_loader(
Returns:
BatchDataset, dataset batched.
"""
de.config.set_seed(1236517205 + rank)
cores = multiprocessing.cpu_count()
num_parallel_workers = min(int(cores / rank_size), num_parallel_workers)
logger.info(f"Dataloader num parallel workers: [{num_parallel_workers}]")
de.config.set_seed(1236517205 + rank * num_parallel_workers)
if rank_size > 1:
ds = de.GeneratorDataset(
dataset,
Expand Down
2 changes: 1 addition & 1 deletion mindyolo/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def train(
manager = CheckpointManager(ckpt_save_policy="latest_k")
manager_ema = CheckpointManager(ckpt_save_policy="latest_k") if self.ema else None

loader = self.dataloader.create_dict_iterator(output_numpy=False, num_epochs=1)
loader = self.dataloader.create_dict_iterator(output_numpy=False, num_epochs=1, do_copy=False)
s_step_time = time.time()
s_epoch_time = time.time()
run_context = RunContext(
Expand Down
2 changes: 1 addition & 1 deletion tutorials/configuration_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ data:
```yaml
data:
num_parallel_workers: 4
num_parallel_workers: 8

train_transforms:
- { func_name: mosaic, prob: 1.0, mosaic9_prob: 0.0, translate: 0.1, scale: 0.9 }
Expand Down
Loading