Skip to content

Commit

Permalink
Added inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jun 3, 2024
1 parent d3d4be1 commit fde191e
Showing 1 changed file with 3 additions and 22 deletions.
25 changes: 3 additions & 22 deletions src/super_gradients/module_interfaces/exportable_obb_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torch import nn, Tensor
from torch.utils.data import DataLoader

from .exportable_detector import AbstractObjectDetectionDecodingModule

logger = get_logger(__name__)

__all__ = [
Expand All @@ -33,7 +35,7 @@
]


class AbstractOBBDetectionDecodingModule(nn.Module):
class AbstractOBBDetectionDecodingModule(AbstractObjectDetectionDecodingModule):
"""
Abstract class for decoding outputs from object detection models to a tuple of two tensors (boxes, scores)
"""
Expand Down Expand Up @@ -61,19 +63,6 @@ def forward(self, predictions: Any) -> Tuple[Tensor, Tensor]:
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
This method is used to infer the total number of predictions for a given input resolution.
The function takes raw predictions from the model and returns the total number of predictions.
It is needed to check whether max_predictions_per_image and num_pre_nms_predictions are not greater than
the total number of predictions for a given resolution.
:param predictions: Predictions from the model itself.
:return: A total number of predictions for a given resolution
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

def get_output_names(self) -> List[str]:
"""
Returns the names of the outputs of the module.
Expand All @@ -84,14 +73,6 @@ def get_output_names(self) -> List[str]:
"""
return ["pre_nms_bboxes_cycywhr", "pre_nms_scores"]

@abc.abstractmethod
def get_num_pre_nms_predictions(self) -> int:
"""
Returns the number of predictions per image that this module produces.
:return: Number of predictions per image.
"""
raise NotImplementedError(f"get_num_pre_nms_predictions() method is not implemented for class {self.__class__.__name__}. ")


@dataclasses.dataclass
class OBBDetectionModelExportResult:
Expand Down

0 comments on commit fde191e

Please sign in to comment.