Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegukhyun committed Jun 16, 2023
1 parent 2417c01 commit 6903af4
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def multi_scale_deformable_attn_pytorch(
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = custom_grid_sample(
sampling_value_l_ = _custom_grid_sample(
value_l_,
sampling_grid_l_,
# mode='bilinear',
Expand All @@ -59,7 +59,7 @@ def multi_scale_deformable_attn_pytorch(
return output.transpose(1, 2).contiguous()


def custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: bool = False) -> torch.Tensor:
def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: bool = False) -> torch.Tensor:
"""Custom patch for mmcv.ops.point_sample.bilinear_grid_sample.
This function is almost same with mmcv.ops.point_sample.bilinear_grid_sample.
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/algorithms/common/adapters/mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Test for otx.algorithms.common.adapters.mmcv.ops"""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Test for otx.algorithms.common.adapters.mmcv.ops.multi_scale_deformable_attn_pytorch."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import torch

from otx.algorithms.common.adapters.mmcv.ops import multi_scale_deformable_attn_pytorch
from tests.test_suite.e2e_test_system import e2e_pytest_unit


@e2e_pytest_unit
def test_multi_scale_deformable_attn_pytorch():
value = torch.randn([1, 22223, 8, 32])
value_spatial_shapes = torch.tensor([[100, 167], [50, 84], [25, 42], [13, 21]])
sampling_locations = torch.randn([1, 2223, 8, 4, 4, 2])
attention_weights = torch.randn([1, 2223, 8, 4, 4])

out = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights)
assert out.shape == torch.Size([1, 2223, 256])
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Test for otx.algorithms.common.adapters.mmdeploy.ops"""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Test for otx.algorithms.common.adapters.mmdeploy.ops.custom_ops."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmdeploy.core import SYMBOLIC_REWRITER

from otx.algorithms.common.adapters.mmdeploy.ops.custom_ops import (
squeeze__default,
grid_sampler__default,
)
from tests.test_suite.e2e_test_system import e2e_pytest_unit


@e2e_pytest_unit
def test_symbolic_registery():
assert len(SYMBOLIC_REWRITER._registry._rewrite_records["squeeze"]) == 1
assert len(SYMBOLIC_REWRITER._registry._rewrite_records["grid_sampler"]) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,86 @@ def fxt_cfg_custom_yolox(num_classes: int = 3):
},
}
return cfg


@pytest.fixture
def fxt_cfg_custom_deformable_detr(num_classes: int = 3):
return ConfigDict(
type="CustomDeformableDETR",
backbone=dict(
type="ResNet",
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type="BN", requires_grad=False),
norm_eval=True,
style="pytorch",
init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"),
),
neck=dict(
type="ChannelMapper",
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type="GN", num_groups=32),
num_outs=4,
),
bbox_head=dict(
type="CustomDeformableDETRHead",
num_query=300,
num_classes=80,
in_channels=2048,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=True,
transformer=dict(
type="DeformableDetrTransformer",
encoder=dict(
type="DetrTransformerEncoder",
num_layers=6,
transformerlayers=dict(
type="BaseTransformerLayer",
attn_cfgs=dict(type="MultiScaleDeformableAttention", embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=("self_attn", "norm", "ffn", "norm"),
),
),
decoder=dict(
type="DeformableDetrTransformerDecoder",
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type="DetrTransformerDecoderLayer",
attn_cfgs=[
dict(type="MultiheadAttention", embed_dims=256, num_heads=8, dropout=0.1),
dict(type="MultiScaleDeformableAttention", embed_dims=256),
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"),
),
),
),
positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True, offset=-0.5),
loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0),
loss_bbox=dict(type="L1Loss", loss_weight=5.0),
loss_iou=dict(type="GIoULoss", loss_weight=2.0),
),
# training and testing settings
train_cfg=dict(
assigner=dict(
type="HungarianAssigner",
cls_cost=dict(type="FocalLossCost", weight=2.0),
reg_cost=dict(type="BBoxL1Cost", weight=5.0, box_format="xywh"),
iou_cost=dict(type="IoUCost", iou_mode="giou", weight=2.0),
)
),
test_cfg=dict(max_per_img=100),
task_adapt=dict(
src_classes=["person", "car"],
dst_classes=["tree", "car", "person"],
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test for CustomDeformableDETRHead."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmdet.models.builder import build_detector

from otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector import (
CustomDeformableDETR,
)
from tests.test_suite.e2e_test_system import e2e_pytest_unit


class TestCustomDeformableDETR:
@e2e_pytest_unit
def test_custom_deformable_detr_build(self, fxt_cfg_custom_deformable_detr):
model = build_detector(fxt_cfg_custom_deformable_detr)
assert isinstance(model, CustomDeformableDETR)
assert model.task_adapt is not None
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Test for otx.algorithms.mmdetection.adapters.mmdet.models.heads."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Test for otx.algorithms.mmdetection.adapters.mmdet.models.heads.custom_deformable_detr_head."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import numpy as np
import torch
import pytest

from mmcv.utils import ConfigDict
from mmdet.models.builder import build_detector
from mmdet.models.dense_heads.deformable_detr_head import DeformableDETRHead

from tests.test_suite.e2e_test_system import e2e_pytest_unit


class TestCustomDeformableDETRHead:
@pytest.fixture(autouse=True)
def setup(self) -> None:
cfg = ConfigDict(
type="CustomDeformableDETRHead",
num_query=300,
num_classes=80,
in_channels=2048,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=True,
transformer=dict(
type="DeformableDetrTransformer",
encoder=dict(
type="DetrTransformerEncoder",
num_layers=6,
transformerlayers=dict(
type="BaseTransformerLayer",
attn_cfgs=dict(type="MultiScaleDeformableAttention", embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=("self_attn", "norm", "ffn", "norm"),
),
),
decoder=dict(
type="DeformableDetrTransformerDecoder",
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type="DetrTransformerDecoderLayer",
attn_cfgs=[
dict(type="MultiheadAttention", embed_dims=256, num_heads=8, dropout=0.1),
dict(type="MultiScaleDeformableAttention", embed_dims=256),
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"),
),
),
),
positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True, offset=-0.5),
loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0),
loss_bbox=dict(type="L1Loss", loss_weight=5.0),
loss_iou=dict(type="GIoULoss", loss_weight=2.0),
)
self.head = build_detector(cfg)

@e2e_pytest_unit
def test_forward(self, mocker):
def return_second_arg(a, b):
return b

mocker.patch.object(DeformableDETRHead, "forward", side_effect=return_second_arg)

feats = (
torch.randn([1, 256, 100, 167]),
torch.randn([1, 256, 50, 84]),
torch.randn([1, 256, 25, 42]),
torch.randn([1, 256, 13, 21]),
)
img_metas = [
{
"filename": None,
"ori_filename": None,
"ori_shape": (128, 128, 3),
"img_shape": torch.Tensor([800, 1333]),
"pad_shape": (800, 1333, 3),
"scale_factor": np.array([10.4140625, 6.25, 10.4140625, 6.25], dtype=np.float32),
"flip": False,
"flip_direction": None,
"img_norm_cfg": {
"mean": np.array([123.675, 116.28, 103.53], dtype=np.float32),
"std": np.array([58.395, 57.12, 57.375], dtype=np.float32),
"to_rgb": False,
},
}
]
out = self.head(feats, img_metas)
assert out[0].get("batch_input_shape") == (800, 1333)
assert out[0].get("img_shape") == (800, 1333, 3)

0 comments on commit 6903af4

Please sign in to comment.