-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Rename directory from `base_models` to `detectors` * Update pytorchcv version * Create criterion modules * Fix unit tests * Update comment * FIx iseg unit test * Enable factory for ATSS * Enable factory for YOLOX and update * Enable factory for RTMDet * Update recipes * Update * Enable factory for SSD * Fix unit tests * Enable factory for RTDETR * Reduce default parameters * Update huggingface, keypoint, and iseg * Revert default `input_size` argument * Fix unit test * Fix integration test * Add ABC * Fix * Change `model_version` to `model_name` * precommit * Remove `DetectionBackboneFactory` * Use `Literal` for `model_name`
- Loading branch information
Showing
78 changed files
with
2,079 additions
and
1,302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Base bounding box coder.""" | ||
|
||
from abc import ABCMeta, abstractmethod | ||
|
||
from torch import Tensor | ||
|
||
|
||
class BaseBBoxCoder(metaclass=ABCMeta): | ||
"""Base class for bounding box coder.""" | ||
|
||
encode_size: int | ||
|
||
@abstractmethod | ||
def encode(self, *args, **kwargs) -> Tensor: | ||
"""Encode bounding boxes.""" | ||
|
||
@abstractmethod | ||
def decode(self, *args, **kwargs) -> Tensor: | ||
"""Decode bounding boxes.""" | ||
|
||
@abstractmethod | ||
def decode_export(self, *args, **kwargs) -> Tensor: | ||
"""Decode bounding boxes for export.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
src/otx/algo/common/utils/prior_generators/base_prior_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Base prior generator.""" | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from typing import TYPE_CHECKING, Callable | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
|
||
class BasePriorGenerator(metaclass=ABCMeta): | ||
"""Base class for prior generator.""" | ||
|
||
strides: list[tuple[int, int]] | ||
grid_anchors: Callable[..., list[Tensor]] | ||
|
||
@property | ||
@abstractmethod | ||
def num_base_priors(self) -> list[int]: | ||
"""Return the number of priors (anchors/points) at a point on the feature grid.""" | ||
|
||
@property | ||
@abstractmethod | ||
def num_levels(self) -> int: | ||
"""int: number of feature levels that the generator will be applied.""" | ||
|
||
@abstractmethod | ||
def grid_priors(self, *args, **kwargs) -> list[Tensor]: | ||
"""Generate grid anchors/points of multiple feature levels.""" | ||
|
||
@abstractmethod | ||
def valid_flags(self, *args, **kwargs) -> list[Tensor]: | ||
"""Generate valid flags of anchors/points of multiple feature levels.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.