diff --git a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py index 99194cbc23e..70a4500d1b5 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py +++ b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py @@ -416,7 +416,10 @@ def evaluate( ) eval_results["MHAcc"] = total_acc - eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"] + if self.hierarchical_info["num_multiclass_heads"] > 0: + eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"] + else: + eval_results["avgClsAcc"] = total_acc_sl eval_results["mAP"] = mAP_value eval_results["accuracy"] = total_acc diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/classifiers/sam_classifier.py b/src/otx/algorithms/classification/adapters/mmcls/models/classifiers/sam_classifier.py index 68249a8f6be..5b03e7c5f0a 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/classifiers/sam_classifier.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/classifiers/sam_classifier.py @@ -16,6 +16,14 @@ logger = get_logger() +def is_hierarchical_chkpt(chkpt: dict): + """Detect whether previous checkpoint is hierarchical or not.""" + for k, v in chkpt.items(): + if "fc" in k: + return True + return False + + @CLASSIFIERS.register_module() class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier): """SAM-enabled ImageClassifier.""" @@ -193,11 +201,19 @@ def load_state_dict_pre_hook(module, state_dict, prefix, *args, **kwargs): # no def load_state_dict_mixing_hook( model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs ): # pylint: disable=unused-argument, too-many-branches, too-many-locals - """Modify input state_dict according to class name matching before weight loading.""" + """Modify input state_dict according to class name matching before weight loading. + + If previous training is hierarchical training, + then the current training should be hierarchical training. vice versa. + + """ backbone_type = type(model.backbone).__name__ if backbone_type not in ["OTXMobileNetV3", "OTXEfficientNet", "OTXEfficientNetV2"]: return + if model.hierarchical != is_hierarchical_chkpt(chkpt_dict): + return + # Dst to src mapping index model_classes = list(model_classes) chkpt_classes = list(chkpt_classes) @@ -249,13 +265,15 @@ def load_state_dict_mixing_hook( continue # Mix weights - chkpt_param = chkpt_dict[chkpt_name] - for module, c in enumerate(model2chkpt): - if c >= 0: - model_param[module].copy_(chkpt_param[c]) + # NOTE: Label mix is not supported for H-label classification. + if not model.hierarchical: + chkpt_param = chkpt_dict[chkpt_name] + for module, c in enumerate(model2chkpt): + if c >= 0: + model_param[module].copy_(chkpt_param[c]) - # Replace checkpoint weight by mixed weights - chkpt_dict[chkpt_name] = model_param + # Replace checkpoint weight by mixed weights + chkpt_dict[chkpt_name] = model_param def extract_feat(self, img): """Directly extract features from the backbone + neck. diff --git a/src/otx/algorithms/classification/task.py b/src/otx/algorithms/classification/task.py index eb6fab8161d..3c74230dcab 100644 --- a/src/otx/algorithms/classification/task.py +++ b/src/otx/algorithms/classification/task.py @@ -47,6 +47,8 @@ from otx.api.entities.inference_parameters import ( default_progress_callback as default_infer_progress_callback, ) +from otx.api.entities.label import LabelEntity +from otx.api.entities.label_schema import LabelGroup from otx.api.entities.metadata import FloatMetadata, FloatType from otx.api.entities.metrics import ( CurveMetric, @@ -125,16 +127,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] if self._task_environment.model is not None: self._load_model() + def _is_multi_label(self, label_groups: List[LabelGroup], all_labels: List[LabelEntity]): + """Check whether the current training mode is multi-label or not.""" + # NOTE: In the current Geti, multi-label should have `___` symbol for all group names. + find_multilabel_symbol = ["___" in getattr(i, "name", "") for i in label_groups] + return ( + (len(label_groups) > 1) and (len(label_groups) == len(all_labels)) and (False not in find_multilabel_symbol) + ) + def _set_train_mode(self): - self._multilabel = len(self._task_environment.label_schema.get_groups(False)) > 1 and len( - self._task_environment.label_schema.get_groups(False) - ) == len( - self._task_environment.get_labels(include_empty=False) - ) # noqa:E127 + label_groups = self._task_environment.label_schema.get_groups(include_empty=False) + all_labels = self._task_environment.label_schema.get_labels(include_empty=False) + + self._multilabel = self._is_multi_label(label_groups, all_labels) if self._multilabel: logger.info("Classification mode: multilabel") - - if not self._multilabel and len(self._task_environment.label_schema.get_groups(False)) > 1: + elif len(label_groups) > 1: logger.info("Classification mode: hierarchical") self._hierarchical = True self._hierarchical_info = get_hierarchical_info(self._task_environment.label_schema) diff --git a/tests/assets/datumaro_h-label_class_decremental/annotations/train.json b/tests/assets/datumaro_h-label_class_decremental/annotations/train.json new file mode 100755 index 00000000000..4bb4caae751 --- /dev/null +++ b/tests/assets/datumaro_h-label_class_decremental/annotations/train.json @@ -0,0 +1,181 @@ +{ + "info": {}, + "categories": { + "label": { + "labels": [ + { + "name": "right", + "parent": "triangle", + "attributes": [] + }, + { + "name": "multi a", + "parent": "triangle", + "attributes": [] + }, + { + "name": "equilateral", + "parent": "triangle", + "attributes": [] + }, + { + "name": "square", + "parent": "rectangle", + "attributes": [] + }, + { + "name": "triangle", + "parent": "", + "attributes": [] + }, + { + "name": "non_square", + "parent": "rectangle", + "attributes": [] + }, + { + "name": "rectangle", + "parent": "", + "attributes": [] + } + ], + "label_groups": [ + { + "name": "shape", + "group_type": "exclusive", + "labels": ["rectangle", "triangle"] + }, + { + "name": "rectangle default", + "group_type": "exclusive", + "labels": ["non_square", "square"] + }, + { + "name": "triangle default", + "group_type": "exclusive", + "labels": ["equilateral", "right"] + }, + { + "name": "shape___multiple example___multi a", + "group_type": "exclusive", + "labels": ["multi a"] + } + ], + "attributes": [] + }, + "mask": { + "colormap": [ + { + "label_id": 0, + "r": 129, + "g": 64, + "b": 123 + }, + { + "label_id": 1, + "r": 91, + "g": 105, + "b": 255 + }, + { + "label_id": 2, + "r": 91, + "g": 105, + "b": 255 + }, + { + "label_id": 3, + "r": 255, + "g": 86, + "b": 98 + }, + { + "label_id": 4, + "r": 204, + "g": 148, + "b": 218 + }, + { + "label_id": 5, + "r": 0, + "g": 251, + "b": 87 + }, + { + "label_id": 6, + "r": 84, + "g": 143, + "b": 173 + } + ] + } + }, + "items": [ + { + "id": "a", + "annotations": [ + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 4 + }, + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 5 + }, + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 1 + } + ], + "image": { + "path": "a.jpg", + "size": [10, 5] + }, + "media": { + "path": "" + } + }, + { + "id": "b", + "annotations": [ + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 6 + }, + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 5 + }, + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 2 + } + ], + "image": { + "path": "b.jpg", + "size": [10, 5] + }, + "media": { + "path": "" + } + } + ] +} diff --git a/tests/assets/datumaro_h-label_class_decremental/annotations/validation.json b/tests/assets/datumaro_h-label_class_decremental/annotations/validation.json new file mode 100755 index 00000000000..d97956af708 --- /dev/null +++ b/tests/assets/datumaro_h-label_class_decremental/annotations/validation.json @@ -0,0 +1,141 @@ +{ + "info": {}, + "categories": { + "label": { + "labels": [ + { + "name": "right", + "parent": "triangle", + "attributes": [] + }, + { + "name": "multi a", + "parent": "triangle", + "attributes": [] + }, + { + "name": "equilateral", + "parent": "triangle", + "attributes": [] + }, + { + "name": "square", + "parent": "rectangle", + "attributes": [] + }, + { + "name": "triangle", + "parent": "", + "attributes": [] + }, + { + "name": "non_square", + "parent": "rectangle", + "attributes": [] + }, + { + "name": "rectangle", + "parent": "", + "attributes": [] + } + ], + "label_groups": [ + { + "name": "shape", + "group_type": "exclusive", + "labels": ["rectangle", "triangle"] + }, + { + "name": "rectangle default", + "group_type": "exclusive", + "labels": ["non_square", "square"] + }, + { + "name": "triangle default", + "group_type": "exclusive", + "labels": ["equilateral", "right"] + }, + { + "name": "shape___multiple example___multi a", + "group_type": "exclusive", + "labels": ["multi a"] + } + ], + "attributes": [] + }, + "mask": { + "colormap": [ + { + "label_id": 0, + "r": 129, + "g": 64, + "b": 123 + }, + { + "label_id": 1, + "r": 91, + "g": 105, + "b": 255 + }, + { + "label_id": 2, + "r": 91, + "g": 105, + "b": 255 + }, + { + "label_id": 3, + "r": 255, + "g": 86, + "b": 98 + }, + { + "label_id": 4, + "r": 204, + "g": 148, + "b": 218 + }, + { + "label_id": 5, + "r": 0, + "g": 251, + "b": 87 + }, + { + "label_id": 6, + "r": 84, + "g": 143, + "b": 173 + } + ] + } + }, + "items": [ + { + "id": "d", + "annotations": [ + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 5 + }, + { + "id": 0, + "type": "label", + "attributes": {}, + "group": 0, + "label_id": 2 + } + ], + "image": { + "path": "d.jpg", + "size": [10, 5] + }, + "media": { + "path": "" + } + } + ] +} diff --git a/tests/assets/datumaro_h-label_class_decremental/images/train/a.jpg b/tests/assets/datumaro_h-label_class_decremental/images/train/a.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_h-label_class_decremental/images/train/a.jpg differ diff --git a/tests/assets/datumaro_h-label_class_decremental/images/train/b.jpg b/tests/assets/datumaro_h-label_class_decremental/images/train/b.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_h-label_class_decremental/images/train/b.jpg differ diff --git a/tests/assets/datumaro_h-label_class_decremental/images/validation/d.jpg b/tests/assets/datumaro_h-label_class_decremental/images/validation/d.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_h-label_class_decremental/images/validation/d.jpg differ diff --git a/tests/assets/datumaro_multilabel/annotations/train.json b/tests/assets/datumaro_multilabel/annotations/train.json index 3b44dd90cdd..1e5e21258c5 100755 --- a/tests/assets/datumaro_multilabel/annotations/train.json +++ b/tests/assets/datumaro_multilabel/annotations/train.json @@ -4,12 +4,17 @@ "label": { "label_groups": [ { - "name": "tom", + "name": "Classification labels___tom", "group_type": "exclusive", "labels": ["tom"] }, { - "name": "mary", + "name": "Classification labels___john", + "group_type": "exclusive", + "labels": ["john"] + }, + { + "name": "Classification labels___mary", "group_type": "exclusive", "labels": ["mary"] } @@ -20,6 +25,11 @@ "parent": "", "attributes": [] }, + { + "name": "john", + "parent": "", + "attributes": [] + }, { "name": "mary", "parent": "", @@ -42,7 +52,7 @@ { "id": 1, "type": "label", - "group": 0, + "group": 1, "label_id": 1 } ], @@ -56,8 +66,8 @@ { "id": 0, "type": "label", - "group": 0, - "label_id": 0 + "group": 2, + "label_id": 2 } ], "image": { diff --git a/tests/assets/datumaro_multilabel/annotations/validation.json b/tests/assets/datumaro_multilabel/annotations/validation.json index 570904a8add..7ec693402f7 100755 --- a/tests/assets/datumaro_multilabel/annotations/validation.json +++ b/tests/assets/datumaro_multilabel/annotations/validation.json @@ -4,12 +4,17 @@ "label": { "label_groups": [ { - "name": "tom", + "name": "Classification labels___tom", "group_type": "exclusive", "labels": ["tom"] }, { - "name": "mary", + "name": "Classification labels___john", + "group_type": "exclusive", + "labels": ["john"] + }, + { + "name": "Classification labels___mary", "group_type": "exclusive", "labels": ["mary"] } @@ -20,6 +25,11 @@ "parent": "", "attributes": [] }, + { + "name": "john", + "parent": "", + "attributes": [] + }, { "name": "mary", "parent": "", @@ -42,8 +52,14 @@ { "id": 1, "type": "label", - "group": 0, + "group": 1, "label_id": 1 + }, + { + "id": 1, + "type": "label", + "group": 2, + "label_id": 2 } ], "image": { diff --git a/tests/assets/datumaro_multilabel_class_decremental/annotations/train.json b/tests/assets/datumaro_multilabel_class_decremental/annotations/train.json new file mode 100755 index 00000000000..d8098a84ecf --- /dev/null +++ b/tests/assets/datumaro_multilabel_class_decremental/annotations/train.json @@ -0,0 +1,68 @@ +{ + "info": {}, + "categories": { + "label": { + "label_groups": [ + { + "name": "Classification labels___tom", + "group_type": "exclusive", + "labels": ["tom"] + }, + { + "name": "Classification labels___mary", + "group_type": "exclusive", + "labels": ["mary"] + } + ], + "labels": [ + { + "name": "tom", + "parent": "", + "attributes": [] + }, + { + "name": "mary", + "parent": "", + "attributes": [] + } + ], + "attributes": [] + } + }, + "items": [ + { + "id": "a", + "annotations": [ + { + "id": 0, + "type": "label", + "group": 0, + "label_id": 0 + }, + { + "id": 1, + "type": "label", + "group": 0, + "label_id": 1 + } + ], + "image": { + "path": "a.jpg" + } + }, + { + "id": "b", + "annotations": [ + { + "id": 0, + "type": "label", + "group": 0, + "label_id": 0 + } + ], + "image": { + "path": "b.jpg" + } + } + ] +} diff --git a/tests/assets/datumaro_multilabel_class_decremental/annotations/validation.json b/tests/assets/datumaro_multilabel_class_decremental/annotations/validation.json new file mode 100755 index 00000000000..49b35f786d3 --- /dev/null +++ b/tests/assets/datumaro_multilabel_class_decremental/annotations/validation.json @@ -0,0 +1,54 @@ +{ + "info": {}, + "categories": { + "label": { + "label_groups": [ + { + "name": "Classification labels___tom", + "group_type": "exclusive", + "labels": ["tom"] + }, + { + "name": "Classification labels___mary", + "group_type": "exclusive", + "labels": ["mary"] + } + ], + "labels": [ + { + "name": "tom", + "parent": "", + "attributes": [] + }, + { + "name": "mary", + "parent": "", + "attributes": [] + } + ], + "attributes": [] + } + }, + "items": [ + { + "id": "d", + "annotations": [ + { + "id": 0, + "type": "label", + "group": 0, + "label_id": 0 + }, + { + "id": 1, + "type": "label", + "group": 0, + "label_id": 1 + } + ], + "image": { + "path": "d.jpg" + } + } + ] +} diff --git a/tests/assets/datumaro_multilabel_class_decremental/images/train/a.jpg b/tests/assets/datumaro_multilabel_class_decremental/images/train/a.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_multilabel_class_decremental/images/train/a.jpg differ diff --git a/tests/assets/datumaro_multilabel_class_decremental/images/train/b.jpg b/tests/assets/datumaro_multilabel_class_decremental/images/train/b.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_multilabel_class_decremental/images/train/b.jpg differ diff --git a/tests/assets/datumaro_multilabel_class_decremental/images/validation/d.jpg b/tests/assets/datumaro_multilabel_class_decremental/images/validation/d.jpg new file mode 100644 index 00000000000..222682d80bf Binary files /dev/null and b/tests/assets/datumaro_multilabel_class_decremental/images/validation/d.jpg differ diff --git a/tests/integration/cli/classification/test_classification.py b/tests/integration/cli/classification/test_classification.py index 59e06be73b0..45a98517284 100644 --- a/tests/integration/cli/classification/test_classification.py +++ b/tests/integration/cli/classification/test_classification.py @@ -324,6 +324,18 @@ def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_train_testing(template, tmp_dir_path, otx_dir, args_m) + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_train_cls_decr(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "multi_label_cls/test_cls_decr" + otx_train_testing(template, tmp_dir_path, otx_dir, args_m) + template_work_dir = get_template_dir(template, tmp_dir_path) + args1 = copy.deepcopy(args_m) + args1["--train-data-roots"] = "tests/assets/datumaro_multilabel_class_decremental" + args1["--val-data-roots"] = "tests/assets/datumaro_multilabel_class_decremental" + args1["--load-weights"] = f"{template_work_dir}/trained_{template.model_template_id}/models/weights.pth" + otx_train_testing(template, tmp_dir_path, otx_dir, args1) + @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) @pytest.mark.parametrize("dump_features", [True, False]) @@ -452,6 +464,18 @@ def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "h_label_cls" otx_train_testing(template, tmp_dir_path, otx_dir, args_h) + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_train_cls_decr(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "h_label_cls/test_cls_decr" + otx_train_testing(template, tmp_dir_path, otx_dir, args_h) + template_work_dir = get_template_dir(template, tmp_dir_path) + args1 = copy.deepcopy(args_h) + args1["--train-data-roots"] = "tests/assets/datumaro_h-label_class_decremental" + args1["--val-data-roots"] = "tests/assets/datumaro_h-label_class_decremental" + args1["--load-weights"] = f"{template_work_dir}/trained_{template.model_template_id}/models/weights.pth" + otx_train_testing(template, tmp_dir_path, otx_dir, args1) + @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) @pytest.mark.parametrize("dump_features", [True, False])