Skip to content

Commit

Permalink
Fix H-label classification (#2377)
Browse files Browse the repository at this point in the history
* Fix h-labelissue

* Update unit tests

* Make black happy

* Fix unittests

* Make black happy

* Fix update heades information func

* Update the logic: consider the loss per batch
  • Loading branch information
sungmanc authored Jul 20, 2023
1 parent 43eb838 commit 0906811
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def configure_model(self, cfg, ir_options): # noqa: C901
cfg.model.arch_type = cfg.model.type
cfg.model.type = super_type

# Hierarchical
if cfg.model.get("hierarchical"):
assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info
cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info

# OV-plugin
ir_model_path = ir_options.get("ir_model_path")
if ir_model_path:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ def load_annotations(self):
if item_labels:
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]

class_indices = [0] * (
self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"]
)
class_indices = [0] * (num_cls_heads + self.hierarchical_info["num_multilabel_classes"])
for j in range(num_cls_heads):
class_indices[j] = -1
for otx_lbl in item_labels:
Expand All @@ -329,6 +327,19 @@ def load_annotations(self):
self.gt_labels.append(class_indices)
self.gt_labels = np.array(self.gt_labels)

self._update_heads_information()

def _update_heads_information(self):
"""Update heads information to find the empty heads.
If there are no annotations at a specific head, this should be filtered out to calculate loss correctly.
"""
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]
for head_idx in range(num_cls_heads):
labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload]
if max(labels_in_head) < 0:
self.hierarchical_info["empty_multiclass_head_indices"].append(head_idx)

@staticmethod
def mean_top_k_accuracy(scores, labels, k=1):
"""Return mean of top-k accuracy."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
cls_score = self.fc(cls_score)

losses = dict(loss=0.0)
num_effective_heads_in_batch = 0
for i in range(self.hierarchical_info["num_multiclass_heads"]):
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
if len(head_gt) > 0:
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
cls_score = self.classifier(cls_score)

losses = dict(loss=0.0)
num_effective_heads_in_batch = 0
for i in range(self.hierarchical_info["num_multiclass_heads"]):
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
if len(head_gt) > 0:
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]
Expand Down
1 change: 1 addition & 0 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl
"class_to_group_idx": class_to_idx,
"all_groups": exclusive_groups + single_label_groups,
"label_to_idx": label_to_idx,
"empty_multiclass_head_indices": [],
}
return mixed_cls_heads_info

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,32 @@ def test_metric_hierarchical_adapter(self):
dataset = OTXHierarchicalClsDataset(
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
)

results = np.zeros((len(dataset), dataset.num_classes))
metrics = dataset.evaluate(results)

assert len(metrics) > 0
assert metrics["accuracy"] > 0

@e2e_pytest_unit
def test_hierarchical_with_empty_heads(self):
self.task_environment, self.dataset = init_environment(
self.hyper_parameters, self.model_template, False, True, self.dataset_len
)
class_info = get_multihead_class_info(self.task_environment.label_schema)
dataset = OTXHierarchicalClsDataset(
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
)
pseudo_gt_labels = []
pseudo_head_idx = 0
for label in dataset.gt_labels:
pseudo_gt_label = label
pseudo_gt_label[pseudo_head_idx] = -1
pseudo_gt_labels.append(pseudo_gt_label)
pseudo_gt_labels = np.array(pseudo_gt_labels)

from copy import deepcopy

pseudo_dataset = deepcopy(dataset)
pseudo_dataset.gt_labels = pseudo_gt_labels
pseudo_dataset._update_heads_information()
assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ def head_type(self) -> None:

@pytest.fixture(autouse=True)
def setup(self, head_type) -> None:
self.num_classes = 3
self.head_dim = 5
self.num_classes = 6
self.head_dim = 10
self.cls_heads_info = {
"num_multiclass_heads": 1,
"num_multilabel_classes": 1,
"head_idx_to_logits_range": {"0": (0, 2)},
"num_single_label_classes": 2,
"num_multiclass_heads": 3,
"num_multilabel_classes": 0,
"head_idx_to_logits_range": {"0": (0, 2), "1": (2, 4), "2": (4, 6)},
"num_single_label_classes": 6,
"empty_multiclass_head_indices": [],
}
self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0)
self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum")
Expand All @@ -43,13 +44,23 @@ def setup(self, head_type) -> None:
)
self.default_head.init_weights()
self.default_input = torch.ones((2, self.head_dim))
self.default_gt = torch.zeros((2, 2))
self.default_gt = torch.zeros((2, 3))

@e2e_pytest_unit
def test_forward(self) -> None:
result = self.default_head.forward_train(self.default_input, self.default_gt)
assert "loss" in result
assert result["loss"] >= 0
assert result["loss"] >= 0 and not torch.isnan(result["loss"])

empty_head_gt_full = torch.tensor([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
result_include_empty_full = self.default_head.forward_train(self.default_input, empty_head_gt_full)
assert "loss" in result_include_empty_full
assert result_include_empty_full["loss"] >= 0 and not torch.isnan(result_include_empty_full["loss"])

empty_head_gt_partial = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
result_include_empty_partial = self.default_head.forward_train(self.default_input, empty_head_gt_partial)
assert "loss" in result_include_empty_partial
assert result_include_empty_partial["loss"] >= 0 and not torch.isnan(result_include_empty_partial["loss"])

@e2e_pytest_unit
def test_simple_test(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def setup(self) -> None:
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py"))
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py"))

self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py"))
self.hierarchical_model_cfg = MPAConfig.fromfile(
os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py")
)

@e2e_pytest_unit
def test_configure(self, mocker):
mock_cfg_base = mocker.patch.object(ClassificationConfigurer, "configure_base")
Expand Down Expand Up @@ -119,6 +124,12 @@ def test_configure_model(self):
assert self.model_cfg.model_task
assert self.model_cfg.model.head.in_channels == 960

multilabel_model_cfg = self.multilabel_model_cfg
self.configurer.configure_model(multilabel_model_cfg, ir_options)

h_label_model_cfg = self.hierarchical_model_cfg
self.configurer.configure_model(h_label_model_cfg, ir_options)

@e2e_pytest_unit
def test_configure_model_not_classification_task(self):
ir_options = {"ir_model_path": {"ir_weight_path": "", "ir_weight_init": ""}}
Expand Down

0 comments on commit 0906811

Please sign in to comment.