Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YOLOv8x Segmentation #196

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# MindYOLO Model Zoo and Baselines

## Detection

| Name | Scale | Context | ImageSize | Dataset | Box mAP (%) | Params | FLOPs | Recipe | Download |
|--------|--------------------|----------|-----------|--------------|-------------|--------|--------|--------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|
| YOLOv8 | N | D910x8-G | 640 | MS COCO 2017 | 37.2 | 3.2M | 8.7G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/yolov8n.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-n_500e_mAP372-cc07f5bd.ckpt) |
Expand All @@ -26,12 +28,16 @@
| YOLOX | X | D910x8-G | 640 | MS COCO 2017 | 51.6 | 99.1M | 281.9G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolox/yolox-x.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolox/yolox-x_300e_map516-52216d90.ckpt) |
| YOLOX | Darknet53 | D910x8-G | 640 | MS COCO 2017 | 47.7 | 63.7M | 185.3G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolox/yolox-darknet53.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolox/yolox-darknet53_300e_map477-b5fcaba9.ckpt) |

<br>
## Segmentation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick start里也加下segment的说明

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seg task兼容原有启动方式,用户通过指定不同.yaml文件来选择不同的task


| Name | Scale | Context | ImageSize | Dataset | Box mAP (%) | Mask mAP (%) | Params | FLOPs | Recipe | Download |
|------------|-------|----------|-----------|--------------|-------------|--------------|--------|--------|---------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|
| YOLOv8-seg | X | D910x8-G | 640 | MS COCO 2017 | 52.5 | 42.9 | 71.8M | 344.1G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/seg/yolov8x-seg.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x-seg_300e_mAP_mask_429-b4920557.ckpt) |

#### Depoly inference
## Depoly inference

- See [support list](./deploy/README.md)

#### Notes
## Notes
- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
- Box mAP: Accuracy reported on the validation set.
15 changes: 13 additions & 2 deletions configs/yolov8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-ar

## Results

### Detection

<div align="center">

| Name | Scale | Arch | Context | ImageSize | Dataset | Box mAP (%) | Params | FLOPs | Recipe | Download |
Expand All @@ -20,9 +22,18 @@ Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-ar
| YOLOv8 | X | P5 | D910x8-G | 640 | MS COCO 2017 | 53.7 | 68.2M | 257.8G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/yolov8x.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x_500e_mAP537-b958e1c7.ckpt) |

</div>
<br>

#### Notes
### Segmentation

<div align="center">

| Name | Scale | Arch | Context | ImageSize | Dataset | Box mAP (%) | Mask mAP (%) | Params | FLOPs | Recipe | Download |
|------------|-------|------|----------|-----------|--------------|-------------|--------------|--------|--------|---------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|
| YOLOv8-seg | X | P5 | D910x8-G | 640 | MS COCO 2017 | 52.5 | 42.9 | 71.8M | 344.1G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/seg/yolov8x-seg.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x-seg_300e_mAP_mask_429-b4920557.ckpt) |

</div>

### Notes

- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
- Box mAP: Accuracy reported on the validation set.
Expand Down
73 changes: 73 additions & 0 deletions configs/yolov8/seg/hyp.scratch.high.seg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
epochs: 300 # total train epochs

optimizer:
optimizer: momentum
lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
momentum: 0.937 # SGD momentum/Adam beta1
nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm
loss_scale: 1.0 # loss scale for optimizer
warmup_epochs: 3 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
min_warmup_step: 1000 # minimum warmup step
group_param: yolov8 # group param strategy
gp_weight_decay: 0.0010078125 # group param weight decay 5e-4
start_factor: 1.0
end_factor: 0.01

loss:
name: YOLOv8SegLoss
box: 7.5 # box loss gain
cls: 0.5 # cls loss gain
dfl: 1.5 # dfl loss gain
reg_max: 16
nm: 32
overlap: True
max_object_num: 600

data:
num_parallel_workers: 4

train_transforms: {
stage_epochs: [ 290, 10 ],
trans_list: [
[
{func_name: resample_segments},
{func_name: mosaic, prob: 1.0},
{func_name: copy_paste, prob: 0.3},
{func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0},
{func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.15, pre_transform: [
{ func_name: resample_segments },
{ func_name: mosaic, prob: 1.0 },
{ func_name: copy_paste, prob: 0.3 },
{ func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0 },]
},
{func_name: albumentations, random_resized_crop: False}, # random_resized_crop not support seg task
{func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4 },
{func_name: fliplr, prob: 0.5 },
{func_name: segment_poly2mask, mask_overlap: True, mask_ratio: 4 },
{func_name: label_norm, xyxy2xywh_: True },
{func_name: label_pad, padding_size: 160, padding_value: -1 },
{func_name: image_norm, scale: 255. },
{func_name: image_transpose, bgr2rgb: True, hwc2chw: True }
],
[
{func_name: resample_segments},
{func_name: letterbox, scaleup: True },
{func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0 },
{func_name: albumentations, random_resized_crop: False}, # random_resized_crop not support seg task
{func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4 },
{func_name: fliplr, prob: 0.5 },
{func_name: segment_poly2mask, mask_overlap: True, mask_ratio: 4 },
{func_name: label_norm, xyxy2xywh_: True },
{func_name: label_pad, padding_size: 160, padding_value: -1 },
{func_name: image_norm, scale: 255. },
{func_name: image_transpose, bgr2rgb: True, hwc2chw: True }
]]
}

test_transforms: [
{ func_name: letterbox, scaleup: False },
{ func_name: image_norm, scale: 255. },
{ func_name: image_transpose, bgr2rgb: True, hwc2chw: True }
]
49 changes: 49 additions & 0 deletions configs/yolov8/seg/yolov8-seg-base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
task: segment
epochs: 500 # total train epochs
per_batch_size: 16 # 16 * 8 = 128
img_size: 640
iou_thres: 0.7
conf_free: True
sync_bn: True
opencv_threads_num: 0 # opencv: disable threading optimizations

network:
model_name: yolov8
nc: 80 # number of classes
reg_max: 16

stride: [8, 16, 32]

# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, ConvNormAct, [64, 3, 2]] # 0-P1/2
- [-1, 1, ConvNormAct, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, ConvNormAct, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, ConvNormAct, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, ConvNormAct, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
- [-1, 1, Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12

- [-1, 1, Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1] ] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)

- [-1, 1, ConvNormAct, [256, 3, 2]]
- [[ -1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)

- [-1, 1, ConvNormAct, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)

- [[15, 18, 21], 1, YOLOv8Head, [nc, reg_max, stride]] # Detect(P3, P4, P5)
13 changes: 13 additions & 0 deletions configs/yolov8/seg/yolov8x-seg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
__BASE__: [
'../../coco.yaml',
'./hyp.scratch.high.seg.yaml',
'./yolov8-seg-base.yaml'
]

recompute: True
recompute_layers: 2

network:
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.25 # scales convolution channels
max_channels: 512
150 changes: 135 additions & 15 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from mindyolo.models import create_model
from mindyolo.utils import logger
from mindyolo.utils.config import parse_args
from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh
from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh, process_mask_upsample, scale_image
from mindyolo.utils.utils import draw_result, set_seed


def get_parser_infer(parents=None):
parser = argparse.ArgumentParser(description="Infer", parents=[parents] if parents else [])
parser.add_argument("--task", type=str, default="detect", choices=["detect", "segment"])
parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU")
parser.add_argument("--ms_mode", type=int, default=0, help="train mode, graph/pynative")
parser.add_argument("--ms_amp_level", type=str, default="O0", help="amp level, O0/O1/O2")
Expand Down Expand Up @@ -84,6 +85,7 @@ def detect(
nms_time_limit: float = 60.0,
img_size: int = 640,
stride: int = 32,
num_class: int = 80,
is_coco_dataset: bool = True,
):
# Resize
Expand Down Expand Up @@ -159,6 +161,106 @@ def detect(
return result_dict


def segment(
network: nn.Cell,
img: np.ndarray,
conf_thres: float = 0.25,
iou_thres: float = 0.65,
conf_free: bool = False,
nms_time_limit: float = 60.0,
img_size: int = 640,
stride: int = 32,
num_class: int = 80,
is_coco_dataset: bool = True,
):
# Resize
h_ori, w_ori = img.shape[:2] # orig hw
r = img_size / max(h_ori, w_ori) # resize image to img_size
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp)
h, w = img.shape[:2]
if h < img_size or w < img_size:
new_h, new_w = math.ceil(h / stride) * stride, math.ceil(w / stride) * stride
dh, dw = (new_h - h) / 2, (new_w - w) / 2
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border

# Transpose Norm
img = img[:, :, ::-1].transpose(2, 0, 1) / 255.0
imgs_tensor = Tensor(img[None], ms.float32)

# Run infer
_t = time.time()
out, (_, _, prototypes) = network(imgs_tensor) # inference and training outputs
infer_times = time.time() - _t

# Run NMS
t = time.time()
_c = num_class + 4 if conf_free else num_class + 5
out = out.asnumpy()
bboxes, mask_coefficient = out[:, :, :_c], out[:, :, _c:]
out = non_max_suppression(
bboxes,
mask_coefficient,
conf_thres=conf_thres,
iou_thres=iou_thres,
conf_free=conf_free,
multi_label=True,
time_limit=nms_time_limit,
)
nms_times = time.time() - t

prototypes = prototypes.asnumpy()

result_dict = {"category_id": [], "bbox": [], "score": [], "segmentation": []}
total_category_ids, total_bboxes, total_scores, total_seg = [], [], [], []
for si, (pred, proto) in enumerate(zip(out, prototypes)):
if len(pred) == 0:
continue

# Predictions
pred_masks = process_mask_upsample(proto, pred[:, 6:], pred[:, :4], shape=imgs_tensor[si].shape[1:])
pred_masks = pred_masks.astype(np.float32)
pred_masks = scale_image((pred_masks.transpose(1, 2, 0)), (h_ori, w_ori))
predn = np.copy(pred)
scale_coords(img.shape[1:], predn[:, :4], (h_ori, w_ori)) # native-space pred

box = xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
category_ids, bboxes, scores, segs = [], [], [], []
for ii, (p, b) in enumerate(zip(pred.tolist(), box.tolist())):
category_ids.append(COCO80_TO_COCO91_CLASS[int(p[5])] if is_coco_dataset else int(p[5]))
bboxes.append([round(x, 3) for x in b])
scores.append(round(p[4], 5))
segs.append(pred_masks[:, :, ii])

total_category_ids.extend(category_ids)
total_bboxes.extend(bboxes)
total_scores.extend(scores)
total_seg.extend(segs)

result_dict["category_id"].extend(total_category_ids)
result_dict["bbox"].extend(total_bboxes)
result_dict["score"].extend(total_scores)
result_dict["segmentation"].extend(total_seg)

t = tuple(x * 1e3 for x in (infer_times, nms_times, infer_times + nms_times)) + (img_size, img_size, 1) # tuple
logger.info(f"Predict result is:")
for k, v in result_dict.items():
if k == "segmentation":
logger.info(f"{k} shape: {v[0].shape}")
else:
logger.info(f"{k}: {v}")
logger.info(f"Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g;" % t)
logger.info(f"Detect a image success.")

return result_dict


def infer(args):
# Init
set_seed(args.seed)
Expand All @@ -184,20 +286,38 @@ def infer(args):

# Detect
is_coco_dataset = "coco" in args.data.dataset_name
result_dict = detect(
network=network,
img=img,
conf_thres=args.conf_thres,
iou_thres=args.iou_thres,
conf_free=args.conf_free,
nms_time_limit=args.nms_time_limit,
img_size=args.img_size,
stride=max(max(args.network.stride), 32),
is_coco_dataset=is_coco_dataset,
)
if args.save_result:
save_path = os.path.join(args.save_dir, "detect_results")
draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path)
if args.task == "detect":
result_dict = detect(
network=network,
img=img,
conf_thres=args.conf_thres,
iou_thres=args.iou_thres,
conf_free=args.conf_free,
nms_time_limit=args.nms_time_limit,
img_size=args.img_size,
stride=max(max(args.network.stride), 32),
num_class=args.data.nc,
is_coco_dataset=is_coco_dataset,
)
if args.save_result:
save_path = os.path.join(args.save_dir, "detect_results")
draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path)
elif args.task == "segment":
result_dict = segment(
network=network,
img=img,
conf_thres=args.conf_thres,
iou_thres=args.iou_thres,
conf_free=args.conf_free,
nms_time_limit=args.nms_time_limit,
img_size=args.img_size,
stride=max(max(args.network.stride), 32),
num_class=args.data.nc,
is_coco_dataset=is_coco_dataset,
)
if args.save_result:
save_path = os.path.join(args.save_dir, "segment_results")
draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path)

logger.info("Infer completed.")

Expand Down
Loading