Skip to content

Commit

Permalink
Add hlabel parent resolver and integrate HLabelData + HLabelInfo --> …
Browse files Browse the repository at this point in the history
…HLabelInfo (#3046)

* Initial commit

* Add function: add ancestor

* Integrate HLabelData and HLabelInfo

* Add unit-test for _add_ancestor func

* Fix invalid names

* Fix precommit

* Fix hlabel{

* Add type checking

* enlarge the threshold

* Fix intg test
  • Loading branch information
sungmanc authored Mar 11, 2024
1 parent 506a22f commit 8737903
Show file tree
Hide file tree
Showing 13 changed files with 504 additions and 338 deletions.
26 changes: 13 additions & 13 deletions src/otx/algo/classification/heads/custom_hlabel_linear_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

if TYPE_CHECKING:
from otx.core.data.entity.classification import HLabelData
from otx.core.data.dataset.classification import HLabelInfo


@MODELS.register_module()
Expand Down Expand Up @@ -74,19 +74,19 @@ def forward(self, feats: tuple[torch.Tensor]) -> torch.Tensor:
pre_logits = self.pre_logits(feats)
return self.fc(pre_logits)

def set_hlabel_data(self, hlabel_data: HLabelData) -> None:
def set_hlabel_info(self, hlabel_info: HLabelInfo) -> None:
"""Set hlabel information."""
self.hlabel_data = hlabel_data
self.hlabel_info = hlabel_info

def _get_gt_label(self, data_samples: list[DataSample]) -> torch.Tensor:
"""Get gt labels from data samples."""
return torch.stack([data_sample.gt_label for data_sample in data_samples])

def _get_head_idx_to_logits_range(self, hlabel_data: HLabelData, idx: int) -> tuple[int, int]:
def _get_head_idx_to_logits_range(self, hlabel_info: HLabelInfo, idx: int) -> tuple[int, int]:
"""Get head_idx_to_logits_range information from hlabel information."""
return (
hlabel_data.head_idx_to_logits_range[str(idx)][0],
hlabel_data.head_idx_to_logits_range[str(idx)][1],
hlabel_info.head_idx_to_logits_range[str(idx)][0],
hlabel_info.head_idx_to_logits_range[str(idx)][1],
)

def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwargs) -> dict:
Expand All @@ -113,9 +113,9 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa
# Multiclass loss
num_effective_heads_in_batch = 0 # consider the label removal case
for i in range(self.num_multiclass_heads):
if i not in self.hlabel_data.empty_multiclass_head_indices:
if i not in self.hlabel_info.empty_multiclass_head_indices:
head_gt = gt_labels[:, i]
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0

Expand All @@ -130,15 +130,15 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa

# Multilabel loss
if self.num_multilabel_classes > 0:
head_gt = gt_labels[:, self.hlabel_data.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
head_gt = gt_labels[:, self.hlabel_info.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
img_metas = [data_sample.metainfo for data_sample in data_samples]
head_logits = head_logits[valid_mask]
valid_label_mask = self.get_valid_label_mask(img_metas).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.hlabel_data.num_single_label_classes :]
valid_label_mask = valid_label_mask[:, self.hlabel_info.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
losses["loss"] += self.multilabel_loss(head_logits, head_gt, valid_label_mask=valid_label_mask)

Expand Down Expand Up @@ -183,7 +183,7 @@ def _get_predictions(
multiclass_pred_scores = []
multiclass_pred_labels = []
for i in range(self.num_multiclass_heads):
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
multiclass_pred = torch.softmax(multiclass_logit, dim=1)
multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)
Expand All @@ -195,7 +195,7 @@ def _get_predictions(
multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

if self.num_multilabel_classes > 0:
multilabel_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
multilabel_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]

multilabel_pred_scores = torch.sigmoid(multilabel_logits)
multilabel_pred_labels = (multilabel_pred_scores >= self.thr).int()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

if TYPE_CHECKING:
from otx.core.data.entity.classification import HLabelData
from otx.core.data.dataset.classification import HLabelInfo


@MODELS.register_module()
Expand Down Expand Up @@ -105,19 +105,19 @@ def forward(self, feats: tuple[torch.Tensor]) -> torch.Tensor:
pre_logits = self.pre_logits(feats)
return self.classifier(pre_logits)

def set_hlabel_data(self, hlabel_data: HLabelData) -> None:
def set_hlabel_info(self, hlabel_info: HLabelInfo) -> None:
"""Set hlabel information."""
self.hlabel_data = hlabel_data
self.hlabel_info = hlabel_info

def _get_gt_label(self, data_samples: list[DataSample]) -> torch.Tensor:
"""Get gt labels from data samples."""
return torch.stack([data_sample.gt_label for data_sample in data_samples])

def _get_head_idx_to_logits_range(self, hlabel_data: HLabelData, idx: int) -> tuple[int, int]:
def _get_head_idx_to_logits_range(self, hlabel_info: HLabelInfo, idx: int) -> tuple[int, int]:
"""Get head_idx_to_logits_range information from hlabel information."""
return (
hlabel_data.head_idx_to_logits_range[str(idx)][0],
hlabel_data.head_idx_to_logits_range[str(idx)][1],
hlabel_info.head_idx_to_logits_range[str(idx)][0],
hlabel_info.head_idx_to_logits_range[str(idx)][1],
)

def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwargs) -> dict:
Expand All @@ -144,9 +144,9 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa
# Multiclass loss
num_effective_heads_in_batch = 0 # consider the label removal case
for i in range(self.num_multiclass_heads):
if i not in self.hlabel_data.empty_multiclass_head_indices:
if i not in self.hlabel_info.empty_multiclass_head_indices:
head_gt = gt_labels[:, i]
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0

Expand All @@ -161,15 +161,15 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa

# Multilabel loss
if self.num_multilabel_classes > 0:
head_gt = gt_labels[:, self.hlabel_data.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
head_gt = gt_labels[:, self.hlabel_info.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
img_metas = [data_sample.metainfo for data_sample in data_samples]
head_logits = head_logits[valid_mask]
valid_label_mask = self.get_valid_label_mask(img_metas).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.hlabel_data.num_single_label_classes :]
valid_label_mask = valid_label_mask[:, self.hlabel_info.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
losses["loss"] += self.multilabel_loss(head_logits, head_gt, valid_label_mask=valid_label_mask)

Expand Down Expand Up @@ -214,7 +214,7 @@ def _get_predictions(
multiclass_pred_scores = []
multiclass_pred_labels = []
for i in range(self.num_multiclass_heads):
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
multiclass_pred = torch.softmax(multiclass_logit, dim=1)
multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)
Expand All @@ -226,7 +226,7 @@ def _get_predictions(
multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

if self.num_multilabel_classes > 0:
multilabel_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
multilabel_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]

multilabel_pred_scores = torch.sigmoid(multilabel_logits)
multilabel_pred_labels = (multilabel_pred_scores >= self.thr).int()
Expand Down
Loading

0 comments on commit 8737903

Please sign in to comment.