From 519abeeb0cbf71bfc138528131f632ea2276e6a9 Mon Sep 17 00:00:00 2001 From: yuedongli1 Date: Tue, 30 May 2023 19:29:35 +0800 Subject: [PATCH] add finetune tutorial --- GETTING_STARTED.md | 2 +- GETTING_STARTED_CN.md | 2 +- examples/finetune_SHWD/README.md | 71 +++++ examples/finetune_SHWD/convert_shwd2yolo.py | 259 ++++++++++++++++++ .../convert_yolov7-tiny_pretrain_ckpt.py | 15 + examples/finetune_SHWD/finetune_shwd.py | 149 ++++++++++ examples/finetune_SHWD/yolov7-tiny_shwd.yaml | 176 ++++++++++++ mindyolo/utils/config.py | 1 - train.py | 2 +- tutorials/custom_dataset.md | 44 +++ 10 files changed, 717 insertions(+), 4 deletions(-) create mode 100644 examples/finetune_SHWD/README.md create mode 100644 examples/finetune_SHWD/convert_shwd2yolo.py create mode 100644 examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py create mode 100644 examples/finetune_SHWD/finetune_shwd.py create mode 100644 examples/finetune_SHWD/yolov7-tiny_shwd.yaml create mode 100644 tutorials/custom_dataset.md diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 9ea989ea..1884159b 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -60,7 +60,7 @@ to understand their behavior. Some common arguments are: * To evaluate a model's performance: ``` - python test.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt + python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` *Notes: (1) The default hyper-parameter is used for 8-card training, and some parameters need to be adjusted in the case of a single card. (2) The default device is Ascend, and you can modify it by specifying 'device_target' as Ascend/GPU/CPU, as these are currently supported.* * For more options, see `train/test.py -h`. diff --git a/GETTING_STARTED_CN.md b/GETTING_STARTED_CN.md index 1e17ddf7..5256ecb8 100644 --- a/GETTING_STARTED_CN.md +++ b/GETTING_STARTED_CN.md @@ -60,7 +60,7 @@ python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_c * 评估模型的精度: ```shell - python test.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt + python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` *注意:默认超参为8卡训练,单卡情况需调整部分参数。 默认设备为Ascend,您可以指定'device_target'的值为Ascend/GPU/CPU。* * 有关更多选项,请参阅 `train/test.py -h`. diff --git a/examples/finetune_SHWD/README.md b/examples/finetune_SHWD/README.md new file mode 100644 index 00000000..6a3882b7 --- /dev/null +++ b/examples/finetune_SHWD/README.md @@ -0,0 +1,71 @@ +### 自定义数据集finetune流程 + +本文以安全帽佩戴检测数据集(SHWD)为例,介绍自定义数据集在MindYOLO上进行finetune的主要流程。 + +#### 数据集格式转换 + +[SHWD数据集](https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset/tree/master)采用voc格式的数据标注,其文件目录如下所示: +``` + ROOT_DIR + ├── Annotations + │ ├── 000000.xml + │ └── 000002.xml + ├── ImageSets + │ └── Main + │ ├── test.txt + │ ├── train.txt + │ ├── trainval.txt + │ └── val.txt + └── JPEGImages + ├── 000000.jpg + └── 000002.jpg +``` +其中,ImageSets/Main文件下的txt文件中每行代表相应子集中单张图片不含后缀的文件名,例如: +``` +000002 +000005 +000019 +000022 +000027 +000034 +``` + +由于MindYOLO在验证阶段选用图片名称作为image_id,因此图片名称只能为数值类型,而不能为字符串类型,还需要对图片进行改名。对SHWD数据集格式的转换包含如下步骤: +* 将图片复制到相应的路径下并改名 +* 在根目录下相应的txt文件中写入该图片的相对路径 +* 解析xml文件,在相应路径下生成对应的txt标注文件 +* 验证集还需生成最终的json文件 + +详细实现可参考[convert_shwd2yolo.py](./convert_shwd2yolo.py)。运行方式如下: + + ```shell + python examples/finetune_SHWD/convert_shwd2yolo.py --root_dir /path_to_shwd/SHWD + ``` + +运行以上命令将在不改变原数据集的前提下,在同级目录生成yolo格式的SHWD数据集。 + +#### 预训练模型文件转换 + +由于SHWD数据集只有7000+张图片,选择yolov7-tiny进行该数据集的训练,可下载MindYOLO提供的在coco数据集上训练好的[模型文件](https://github.com/mindspore-lab/mindyolo/blob/master/MODEL_ZOO.md)作为预训练模型。由于coco数据集含有80种物体类别,SHWD数据集只有两类,模型的最后一层head层输出与类别数nc有关,因此需将预训练模型文件的最后一层去掉, 可参考[convert_yolov7-tiny_pretrain_ckpt.py](./convert_yolov7-tiny_pretrain_ckpt.py)。运行方式如下: + + ```shell + python examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py + ``` + +#### 模型微调(Finetune) + +简要的训练流程可参考[finetune_shwd.py](./finetune_shwd.py) + +* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: + + ```shell + mpirun --allow-run-as-root -n 8 python examples/finetune_SHWD/finetune_shwd.py --config ./examples/finetune_SHWD/yolov7-tiny.yaml --is_parallel True + ``` + +* 在单卡NPU/GPU/CPU上训练模型: + + ```shell + python examples/finetune_SHWD/finetune_shwd.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml + ``` + +*注意:直接用yolov7-tiny默认coco参数在SHWD数据集上训练,可取得AP50 87.0的精度。将lr_init参数由0.01改为0.001,即可实现ap50为89.2的精度结果。* \ No newline at end of file diff --git a/examples/finetune_SHWD/convert_shwd2yolo.py b/examples/finetune_SHWD/convert_shwd2yolo.py new file mode 100644 index 00000000..c248a160 --- /dev/null +++ b/examples/finetune_SHWD/convert_shwd2yolo.py @@ -0,0 +1,259 @@ +import os +from pathlib import Path +import argparse +import shutil +import xml.etree.ElementTree as ET +import collections +import json +from tqdm import tqdm + + +category_set = ['person', 'hat'] +coco = dict() +coco['images'] = [] +coco['type'] = 'instances' +coco['annotations'] = [] +coco['categories'] = [] + +category_item_id = 0 +annotation_id = 0 +image_index = 0 + + +def addCatItem(name): + global category_item_id + category_item = collections.OrderedDict() + category_item['supercategory'] = 'none' + category_item['id'] = category_item_id + category_item['name'] = name + coco['categories'].append(category_item) + category_item_id += 1 + + +def addImgItem(image_id, size): + file_name = str(image_id).zfill(8) + '.jpg' + if not size['width']: + raise Exception('Could not find width tag in xml file.') + if not size['height']: + raise Exception('Could not find height tag in xml file.') + + image_item = collections.OrderedDict() + + image_item['file_name'] = file_name + image_item['width'] = size['width'] + image_item['height'] = size['height'] + image_item['id'] = image_id + coco['images'].append(image_item) + + +def addAnnoItem(image_id, category_id, bbox): + global annotation_id + + annotation_item = collections.OrderedDict() + annotation_item['segmentation'] = [] + + # segmentation + seg = [] + # left_top + seg.append(bbox[0]) + seg.append(bbox[1]) + # left_bottom + seg.append(bbox[0]) + seg.append(bbox[1] + bbox[3]) + # right_bottom + seg.append(bbox[0] + bbox[2]) + seg.append(bbox[1] + bbox[3]) + # right_top + seg.append(bbox[0] + bbox[2]) + seg.append(bbox[1]) + + annotation_item['segmentation'].append(seg) + annotation_item['area'] = bbox[2] * bbox[3] + annotation_item['iscrowd'] = 0 + annotation_item['image_id'] = image_id + annotation_item['bbox'] = bbox + annotation_item['category_id'] = category_id + annotation_item['id'] = annotation_id + annotation_item['ignore'] = 0 + annotation_id += 1 + coco['annotations'].append(annotation_item) + + +def xxyy2xywhn(size, box): + dw = 1. / (size[0]) + dh = 1. / (size[1]) + x = (box[0] + box[1]) / 2.0 - 1 + y = (box[2] + box[3]) / 2.0 - 1 + w = box[1] - box[0] + h = box[3] - box[2] + x = x * dw + w = w * dw + y = y * dh + h = h * dh + x = round(x, 6) + w = round(w, 6) + y = round(y, 6) + h = round(h, 6) + return x, y, w, h + + +def xml2txt(xml_path, txt_path): + in_file = open(xml_path, encoding='utf-8') + out_file = open(txt_path, 'w', encoding='utf-8') + tree = ET.parse(in_file) + root = tree.getroot() + size = root.find('size') + w = int(size.find('width').text) + h = int(size.find('height').text) + for obj in root.iter('object'): + difficult = obj.find('difficult').text + cls = obj.find('name').text + if cls not in category_set or int(difficult) == 1: + continue + cls_id = category_set.index(cls) + xmlbox = obj.find('bndbox') + x1, x2, y1, y2 = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), + float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) + + # clip x2, y2 to normal range + if x2 > w: + x2 = w + if y2 > h: + y2 = h + + # xyxy2xywhn + bbox = (x1, x2, y1, y2) + bbox = xxyy2xywhn((w, h), bbox) + out_file.write(str(cls_id) + " " + + " ".join([str(a) for a in bbox]) + '\n') + + +def xml2json(image_index, xml_path): + bndbox = dict() + size = dict() + size['width'] = None + size['height'] = None + + tree = ET.parse(xml_path) + root = tree.getroot() + + if root.tag != 'annotation': + raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag)) + + # elem format: , , , + for elem in root: + if elem.tag == 'folder' or elem.tag == 'filename' or elem.tag == 'path' or elem.tag == 'source': + continue + + elif elem.tag == 'size': + # add image information, like file_name, size, image_id + for subelem in elem: + size[subelem.tag] = int(subelem.text) + addImgItem(image_index, size) + + elif elem.tag == 'object': + for subelem in elem: + if subelem.tag == 'name': + object_name = subelem.text + current_category_id = category_set.index(object_name) + + elif subelem.tag == 'bndbox': + for option in subelem: + bndbox[option.tag] = int(option.text) + + bbox = [] + bbox.append(bndbox['xmin']) + bbox.append(bndbox['ymin']) + bbox.append(bndbox['xmax'] - bndbox['xmin']) + bbox.append(bndbox['ymax'] - bndbox['ymin']) + + # add bound box information, include area,image_id, bbox, category_id, id and so on + addAnnoItem(image_index, current_category_id, bbox) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--root_dir', default='', type=str, help='root directory of data set') + opt = parser.parse_args() + + # generate directory structure + root_dir = opt.root_dir + new_dir = os.path.join(root_dir, '..', 'SHWD') + os.makedirs(os.path.join(new_dir, 'images', 'train')) + os.makedirs(os.path.join(new_dir, 'images', 'val')) + os.makedirs(os.path.join(new_dir, 'labels', 'train')) + os.makedirs(os.path.join(new_dir, 'labels', 'val')) + os.makedirs(os.path.join(new_dir, 'annotations')) + + train_txt_yolo = open(os.path.join(new_dir, 'train.txt'), 'w') + val_txt_yolo = open(os.path.join(new_dir, 'val.txt'), 'w') + + images_path = os.path.join(root_dir, 'JPEGImages') + labels_path = os.path.join(root_dir, 'Annotations') + + train_set_txt = os.path.join(root_dir, 'ImageSets', 'Main', 'trainval.txt') + with open(train_set_txt, 'r', encoding='utf-8') as f: + for line in tqdm(f.readlines(), desc='train_set'): + stem = line.strip('\n') + old_path = os.path.join(images_path, stem + '.jpg') + if not os.path.exists(old_path): + old_path = os.path.join(images_path, stem + '.JPG') + + # copy train_set image to new path + new_images_path = os.path.join(new_dir, 'images', 'train') + shutil.copy(old_path, new_images_path) + + # rename image_file to continuous number + old_name = Path(old_path).name + new_stem = str(image_index).zfill(8) + os.rename(os.path.join(new_images_path, old_name), os.path.join(new_images_path, new_stem + '.jpg')) + + # write the relative path of image to train.txt + train_txt_yolo.write('./images/train/' + new_stem + '.jpg' + '\n') + + # convert xml file to txt file + xml_path = os.path.join(labels_path, stem + '.xml') + txt_path = os.path.join(new_dir, 'labels', 'train', new_stem + '.txt') + xml2txt(xml_path, txt_path) + + image_index += 1 + + val_set_txt = os.path.join(root_dir, 'ImageSets', 'Main', 'test.txt') + with open(val_set_txt, 'r', encoding='utf-8') as f: + for line in tqdm(f.readlines(), desc='val_set'): + stem = line.strip('\n') + old_path = os.path.join(images_path, stem + '.jpg') + if not os.path.exists(old_path): + old_path = os.path.join(images_path, stem + '.JPG') + + # copy val_set image to new path + new_images_path = os.path.join(new_dir, 'images', 'val') + shutil.copy(old_path, new_images_path) + + # rename image_file to continuous number + old_name = Path(old_path).name + new_stem = str(image_index).zfill(8) + os.rename(os.path.join(new_images_path, old_name), os.path.join(new_images_path, new_stem + '.jpg')) + + # write the relative path of image to val.txt + val_txt_yolo.write('./images/val/' + new_stem + '.jpg' + '\n') + + # convert xml file to txt file + xml_path = os.path.join(labels_path, stem + '.xml') + txt_path = os.path.join(new_dir, 'labels', 'val', new_stem + '.txt') + xml2txt(xml_path, txt_path) + + # convert xml file to json file + xml2json(image_index, xml_path) + + image_index += 1 + + for categoryname in category_set: + addCatItem(categoryname) + + train_txt_yolo.close() + val_txt_yolo.close() + + # save ground truth json file + json_file = os.path.join(new_dir, 'annotations', 'instances_val2017.json') + json.dump(coco, open(json_file, 'w')) diff --git a/examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py b/examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py new file mode 100644 index 00000000..8f506f7f --- /dev/null +++ b/examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py @@ -0,0 +1,15 @@ +import mindspore as ms + + +def convert_weight(ori_weight, new_weight): + new_ckpt = [] + param_dict = ms.load_checkpoint(ori_weight) + for k, v in param_dict.items(): + if '77' in k: + continue + new_ckpt.append({'name': k, 'data': v}) + ms.save_checkpoint(new_ckpt, new_weight) + + +if __name__ == '__main__': + convert_weight('./yolov7-tiny_300e_mAP375-d8972c94.ckpt', './yolov7-tiny_pretrain.ckpt') diff --git a/examples/finetune_SHWD/finetune_shwd.py b/examples/finetune_SHWD/finetune_shwd.py new file mode 100644 index 00000000..3ae22e22 --- /dev/null +++ b/examples/finetune_SHWD/finetune_shwd.py @@ -0,0 +1,149 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +import mindspore as ms + +from train import get_parser_train +from mindyolo.data import COCODataset, create_loader +from mindyolo.models import create_loss, create_model +from mindyolo.optim import (EMA, create_group_param, create_lr_scheduler, + create_optimizer, create_warmup_momentum_scheduler) +from mindyolo.utils import logger +from mindyolo.utils.config import parse_args +from mindyolo.utils.train_step_factory import (create_train_step_fn, + get_gradreducer, + get_loss_scaler) +from mindyolo.utils.trainer_factory import create_trainer +from mindyolo.utils.utils import (freeze_layers, load_pretrain, set_default, + set_seed) + + +def train_shwd(args): + # Set Default + set_seed(args.seed) + set_default(args) + main_device = args.rank % args.rank_size == 0 + + logger.info("parse_args:") + logger.info("\n" + str(args)) + logger.info("Please check the above information for the configurations") + + # Create Network + args.network.recompute = args.recompute + args.network.recompute_layers = args.recompute_layers + network = create_model( + model_name=args.network.model_name, + model_cfg=args.network, + num_classes=args.data.nc, + sync_bn=args.sync_bn, + ) + + if args.ema and main_device: + ema_network = create_model( + model_name=args.network.model_name, + model_cfg=args.network, + num_classes=args.data.nc, + ) + ema = EMA(network, ema_network) + else: + ema = None + load_pretrain(network, args.weight, ema, args.ema_weight) # load pretrain + freeze_layers(network, args.freeze) # freeze Layers + ms.amp.auto_mixed_precision(network, amp_level=args.ms_amp_level) + if ema: + ms.amp.auto_mixed_precision(ema.ema, amp_level=args.ms_amp_level) + + # Create Dataloaders + transforms = args.data.train_transforms + dataset = COCODataset( + dataset_path=args.data.train_set, + img_size=args.img_size, + transforms_dict=transforms, + is_training=True, + augment=True, + rect=args.rect, + single_cls=args.single_cls, + batch_size=args.total_batch_size, + stride=max(args.network.stride), + ) + dataloader = create_loader( + dataset=dataset, + batch_collate_fn=dataset.train_collate_fn, + dataset_column_names=dataset.dataset_column_names, + batch_size=args.per_batch_size, + epoch_size=args.epochs, + rank=args.rank, + rank_size=args.rank_size, + shuffle=True, + drop_remainder=True, + num_parallel_workers=args.data.num_parallel_workers, + python_multiprocessing=True, + ) + steps_per_epoch = dataloader.get_dataset_size() // args.epochs + + # Create Loss + loss_fn = create_loss( + **args.loss, anchors=args.network.get("anchors", 1), stride=args.network.stride, nc=args.data.nc + ) + ms.amp.auto_mixed_precision(loss_fn, amp_level="O0" if args.keep_loss_fp32 else args.ms_amp_level) + + # Create Optimizer + args.optimizer.steps_per_epoch = steps_per_epoch + lr = create_lr_scheduler(**args.optimizer) + params = create_group_param(params=network.trainable_params(), **args.optimizer) + optimizer = create_optimizer(params=params, lr=lr, **args.optimizer) + warmup_momentum = create_warmup_momentum_scheduler(**args.optimizer) + + # Create train_step_fn + reducer = get_gradreducer(args.is_parallel, optimizer.parameters) + scaler = get_loss_scaler(args.ms_loss_scaler, scale_value=args.ms_loss_scaler_value) + train_step_fn = create_train_step_fn( + network=network, + loss_fn=loss_fn, + optimizer=optimizer, + loss_ratio=args.rank_size, + scaler=scaler, + reducer=reducer, + overflow_still_update=args.overflow_still_update, + ms_jit=args.ms_jit, + ) + + # Create Trainer + network.set_train(True) + optimizer.set_train(True) + model_name = os.path.basename(args.config)[:-5] # delete ".yaml" + trainer = create_trainer( + model_name=model_name, + train_step_fn=train_step_fn, + scaler=scaler, + dataloader=dataloader, + steps_per_epoch=steps_per_epoch, + network=network, + ema=ema, + optimizer=optimizer, + summary=args.summary, + ) + + trainer.train( + epochs=args.epochs, + main_device=main_device, + warmup_step=max(round(args.optimizer.warmup_epochs * steps_per_epoch), args.optimizer.min_warmup_step), + warmup_momentum=warmup_momentum, + accumulate=args.accumulate, + overflow_still_update=args.overflow_still_update, + keep_checkpoint_max=args.keep_checkpoint_max, + log_interval=args.log_interval, + loss_item_name=[] if not hasattr(loss_fn, "loss_item_name") else loss_fn.loss_item_name, + save_dir=args.save_dir, + enable_modelarts=args.enable_modelarts, + train_url=args.train_url, + run_eval=args.run_eval, + ) + logger.info("Training completed.") + + +if __name__ == "__main__": + parser = get_parser_train() + args = parse_args(parser) + train_shwd(args) diff --git a/examples/finetune_SHWD/yolov7-tiny_shwd.yaml b/examples/finetune_SHWD/yolov7-tiny_shwd.yaml new file mode 100644 index 00000000..76ee024e --- /dev/null +++ b/examples/finetune_SHWD/yolov7-tiny_shwd.yaml @@ -0,0 +1,176 @@ +per_batch_size: 16 # 16 * 8 = 128 +img_size: 640 # image sizes +sync_bn: True +weight: ./yolov7-tiny_pretrain.ckpt + +data: + dataset_name: shwd + + train_set: ./SHWD/train.txt + val_set: ./SHWD/val.txt + + nc: 2 + + # class names + names: [ 'person', 'hat' ] + + num_parallel_workers: 4 + + train_transforms: + - {func_name: mosaic, prob: 1.0, mosaic9_prob: 0.2, translate: 0.1, scale: 0.5} + - {func_name: mixup, prob: 0.05, alpha: 8.0, beta: 8.0, needed_mosaic: True} + - {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4} + - {func_name: pastein, prob: 0.05, num_sample: 30} + - {func_name: label_norm, xyxy2xywh_: True} + - {func_name: fliplr, prob: 0.5} + - {func_name: label_pad, padding_size: 160, padding_value: -1} + - {func_name: image_norm, scale: 255.} + - {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + + test_transforms: + - {func_name: letterbox, scaleup: False} + - {func_name: label_norm, xyxy2xywh_: True} + - {func_name: label_pad, padding_size: 160, padding_value: -1} + - {func_name: image_norm, scale: 255. } + - {func_name: image_transpose, bgr2rgb: True, hwc2chw: True } + +optimizer: + optimizer: momentum + lr_init: 0.001 # initial learning rate + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov7 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv7Loss + box: 0.05 # box loss gain + cls: 0.5 # cls loss gain + cls_pw: 1.0 # cls BCELoss positive_weight + obj: 1.0 # obj loss gain (scale with pixels) + obj_pw: 1.0 # obj BCELoss positive_weight + fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) + anchor_t: 4.0 # anchor-multiple threshold + label_smoothing: 0.0 # label smoothing epsilon + +network: + model_name: yolov7 + depth_multiple: 1.0 # model depth multiple + width_multiple: 1.0 # layer channel multiple + + stride: [8, 16, 32] + + # anchors + anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + + # yolov7-tiny backbone + backbone: + # [from, number, module, args] c2, k=1, s=1, p=None, g=1, d=1, act=True + [[-1, 1, ConvNormAct, [32, 3, 2, None, 1, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 + + [-1, 1, ConvNormAct, [64, 3, 2, None, 1, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 + + [-1, 1, ConvNormAct, [32, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [32, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [32, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [32, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 7 + + [-1, 1, MP, []], # 8-P3/8 + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 14 + + [-1, 1, MP, []], # 15-P4/16 + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [128, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [128, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 21 + + [-1, 1, MP, []], # 22-P5/32 + [-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [256, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [256, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [512, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 28 + ] + + # yolov7-tiny head + head: + [[-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, SP, [5]], + [-2, 1, SP, [9]], + [-3, 1, SP, [13]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -7], 1, Concat, [1]], + [-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 37 + + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, Upsample, [None, 2, 'nearest']], + [21, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # route backbone P4 + [[-1, -2], 1, Concat, [1]], + + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 47 + + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, Upsample, [None, 2, 'nearest']], + [14, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # route backbone P3 + [[-1, -2], 1, Concat, [1]], + + [-1, 1, ConvNormAct, [32, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [32, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [32, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [32, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 57 + + [-1, 1, ConvNormAct, [128, 3, 2, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, 47], 1, Concat, [1]], + + [-1, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [64, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [64, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 65 + + [-1, 1, ConvNormAct, [256, 3, 2, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, 37], 1, Concat, [1]], + + [-1, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-2, 1, ConvNormAct, [128, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [128, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [-1, 1, ConvNormAct, [128, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [[-1, -2, -3, -4], 1, Concat, [1]], + [-1, 1, ConvNormAct, [256, 1, 1, None, 1, 1, nn.LeakyReLU(0.1)]], # 73 + + [57, 1, ConvNormAct, [128, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [65, 1, ConvNormAct, [256, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + [73, 1, ConvNormAct, [512, 3, 1, None, 1, 1, nn.LeakyReLU(0.1)]], + + [[74,75,76], 1, YOLOv7Head, [nc, anchors, stride]], # Detect(P3, P4, P5) + ] diff --git a/mindyolo/utils/config.py b/mindyolo/utils/config.py index ad681202..0b35a893 100644 --- a/mindyolo/utils/config.py +++ b/mindyolo/utils/config.py @@ -113,7 +113,6 @@ def __init__(self, cfg_dict): super(Config, self).__init__() for k, v in cfg_dict.items(): setattr(self, k, Config(v) if isinstance(v, dict) else v) - self.__dict__.update(cfg_dict) def __setattr__(self, name, value): self[name] = value diff --git a/train.py b/train.py index 1fad02b3..3decbdda 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ from functools import partial from mindyolo.utils.callback import create_callback -from test import test import mindspore as ms @@ -155,6 +154,7 @@ def train(args): steps_per_epoch = dataloader.get_dataset_size() // args.epochs if args.run_eval: + from test import test eval_dataset = COCODataset( dataset_path=args.data.val_set, img_size=args.img_size, diff --git a/tutorials/custom_dataset.md b/tutorials/custom_dataset.md new file mode 100644 index 00000000..9a88a82d --- /dev/null +++ b/tutorials/custom_dataset.md @@ -0,0 +1,44 @@ +# 数据集格式介绍 + +适用于MindYOLO的数据集格式具有如下形式: +``` + ROOT_DIR + ├── val.txt + ├── train.txt + ├── annotations + │ └── instances_val2017.json + ├── images + │ ├── train + │ │ ├── 00000001.jpg + │ │ └── 00000002.jpg + │ └── val + │ ├── 00006563.jpg + │ └── 00006564.jpg + └── labels + └── train + ├── 00000001.txt + └── 00000002.txt +``` + +其中train.txt文件每行对应单张图片的相对路径,例如: +``` +./images/train/00000000.jpg +./images/train/00000001.jpg +./images/train/00000002.jpg +./images/train/00000003.jpg +./images/train/00000004.jpg +./images/train/00000005.jpg +``` +train文件夹下的txt文件为相应图片的标注信息,通常每行有5列,分别对应类别id以及标注框归一化之后的中心点坐标xy和宽高wh,例如: +``` +62 0.417040 0.206280 0.403600 0.412560 +62 0.818810 0.197933 0.174740 0.189680 +39 0.684540 0.277773 0.086240 0.358960 +0 0.620220 0.725853 0.751680 0.525840 +63 0.197190 0.364053 0.394380 0.669653 +39 0.932330 0.226240 0.034820 0.076640 +``` + +instances_val.json为coco格式的验证集标注,可直接调用coco api用于map的计算。 + +使用MindYOLO套件完成自定义数据集finetune的实际案例可参考[README.md](https://github.com/mindspore-lab/mindyolo/blob/master/examples/finetune_SHWD/README.md) \ No newline at end of file