Skip to content

Commit

Permalink
roi align mask extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Aug 21, 2024
1 parent cf364b9 commit b9f9c9a
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/otx/algo/instance_segmentation/maskdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import Tensor, nn
from torch.nn.modules import Module
from torchvision import tv_tensors
from torchvision.ops.roi_align import RoIAlign

from otx.algo.instance_segmentation.mask_dino import box_ops
from otx.algo.instance_segmentation.mask_dino.criterion import SetCriterion
Expand Down Expand Up @@ -109,6 +110,12 @@ def __init__(
self.focus_on_box = focus_on_box
self.transform_eval = transform_eval
self.semantic_ce_loss = semantic_ce_loss
self.roi_align = RoIAlign(
output_size=(28, 28),
sampling_ratio=0,
aligned=True,
spatial_scale=1.0,
)

if not self.semantic_on:
assert self.sem_seg_postprocess_before_inference
Expand Down Expand Up @@ -329,16 +336,25 @@ def export(self, batch_inputs: Tensor, batch_img_metas: list[dict]):

boxes_with_scores = torch.cat([pred_boxes, pred_scores[:, None]], dim=1)

batch_masks, batch_bboxes, batch_labels = [], [], []
boxes_with_scores = boxes_with_scores.unsqueeze(0)
pred_classes = pred_classes.unsqueeze(0)
pred_masks = pred_masks.unsqueeze(0)

batch_masks.append(pred_masks)
batch_bboxes.append(boxes_with_scores)
batch_labels.append(pred_classes)
batch_index = (
torch.arange(boxes_with_scores.size(0))
.float()
.view(-1, 1, 1)
.expand(boxes_with_scores.size(0), boxes_with_scores.size(1), 1)
)
rois = torch.cat([batch_index, boxes_with_scores[..., :4]], dim=-1)
cropped_masks = self.roi_align(pred_masks, rois[0])
cropped_masks = cropped_masks[torch.arange(cropped_masks.size(0)), torch.arange(cropped_masks.size(0))]
cropped_masks = cropped_masks.unsqueeze(0)

return (
batch_bboxes,
batch_labels,
batch_masks,
boxes_with_scores,
pred_classes,
cropped_masks,
)


Expand Down Expand Up @@ -701,7 +717,7 @@ def post_process_instance_segmentation(
mask_box_results,
imgs_info,
):
ori_h, ori_w = img_info.ori_shape
ori_h, ori_w = img_info.ori_shape[-2:]
scores = mask_cls.sigmoid()
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)

Expand Down

0 comments on commit b9f9c9a

Please sign in to comment.