Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hlabel parent resolver and integrate HLabelData + HLabelInfo --> HLabelInfo #3046

Merged
merged 18 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

if TYPE_CHECKING:
from otx.core.data.entity.classification import HLabelData
from otx.core.data.dataset.classification import HLabelInfo


@MODELS.register_module()
Expand Down Expand Up @@ -74,19 +74,19 @@ def forward(self, feats: tuple[torch.Tensor]) -> torch.Tensor:
pre_logits = self.pre_logits(feats)
return self.fc(pre_logits)

def set_hlabel_data(self, hlabel_data: HLabelData) -> None:
def set_hlabel_info(self, hlabel_info: HLabelInfo) -> None:
"""Set hlabel information."""
self.hlabel_data = hlabel_data
self.hlabel_info = hlabel_info

def _get_gt_label(self, data_samples: list[DataSample]) -> torch.Tensor:
"""Get gt labels from data samples."""
return torch.stack([data_sample.gt_label for data_sample in data_samples])

def _get_head_idx_to_logits_range(self, hlabel_data: HLabelData, idx: int) -> tuple[int, int]:
def _get_head_idx_to_logits_range(self, hlabel_info: HLabelInfo, idx: int) -> tuple[int, int]:
"""Get head_idx_to_logits_range information from hlabel information."""
return (
hlabel_data.head_idx_to_logits_range[str(idx)][0],
hlabel_data.head_idx_to_logits_range[str(idx)][1],
hlabel_info.head_idx_to_logits_range[str(idx)][0],
hlabel_info.head_idx_to_logits_range[str(idx)][1],
)

def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwargs) -> dict:
Expand All @@ -113,9 +113,9 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa
# Multiclass loss
num_effective_heads_in_batch = 0 # consider the label removal case
for i in range(self.num_multiclass_heads):
if i not in self.hlabel_data.empty_multiclass_head_indices:
if i not in self.hlabel_info.empty_multiclass_head_indices:
head_gt = gt_labels[:, i]
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0

Expand All @@ -130,15 +130,15 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa

# Multilabel loss
if self.num_multilabel_classes > 0:
head_gt = gt_labels[:, self.hlabel_data.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
head_gt = gt_labels[:, self.hlabel_info.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
img_metas = [data_sample.metainfo for data_sample in data_samples]
head_logits = head_logits[valid_mask]
valid_label_mask = self.get_valid_label_mask(img_metas).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.hlabel_data.num_single_label_classes :]
valid_label_mask = valid_label_mask[:, self.hlabel_info.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
losses["loss"] += self.multilabel_loss(head_logits, head_gt, valid_label_mask=valid_label_mask)

Expand Down Expand Up @@ -183,7 +183,7 @@ def _get_predictions(
multiclass_pred_scores = []
multiclass_pred_labels = []
for i in range(self.num_multiclass_heads):
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
multiclass_pred = torch.softmax(multiclass_logit, dim=1)
multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)
Expand All @@ -195,7 +195,7 @@ def _get_predictions(
multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

if self.num_multilabel_classes > 0:
multilabel_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
multilabel_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]

multilabel_pred_scores = torch.sigmoid(multilabel_logits)
multilabel_pred_labels = (multilabel_pred_scores >= self.thr).int()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

if TYPE_CHECKING:
from otx.core.data.entity.classification import HLabelData
from otx.core.data.dataset.classification import HLabelInfo


@MODELS.register_module()
Expand Down Expand Up @@ -105,19 +105,19 @@ def forward(self, feats: tuple[torch.Tensor]) -> torch.Tensor:
pre_logits = self.pre_logits(feats)
return self.classifier(pre_logits)

def set_hlabel_data(self, hlabel_data: HLabelData) -> None:
def set_hlabel_info(self, hlabel_info: HLabelInfo) -> None:
"""Set hlabel information."""
self.hlabel_data = hlabel_data
self.hlabel_info = hlabel_info

def _get_gt_label(self, data_samples: list[DataSample]) -> torch.Tensor:
"""Get gt labels from data samples."""
return torch.stack([data_sample.gt_label for data_sample in data_samples])

def _get_head_idx_to_logits_range(self, hlabel_data: HLabelData, idx: int) -> tuple[int, int]:
def _get_head_idx_to_logits_range(self, hlabel_info: HLabelInfo, idx: int) -> tuple[int, int]:
"""Get head_idx_to_logits_range information from hlabel information."""
return (
hlabel_data.head_idx_to_logits_range[str(idx)][0],
hlabel_data.head_idx_to_logits_range[str(idx)][1],
hlabel_info.head_idx_to_logits_range[str(idx)][0],
hlabel_info.head_idx_to_logits_range[str(idx)][1],
)

def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwargs) -> dict:
Expand All @@ -144,9 +144,9 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa
# Multiclass loss
num_effective_heads_in_batch = 0 # consider the label removal case
for i in range(self.num_multiclass_heads):
if i not in self.hlabel_data.empty_multiclass_head_indices:
if i not in self.hlabel_info.empty_multiclass_head_indices:
head_gt = gt_labels[:, i]
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0

Expand All @@ -161,15 +161,15 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa

# Multilabel loss
if self.num_multilabel_classes > 0:
head_gt = gt_labels[:, self.hlabel_data.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
head_gt = gt_labels[:, self.hlabel_info.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
img_metas = [data_sample.metainfo for data_sample in data_samples]
head_logits = head_logits[valid_mask]
valid_label_mask = self.get_valid_label_mask(img_metas).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.hlabel_data.num_single_label_classes :]
valid_label_mask = valid_label_mask[:, self.hlabel_info.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
losses["loss"] += self.multilabel_loss(head_logits, head_gt, valid_label_mask=valid_label_mask)

Expand Down Expand Up @@ -214,7 +214,7 @@ def _get_predictions(
multiclass_pred_scores = []
multiclass_pred_labels = []
for i in range(self.num_multiclass_heads):
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
multiclass_pred = torch.softmax(multiclass_logit, dim=1)
multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)
Expand All @@ -226,7 +226,7 @@ def _get_predictions(
multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

if self.num_multilabel_classes > 0:
multilabel_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
multilabel_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]

multilabel_pred_scores = torch.sigmoid(multilabel_logits)
multilabel_pred_labels = (multilabel_pred_scores >= self.thr).int()
Expand Down
Loading
Loading