From 09068119378493644efdb5c28f87f0617aa2ef07 Mon Sep 17 00:00:00 2001 From: Sungman Cho Date: Thu, 20 Jul 2023 17:00:18 +0900 Subject: [PATCH] Fix H-label classification (#2377) * Fix h-labelissue * Update unit tests * Make black happy * Fix unittests * Make black happy * Fix update heades information func * Update the logic: consider the loss per batch --- .../adapters/mmcls/configurer.py | 5 ++++ .../adapters/mmcls/datasets/otx_datasets.py | 17 +++++++++-- .../custom_hierarchical_linear_cls_head.py | 30 +++++++++++-------- ...custom_hierarchical_non_linear_cls_head.py | 30 +++++++++++-------- .../classification/utils/cls_utils.py | 1 + .../adapters/mmcls/data/test_datasets.py | 25 +++++++++++++++- .../test_custom_hierarchical_cls_head.py | 27 ++++++++++++----- .../adapters/mmcls/test_configurer.py | 11 +++++++ 8 files changed, 108 insertions(+), 38 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/configurer.py b/src/otx/algorithms/classification/adapters/mmcls/configurer.py index d0ecbfb4e2c..fe4529679a9 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/configurer.py +++ b/src/otx/algorithms/classification/adapters/mmcls/configurer.py @@ -132,6 +132,11 @@ def configure_model(self, cfg, ir_options): # noqa: C901 cfg.model.arch_type = cfg.model.type cfg.model.type = super_type + # Hierarchical + if cfg.model.get("hierarchical"): + assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info + cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info + # OV-plugin ir_model_path = ir_options.get("ir_model_path") if ir_model_path: diff --git a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py index 70a4500d1b5..7522be7ea33 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py +++ b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py @@ -309,9 +309,7 @@ def load_annotations(self): if item_labels: num_cls_heads = self.hierarchical_info["num_multiclass_heads"] - class_indices = [0] * ( - self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"] - ) + class_indices = [0] * (num_cls_heads + self.hierarchical_info["num_multilabel_classes"]) for j in range(num_cls_heads): class_indices[j] = -1 for otx_lbl in item_labels: @@ -329,6 +327,19 @@ def load_annotations(self): self.gt_labels.append(class_indices) self.gt_labels = np.array(self.gt_labels) + self._update_heads_information() + + def _update_heads_information(self): + """Update heads information to find the empty heads. + + If there are no annotations at a specific head, this should be filtered out to calculate loss correctly. + """ + num_cls_heads = self.hierarchical_info["num_multiclass_heads"] + for head_idx in range(num_cls_heads): + labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload] + if max(labels_in_head) < 0: + self.hierarchical_info["empty_multiclass_head_indices"].append(head_idx) + @staticmethod def mean_top_k_accuracy(scores, labels, k=1): """Return mean of top-k accuracy.""" diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py index 5b3245a4f40..6776756bb61 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py @@ -87,22 +87,26 @@ def forward_train(self, cls_score, gt_label, **kwargs): cls_score = self.fc(cls_score) losses = dict(loss=0.0) + num_effective_heads_in_batch = 0 for i in range(self.hierarchical_info["num_multiclass_heads"]): - head_gt = gt_label[:, i] - head_logits = cls_score[ - :, - self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ - "head_idx_to_logits_range" - ][str(i)][1], - ] - valid_mask = head_gt >= 0 - head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if i not in self.hierarchical_info["empty_multiclass_head_indices"]: + head_gt = gt_label[:, i] + head_logits = cls_score[ + :, + self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ + "head_idx_to_logits_range" + ][str(i)][1], + ] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask].long() + if len(head_gt) > 0: + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss + num_effective_heads_in_batch += 1 if self.hierarchical_info["num_multiclass_heads"] > 1: - losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] + losses["loss"] /= num_effective_heads_in_batch if self.compute_multilabel_loss: head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :] diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py index 4b2691157e1..5397818fbf3 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py @@ -117,22 +117,26 @@ def forward_train(self, cls_score, gt_label, **kwargs): cls_score = self.classifier(cls_score) losses = dict(loss=0.0) + num_effective_heads_in_batch = 0 for i in range(self.hierarchical_info["num_multiclass_heads"]): - head_gt = gt_label[:, i] - head_logits = cls_score[ - :, - self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ - "head_idx_to_logits_range" - ][str(i)][1], - ] - valid_mask = head_gt >= 0 - head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if i not in self.hierarchical_info["empty_multiclass_head_indices"]: + head_gt = gt_label[:, i] + head_logits = cls_score[ + :, + self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ + "head_idx_to_logits_range" + ][str(i)][1], + ] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask].long() + if len(head_gt) > 0: + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss + num_effective_heads_in_batch += 1 if self.hierarchical_info["num_multiclass_heads"] > 1: - losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] + losses["loss"] /= num_effective_heads_in_batch if self.compute_multilabel_loss: head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :] diff --git a/src/otx/algorithms/classification/utils/cls_utils.py b/src/otx/algorithms/classification/utils/cls_utils.py index 8bb2b9630f2..23dc1ba1fa6 100644 --- a/src/otx/algorithms/classification/utils/cls_utils.py +++ b/src/otx/algorithms/classification/utils/cls_utils.py @@ -62,6 +62,7 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl "class_to_group_idx": class_to_idx, "all_groups": exclusive_groups + single_label_groups, "label_to_idx": label_to_idx, + "empty_multiclass_head_indices": [], } return mixed_cls_heads_info diff --git a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py index 41e6890e02d..b4719680125 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py @@ -142,9 +142,32 @@ def test_metric_hierarchical_adapter(self): dataset = OTXHierarchicalClsDataset( otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info ) - results = np.zeros((len(dataset), dataset.num_classes)) metrics = dataset.evaluate(results) assert len(metrics) > 0 assert metrics["accuracy"] > 0 + + @e2e_pytest_unit + def test_hierarchical_with_empty_heads(self): + self.task_environment, self.dataset = init_environment( + self.hyper_parameters, self.model_template, False, True, self.dataset_len + ) + class_info = get_multihead_class_info(self.task_environment.label_schema) + dataset = OTXHierarchicalClsDataset( + otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info + ) + pseudo_gt_labels = [] + pseudo_head_idx = 0 + for label in dataset.gt_labels: + pseudo_gt_label = label + pseudo_gt_label[pseudo_head_idx] = -1 + pseudo_gt_labels.append(pseudo_gt_label) + pseudo_gt_labels = np.array(pseudo_gt_labels) + + from copy import deepcopy + + pseudo_dataset = deepcopy(dataset) + pseudo_dataset.gt_labels = pseudo_gt_labels + pseudo_dataset._update_heads_information() + assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0 diff --git a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py index 11f6e100996..8f8ec9b6550 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py @@ -24,13 +24,14 @@ def head_type(self) -> None: @pytest.fixture(autouse=True) def setup(self, head_type) -> None: - self.num_classes = 3 - self.head_dim = 5 + self.num_classes = 6 + self.head_dim = 10 self.cls_heads_info = { - "num_multiclass_heads": 1, - "num_multilabel_classes": 1, - "head_idx_to_logits_range": {"0": (0, 2)}, - "num_single_label_classes": 2, + "num_multiclass_heads": 3, + "num_multilabel_classes": 0, + "head_idx_to_logits_range": {"0": (0, 2), "1": (2, 4), "2": (4, 6)}, + "num_single_label_classes": 6, + "empty_multiclass_head_indices": [], } self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0) self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum") @@ -43,13 +44,23 @@ def setup(self, head_type) -> None: ) self.default_head.init_weights() self.default_input = torch.ones((2, self.head_dim)) - self.default_gt = torch.zeros((2, 2)) + self.default_gt = torch.zeros((2, 3)) @e2e_pytest_unit def test_forward(self) -> None: result = self.default_head.forward_train(self.default_input, self.default_gt) assert "loss" in result - assert result["loss"] >= 0 + assert result["loss"] >= 0 and not torch.isnan(result["loss"]) + + empty_head_gt_full = torch.tensor([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]) + result_include_empty_full = self.default_head.forward_train(self.default_input, empty_head_gt_full) + assert "loss" in result_include_empty_full + assert result_include_empty_full["loss"] >= 0 and not torch.isnan(result_include_empty_full["loss"]) + + empty_head_gt_partial = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]) + result_include_empty_partial = self.default_head.forward_train(self.default_input, empty_head_gt_partial) + assert "loss" in result_include_empty_partial + assert result_include_empty_partial["loss"] >= 0 and not torch.isnan(result_include_empty_partial["loss"]) @e2e_pytest_unit def test_simple_test(self) -> None: diff --git a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py index 96e1efbf685..ae058a4d56d 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py @@ -23,6 +23,11 @@ def setup(self) -> None: self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py")) self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py")) + self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py")) + self.hierarchical_model_cfg = MPAConfig.fromfile( + os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py") + ) + @e2e_pytest_unit def test_configure(self, mocker): mock_cfg_base = mocker.patch.object(ClassificationConfigurer, "configure_base") @@ -119,6 +124,12 @@ def test_configure_model(self): assert self.model_cfg.model_task assert self.model_cfg.model.head.in_channels == 960 + multilabel_model_cfg = self.multilabel_model_cfg + self.configurer.configure_model(multilabel_model_cfg, ir_options) + + h_label_model_cfg = self.hierarchical_model_cfg + self.configurer.configure_model(h_label_model_cfg, ir_options) + @e2e_pytest_unit def test_configure_model_not_classification_task(self): ir_options = {"ir_model_path": {"ir_weight_path": "", "ir_weight_init": ""}}