Skip to content

Commit

Permalink
resolve conflicts. Delete unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Oct 28, 2024
1 parent 88733c9 commit 4f020a3
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 46 deletions.
25 changes: 0 additions & 25 deletions src/otx/algo/object_detection_3d/heads/depthaware_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""depth aware transformer head for 3d object detection."""
from __future__ import annotations

import math
from typing import Any, Callable, ClassVar

import torch
Expand Down Expand Up @@ -101,30 +100,6 @@ def _reset_parameters(self) -> None:
constant_(self.reference_points.bias.data, 0.0)
normal_(self.level_embed)

def get_proposal_pos_embed(self, proposals: Tensor) -> Tensor:
"""Generate position embeddings for proposal tensor.
Args:
proposals (Tensor): Proposal tensor of shape (N, L, 6).
TODO (Kirill): Not used. Remove this function?
Returns:
Tensor: Position embeddings for proposal tensor of shape (N, L, embedding_dim).
"""
num_pos_feats = 128
temperature = 10000
scale = 2 * math.pi

dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 6
proposals = proposals.sigmoid() * scale
# N, L, 6, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 6, 64, 2
return torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)

def get_valid_ratio(self, mask: Tensor) -> Tensor:
"""Calculate the valid ratio of the mask.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from otx.algo.object_detection_3d.heads.depthaware_transformer import (
DepthAwareTransformerBuilder,
gen_sineembed_for_position,
)


Expand Down Expand Up @@ -57,16 +56,6 @@ def test_depth_aware_transformer_forward(self, depth_aware_transformer):
assert output[2].shape == (1, 550, 2)
assert output[4] is None

def test_depth_aware_transformer_get_proposal_pos_embed(self, depth_aware_transformer):
# Create dummy input tensor
proposals = torch.randn(2, 10, 6)

# Get proposal position embeddings
pos_embed = depth_aware_transformer.get_proposal_pos_embed(proposals)

# Check output shape
assert pos_embed.shape == (2, 10, 768)

def test_depth_aware_transformer_get_valid_ratio(self, depth_aware_transformer):
# Create dummy input tensor
mask = torch.randn(2, 32, 32) > 0
Expand All @@ -76,13 +65,3 @@ def test_depth_aware_transformer_get_valid_ratio(self, depth_aware_transformer):

# Check output shape
assert valid_ratio.shape == (2, 2)

def test_gen_sineembed_for_position(self):
# Create dummy input tensor
pos_tensor = torch.randn(2, 4, 6)

# Generate sine embeddings for position tensor
pos_embed = gen_sineembed_for_position(pos_tensor)

# Check output shape
assert pos_embed.shape == (2, 4, 768)

0 comments on commit 4f020a3

Please sign in to comment.