Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMDet MaskRCNN ResNet50/SwinTransformer Decouple #3281

Merged
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
fa0d462
migrate mmdet maskrcnn modules
eugene123tw Apr 5, 2024
763ce7d
style reformat
eugene123tw Apr 5, 2024
5ace6b0
style reformat
eugene123tw Apr 5, 2024
b2b6e0c
stype reformat
eugene123tw Apr 5, 2024
a305c7c
ignore mypy, ruff errors
eugene123tw Apr 8, 2024
6b682b7
skip mypy error
eugene123tw Apr 8, 2024
727cf2b
update
eugene123tw Apr 8, 2024
22f8f81
fix loss
eugene123tw Apr 8, 2024
08d766f
add maskrcnn
eugene123tw Apr 8, 2024
873a101
update import
eugene123tw Apr 8, 2024
bfca5f0
update import
eugene123tw Apr 8, 2024
47f3744
add necks
eugene123tw Apr 8, 2024
7bb7f39
update
eugene123tw Apr 8, 2024
d37f8a1
update
eugene123tw Apr 8, 2024
abd1f21
add cross-entropy loss
eugene123tw Apr 9, 2024
4978171
style changes
eugene123tw Apr 9, 2024
330f721
mypy changes and style changes
eugene123tw Apr 10, 2024
92dc0bf
update style
eugene123tw Apr 11, 2024
87a2415
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 11, 2024
c4dec94
remove box structures
eugene123tw Apr 11, 2024
5723ec2
add resnet
eugene123tw Apr 11, 2024
47de4f1
udpate
eugene123tw Apr 11, 2024
f5a51da
modify resnet
eugene123tw Apr 15, 2024
63f46c4
add annotation
eugene123tw Apr 15, 2024
46b37d5
style changes
eugene123tw Apr 15, 2024
cef1e24
update
eugene123tw Apr 15, 2024
655baea
fix all mypy issues
eugene123tw Apr 15, 2024
b1ed150
fix mypy issues
eugene123tw Apr 15, 2024
27b6a4a
style changes
eugene123tw Apr 16, 2024
e87cfa9
remove unused losses
eugene123tw Apr 16, 2024
c2c2394
remove focal_loss_pb
eugene123tw Apr 16, 2024
edd85e0
fix all rull and mypy issues
eugene123tw Apr 16, 2024
742b6fd
fix conflicts
eugene123tw Apr 16, 2024
194a6c2
style change
eugene123tw Apr 16, 2024
c1938de
update
eugene123tw Apr 16, 2024
734a459
udpate license
eugene123tw Apr 16, 2024
93fcd79
udpate
eugene123tw Apr 16, 2024
1d0926d
remove duplicates
eugene123tw Apr 17, 2024
1238d28
remove as F
eugene123tw Apr 17, 2024
55d77f5
remove as F
eugene123tw Apr 17, 2024
1598929
remove mmdet mask structures
eugene123tw Apr 17, 2024
498b750
remove duplicates
eugene123tw Apr 17, 2024
a0f52de
style changes
eugene123tw Apr 17, 2024
549d0ef
add new test
eugene123tw Apr 17, 2024
0a074ae
test style change
eugene123tw Apr 17, 2024
9106bc1
fix test
eugene123tw Apr 17, 2024
8197825
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 17, 2024
07cd25e
chagne device for unit test
eugene123tw Apr 17, 2024
ab07675
add deployment files
eugene123tw Apr 17, 2024
70717ce
remove deployment from inst-seg
eugene123tw Apr 17, 2024
e5027d9
update deployment
eugene123tw Apr 18, 2024
03e26d3
add mmdeploy maskrcnn opset
eugene123tw Apr 18, 2024
bbeaa32
fix linter
eugene123tw Apr 18, 2024
478158a
update test
eugene123tw Apr 18, 2024
052a582
update test
eugene123tw Apr 18, 2024
28cea6f
update test
eugene123tw Apr 18, 2024
ed7275e
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 22, 2024
04cd223
replace mmcv.cnn module
eugene123tw Apr 22, 2024
11ae260
remove upsample building
eugene123tw Apr 22, 2024
d8a78ca
remove upsample building
eugene123tw Apr 22, 2024
34ea7d7
use batch_nms from otx
eugene123tw Apr 22, 2024
617e1ac
add swintransformer
eugene123tw Apr 22, 2024
5a6737f
add transformers
eugene123tw Apr 22, 2024
fe34145
add swin transformer
eugene123tw Apr 23, 2024
4dd8bdc
style changes
eugene123tw Apr 23, 2024
a6de9b4
merge upstream
eugene123tw Apr 23, 2024
1b199a0
solve conflicts
eugene123tw Apr 23, 2024
c332430
update instance_segmentation/maskrcnn.py
eugene123tw Apr 23, 2024
6d92199
update nms
eugene123tw Apr 23, 2024
4d75001
fix xai
eugene123tw Apr 24, 2024
7b4b43e
change rotate detection recipe
eugene123tw Apr 24, 2024
0b2a326
fix swint recipe
eugene123tw Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style changes
  • Loading branch information
eugene123tw committed Apr 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 46b37d56ad7dd7859a681f61e23d503786ada75b
78 changes: 27 additions & 51 deletions src/otx/algo/detection/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -13,13 +13,12 @@

def cross_entropy(
pred: torch.Tensor,
label: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor | None = None,
reduction: str = "mean",
avg_factor: int | None = None,
class_weight: list[float] | None = None,
ignore_index: int = -100,
avg_non_ignore: bool = False,
ignore_index: int | None = None,
) -> torch.Tensor:
"""Calculate the CrossEntropy loss.

@@ -34,22 +33,14 @@ def cross_entropy(
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.

Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)

# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
avg_factor = label.numel() - (label == ignore_index).sum().item()
loss = F.cross_entropy(pred, target, weight=class_weight, reduction="none", ignore_index=ignore_index)

# apply weights and do the reduction
if weight is not None:
@@ -83,13 +74,12 @@ def _expand_onehot_labels(

def binary_cross_entropy(
pred: torch.Tensor,
label: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor | None = None,
reduction: str = "mean",
avg_factor: int | None = None,
class_weight: list[float] | None = None,
ignore_index: int = -100,
avg_non_ignore: bool = False,
ignore_index: int | None = None,
) -> torch.Tensor:
"""Calculate the binary CrossEntropy loss.

@@ -108,56 +98,47 @@ def binary_cross_entropy(
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.

Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index

if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.size(-1), ignore_index)
if pred.dim() != target.dim():
target, weight, valid_mask = _expand_onehot_labels(target, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
valid_mask = ((target >= 0) & (target != ignore_index)).float()

# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
avg_factor = valid_mask.sum().item()
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask if weight is not None else valid_mask

# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
loss = F.binary_cross_entropy_with_logits(pred, target.float(), pos_weight=class_weight, reduction="none")
# do the reduction for the weighted loss
return weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)


def mask_cross_entropy(
pred: torch.Tensor,
target: torch.Tensor,
label: torch.Tensor,
weight: torch.Tensor | None = None,
reduction: str = "mean",
avg_factor: int | None = None,
class_weight: list[float] | None = None,
ignore_index: int | None = None,
**kwargs,
) -> torch.Tensor:
"""Calculate the CrossEntropy loss for masks.

Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
weight (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
@@ -189,7 +170,7 @@ def mask_cross_entropy(
assert reduction == "mean" and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
pred_slice = pred[inds, weight].squeeze(1)
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]


@@ -222,7 +203,7 @@ def __init__(
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
super().__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
@@ -246,21 +227,19 @@ def __init__(
else:
self.cls_criterion = cross_entropy

def extra_repr(self):
def extra_repr(self) -> str:
"""Extra repr."""
s = f"avg_non_ignore={self.avg_non_ignore}"
return s
return f"avg_non_ignore={self.avg_non_ignore}"

def forward(
self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs,
):
cls_score: torch.Tensor,
label: torch.Tensor,
weight: torch.Tensor | None = None,
avg_factor: int | None = None,
reduction_override: str | None = None,
ignore_index: int | None = None,
) -> torch.Tensor:
"""Forward function.

Args:
@@ -286,15 +265,12 @@ def forward(
class_weight = cls_score.new_tensor(self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
return self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs,
)
return loss_cls
Original file line number Diff line number Diff line change
@@ -60,19 +60,19 @@ def __init__(
self.with_cp = with_cp

@property
def norm1(self):
def norm1(self) -> nn.Module:
"""nn.Module: normalization layer after the first convolution layer."""
return getattr(self, self.norm1_name)

@property
def norm2(self):
def norm2(self) -> nn.Module:
"""nn.Module: normalization layer after the second convolution layer."""
return getattr(self, self.norm2_name)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""

def _inner_forward(x):
def _inner_forward(x: torch.Tensor) -> torch.Tensor:
identity = x

out = self.conv1(x)
@@ -103,11 +103,11 @@ def __init__(
planes: int,
stride: int = 1,
dilation: int = 1,
downsample=None,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type="BN"),
init_cfg=None,
downsample: nn.Module | None = None,
with_cp: bool = False,
conv_cfg: dict | None = None,
norm_cfg: dict = dict(type="BN"),
init_cfg: dict | None = None,
):
"""Bottleneck block for ResNet.

@@ -277,7 +277,7 @@ def __init__(
norm_eval: bool = True,
with_cp: bool = False,
zero_init_residual: bool = True,
pretrained: str | bool = None,
pretrained: str | bool | None = None,
init_cfg: list[dict] | dict | None = None,
):
super(ResNet, self).__init__(init_cfg)
Original file line number Diff line number Diff line change
@@ -50,7 +50,12 @@ def __init__(
if avg_down:
conv_stride = 1
downsample.append(
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False,
),
)
downsample.extend(
[
@@ -81,10 +86,10 @@ def __init__(
),
)
inplanes = planes * block.expansion
layers = [
block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)
for _ in range(1, num_blocks)
]
for _ in range(1, num_blocks):
layers.append(
block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs),
)

else: # downsample_first=False is for HourglassModule
for _ in range(num_blocks - 1):