-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add finetune tutorial #124
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
### 自定义数据集finetune流程 | ||
|
||
本文以安全帽佩戴检测数据集(SHWD)为例,介绍自定义数据集在MindYOLO上进行finetune的主要流程。 | ||
|
||
#### 数据集格式转换 | ||
|
||
[SHWD数据集](https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset/tree/master)采用voc格式的数据标注,其文件目录如下所示: | ||
``` | ||
ROOT_DIR | ||
├── Annotations | ||
│ ├── 000000.xml | ||
│ └── 000002.xml | ||
├── ImageSets | ||
│ └── Main | ||
│ ├── test.txt | ||
│ ├── train.txt | ||
│ ├── trainval.txt | ||
│ └── val.txt | ||
└── JPEGImages | ||
├── 000000.jpg | ||
└── 000002.jpg | ||
``` | ||
其中,ImageSets/Main文件下的txt文件中每行代表相应子集中单张图片不含后缀的文件名,例如: | ||
``` | ||
000002 | ||
000005 | ||
000019 | ||
000022 | ||
000027 | ||
000034 | ||
``` | ||
|
||
由于MindYOLO在验证阶段选用图片名称作为image_id,因此图片名称只能为数值类型,而不能为字符串类型,还需要对图片进行改名。对SHWD数据集格式的转换包含如下步骤: | ||
* 将图片复制到相应的路径下并改名 | ||
* 在根目录下相应的txt文件中写入该图片的相对路径 | ||
* 解析xml文件,在相应路径下生成对应的txt标注文件 | ||
* 验证集还需生成最终的json文件 | ||
|
||
详细实现可参考[convert_shwd2yolo.py](./convert_shwd2yolo.py)。运行方式如下: | ||
|
||
```shell | ||
python examples/finetune_SHWD/convert_shwd2yolo.py --root_dir /path_to_shwd/SHWD | ||
``` | ||
|
||
运行以上命令将在不改变原数据集的前提下,在同级目录生成yolo格式的SHWD数据集。 | ||
|
||
#### 预训练模型文件转换 | ||
|
||
由于SHWD数据集只有7000+张图片,选择yolov7-tiny进行该数据集的训练,可下载MindYOLO提供的在coco数据集上训练好的[模型文件](https://github.com/mindspore-lab/mindyolo/blob/master/MODEL_ZOO.md)作为预训练模型。由于coco数据集含有80种物体类别,SHWD数据集只有两类,模型的最后一层head层输出与类别数nc有关,因此需将预训练模型文件的最后一层去掉, 可参考[convert_yolov7-tiny_pretrain_ckpt.py](./convert_yolov7-tiny_pretrain_ckpt.py)。运行方式如下: | ||
|
||
```shell | ||
python examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py | ||
``` | ||
|
||
#### 模型微调(Finetune) | ||
|
||
简要的训练流程可参考[finetune_shwd.py](./finetune_shwd.py) | ||
|
||
* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: | ||
|
||
```shell | ||
mpirun --allow-run-as-root -n 8 python examples/finetune_SHWD/finetune_shwd.py --config ./examples/finetune_SHWD/yolov7-tiny.yaml --is_parallel True | ||
``` | ||
|
||
* 在单卡NPU/GPU/CPU上训练模型: | ||
|
||
```shell | ||
python examples/finetune_SHWD/finetune_shwd.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml | ||
``` | ||
|
||
*注意:直接用yolov7-tiny默认coco参数在SHWD数据集上训练,可取得AP50 87.0的精度。将lr_init参数由0.01改为0.001,即可实现ap50为89.2的精度结果。* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
import os | ||
from pathlib import Path | ||
import argparse | ||
import shutil | ||
import xml.etree.ElementTree as ET | ||
import collections | ||
import json | ||
from tqdm import tqdm | ||
|
||
|
||
category_set = ['person', 'hat'] | ||
coco = dict() | ||
coco['images'] = [] | ||
coco['type'] = 'instances' | ||
coco['annotations'] = [] | ||
coco['categories'] = [] | ||
|
||
category_item_id = 0 | ||
annotation_id = 0 | ||
image_index = 0 | ||
|
||
|
||
def addCatItem(name): | ||
global category_item_id | ||
category_item = collections.OrderedDict() | ||
category_item['supercategory'] = 'none' | ||
category_item['id'] = category_item_id | ||
category_item['name'] = name | ||
coco['categories'].append(category_item) | ||
category_item_id += 1 | ||
|
||
|
||
def addImgItem(image_id, size): | ||
file_name = str(image_id).zfill(8) + '.jpg' | ||
if not size['width']: | ||
raise Exception('Could not find width tag in xml file.') | ||
if not size['height']: | ||
raise Exception('Could not find height tag in xml file.') | ||
|
||
image_item = collections.OrderedDict() | ||
|
||
image_item['file_name'] = file_name | ||
image_item['width'] = size['width'] | ||
image_item['height'] = size['height'] | ||
image_item['id'] = image_id | ||
coco['images'].append(image_item) | ||
|
||
|
||
def addAnnoItem(image_id, category_id, bbox): | ||
global annotation_id | ||
|
||
annotation_item = collections.OrderedDict() | ||
annotation_item['segmentation'] = [] | ||
|
||
# segmentation | ||
seg = [] | ||
# left_top | ||
seg.append(bbox[0]) | ||
seg.append(bbox[1]) | ||
# left_bottom | ||
seg.append(bbox[0]) | ||
seg.append(bbox[1] + bbox[3]) | ||
# right_bottom | ||
seg.append(bbox[0] + bbox[2]) | ||
seg.append(bbox[1] + bbox[3]) | ||
# right_top | ||
seg.append(bbox[0] + bbox[2]) | ||
seg.append(bbox[1]) | ||
|
||
annotation_item['segmentation'].append(seg) | ||
annotation_item['area'] = bbox[2] * bbox[3] | ||
annotation_item['iscrowd'] = 0 | ||
annotation_item['image_id'] = image_id | ||
annotation_item['bbox'] = bbox | ||
annotation_item['category_id'] = category_id | ||
annotation_item['id'] = annotation_id | ||
annotation_item['ignore'] = 0 | ||
annotation_id += 1 | ||
coco['annotations'].append(annotation_item) | ||
|
||
|
||
def xxyy2xywhn(size, box): | ||
dw = 1. / (size[0]) | ||
dh = 1. / (size[1]) | ||
x = (box[0] + box[1]) / 2.0 - 1 | ||
y = (box[2] + box[3]) / 2.0 - 1 | ||
w = box[1] - box[0] | ||
h = box[3] - box[2] | ||
x = x * dw | ||
w = w * dw | ||
y = y * dh | ||
h = h * dh | ||
x = round(x, 6) | ||
w = round(w, 6) | ||
y = round(y, 6) | ||
h = round(h, 6) | ||
return x, y, w, h | ||
|
||
|
||
def xml2txt(xml_path, txt_path): | ||
in_file = open(xml_path, encoding='utf-8') | ||
out_file = open(txt_path, 'w', encoding='utf-8') | ||
tree = ET.parse(in_file) | ||
root = tree.getroot() | ||
size = root.find('size') | ||
w = int(size.find('width').text) | ||
h = int(size.find('height').text) | ||
for obj in root.iter('object'): | ||
difficult = obj.find('difficult').text | ||
cls = obj.find('name').text | ||
if cls not in category_set or int(difficult) == 1: | ||
continue | ||
cls_id = category_set.index(cls) | ||
xmlbox = obj.find('bndbox') | ||
x1, x2, y1, y2 = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), | ||
float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) | ||
|
||
# clip x2, y2 to normal range | ||
if x2 > w: | ||
x2 = w | ||
if y2 > h: | ||
y2 = h | ||
|
||
# xyxy2xywhn | ||
bbox = (x1, x2, y1, y2) | ||
bbox = xxyy2xywhn((w, h), bbox) | ||
out_file.write(str(cls_id) + " " + | ||
" ".join([str(a) for a in bbox]) + '\n') | ||
|
||
|
||
def xml2json(image_index, xml_path): | ||
bndbox = dict() | ||
size = dict() | ||
size['width'] = None | ||
size['height'] = None | ||
|
||
tree = ET.parse(xml_path) | ||
root = tree.getroot() | ||
|
||
if root.tag != 'annotation': | ||
raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag)) | ||
|
||
# elem format: <folder>, <filename>, <size>, <object> | ||
for elem in root: | ||
if elem.tag == 'folder' or elem.tag == 'filename' or elem.tag == 'path' or elem.tag == 'source': | ||
continue | ||
|
||
elif elem.tag == 'size': | ||
# add image information, like file_name, size, image_id | ||
for subelem in elem: | ||
size[subelem.tag] = int(subelem.text) | ||
addImgItem(image_index, size) | ||
|
||
elif elem.tag == 'object': | ||
for subelem in elem: | ||
if subelem.tag == 'name': | ||
object_name = subelem.text | ||
current_category_id = category_set.index(object_name) | ||
|
||
elif subelem.tag == 'bndbox': | ||
for option in subelem: | ||
bndbox[option.tag] = int(option.text) | ||
|
||
bbox = [] | ||
bbox.append(bndbox['xmin']) | ||
bbox.append(bndbox['ymin']) | ||
bbox.append(bndbox['xmax'] - bndbox['xmin']) | ||
bbox.append(bndbox['ymax'] - bndbox['ymin']) | ||
|
||
# add bound box information, include area,image_id, bbox, category_id, id and so on | ||
addAnnoItem(image_index, current_category_id, bbox) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--root_dir', default='', type=str, help='root directory of data set') | ||
opt = parser.parse_args() | ||
|
||
# generate directory structure | ||
root_dir = opt.root_dir | ||
new_dir = os.path.join(root_dir, '..', 'SHWD') | ||
os.makedirs(os.path.join(new_dir, 'images', 'train')) | ||
os.makedirs(os.path.join(new_dir, 'images', 'val')) | ||
os.makedirs(os.path.join(new_dir, 'labels', 'train')) | ||
os.makedirs(os.path.join(new_dir, 'labels', 'val')) | ||
os.makedirs(os.path.join(new_dir, 'annotations')) | ||
|
||
train_txt_yolo = open(os.path.join(new_dir, 'train.txt'), 'w') | ||
val_txt_yolo = open(os.path.join(new_dir, 'val.txt'), 'w') | ||
|
||
images_path = os.path.join(root_dir, 'JPEGImages') | ||
labels_path = os.path.join(root_dir, 'Annotations') | ||
|
||
train_set_txt = os.path.join(root_dir, 'ImageSets', 'Main', 'trainval.txt') | ||
with open(train_set_txt, 'r', encoding='utf-8') as f: | ||
for line in tqdm(f.readlines(), desc='train_set'): | ||
stem = line.strip('\n') | ||
old_path = os.path.join(images_path, stem + '.jpg') | ||
if not os.path.exists(old_path): | ||
old_path = os.path.join(images_path, stem + '.JPG') | ||
|
||
# copy train_set image to new path | ||
new_images_path = os.path.join(new_dir, 'images', 'train') | ||
shutil.copy(old_path, new_images_path) | ||
|
||
# rename image_file to continuous number | ||
old_name = Path(old_path).name | ||
new_stem = str(image_index).zfill(8) | ||
os.rename(os.path.join(new_images_path, old_name), os.path.join(new_images_path, new_stem + '.jpg')) | ||
|
||
# write the relative path of image to train.txt | ||
train_txt_yolo.write('./images/train/' + new_stem + '.jpg' + '\n') | ||
|
||
# convert xml file to txt file | ||
xml_path = os.path.join(labels_path, stem + '.xml') | ||
txt_path = os.path.join(new_dir, 'labels', 'train', new_stem + '.txt') | ||
xml2txt(xml_path, txt_path) | ||
|
||
image_index += 1 | ||
|
||
val_set_txt = os.path.join(root_dir, 'ImageSets', 'Main', 'test.txt') | ||
with open(val_set_txt, 'r', encoding='utf-8') as f: | ||
for line in tqdm(f.readlines(), desc='val_set'): | ||
stem = line.strip('\n') | ||
old_path = os.path.join(images_path, stem + '.jpg') | ||
if not os.path.exists(old_path): | ||
old_path = os.path.join(images_path, stem + '.JPG') | ||
|
||
# copy val_set image to new path | ||
new_images_path = os.path.join(new_dir, 'images', 'val') | ||
shutil.copy(old_path, new_images_path) | ||
|
||
# rename image_file to continuous number | ||
old_name = Path(old_path).name | ||
new_stem = str(image_index).zfill(8) | ||
os.rename(os.path.join(new_images_path, old_name), os.path.join(new_images_path, new_stem + '.jpg')) | ||
|
||
# write the relative path of image to val.txt | ||
val_txt_yolo.write('./images/val/' + new_stem + '.jpg' + '\n') | ||
|
||
# convert xml file to txt file | ||
xml_path = os.path.join(labels_path, stem + '.xml') | ||
txt_path = os.path.join(new_dir, 'labels', 'val', new_stem + '.txt') | ||
xml2txt(xml_path, txt_path) | ||
|
||
# convert xml file to json file | ||
xml2json(image_index, xml_path) | ||
|
||
image_index += 1 | ||
|
||
for categoryname in category_set: | ||
addCatItem(categoryname) | ||
|
||
train_txt_yolo.close() | ||
val_txt_yolo.close() | ||
|
||
# save ground truth json file | ||
json_file = os.path.join(new_dir, 'annotations', 'instances_val2017.json') | ||
json.dump(coco, open(json_file, 'w')) | ||
15 changes: 15 additions & 0 deletions
15
examples/finetune_SHWD/convert_yolov7-tiny_pretrain_ckpt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import mindspore as ms | ||
|
||
|
||
def convert_weight(ori_weight, new_weight): | ||
new_ckpt = [] | ||
param_dict = ms.load_checkpoint(ori_weight) | ||
for k, v in param_dict.items(): | ||
if '77' in k: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 77为什么要drop掉? |
||
continue | ||
new_ckpt.append({'name': k, 'data': v}) | ||
ms.save_checkpoint(new_ckpt, new_weight) | ||
|
||
|
||
if __name__ == '__main__': | ||
convert_weight('./yolov7-tiny_300e_mAP375-d8972c94.ckpt', './yolov7-tiny_pretrain.ckpt') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
教程中没有对关键代码的解释,需要在代码中关键部分加注释