Skip to content

Commit

Permalink
Remove runtime parameter type check (#1999)
Browse files Browse the repository at this point in the history
* Remove check_input_parameters_type() decorators

* Remove params validation tests

* Fix pre-commit
  • Loading branch information
goodsong81 authored Apr 12, 2023
1 parent 1263b82 commit 5652811
Show file tree
Hide file tree
Showing 63 changed files with 9 additions and 4,578 deletions.
7 changes: 0 additions & 7 deletions otx/algorithms/action/adapters/mmaction/data/cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
from otx.algorithms.action.adapters.mmaction.data.pipelines import RawFrameDecode
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.label import LabelEntity
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -96,7 +92,6 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
item = self.video_info[list(self.video_info.keys())[index]]
return item

@check_input_parameters_type({"otx_dataset": DatasetParamTypeCheck})
# pylint: disable=too-many-arguments, invalid-name, super-init-not-called
# TODO Check need for additional params such as multi_class, with_offset
def __init__(
Expand All @@ -123,7 +118,6 @@ def __len__(self) -> int:
"""Return length of dataset."""
return len(self.video_infos)

@check_input_parameters_type()
def prepare_train_frames(self, idx: int) -> Dict[str, Any]:
"""Get training data and annotations after pipeline.
Expand All @@ -133,7 +127,6 @@ def prepare_train_frames(self, idx: int) -> Dict[str, Any]:
item = copy(self.video_infos[idx]) # Copying dict(), not contents
return self.pipeline(item)

@check_input_parameters_type()
def prepare_test_frames(self, idx: int) -> Dict[str, Any]:
"""Get testing data after pipeline.
Expand Down
7 changes: 0 additions & 7 deletions otx/algorithms/action/adapters/mmaction/data/det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.label import LabelEntity
from otx.api.entities.metadata import VideoMetadata
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)
from otx.api.utils.shape_factory import ShapeFactory

root_logger = get_root_logger()
Expand Down Expand Up @@ -209,7 +205,6 @@ def _update_annotations(self, metadata: VideoMetadata, anns: List[Annotation]):
metadata.update("proposals", proposals[:, :4])
metadata.update("scores", proposals[:, 4])

@check_input_parameters_type({"otx_dataset": DatasetParamTypeCheck})
# TODO Remove duplicated codes with mmaction's AVADataset
def __init__(
self,
Expand Down Expand Up @@ -245,7 +240,6 @@ def __init__(
# TODO. Handle exclude file for AVA dataset
self.exclude_file = None

@check_input_parameters_type()
def prepare_train_frames(self, idx: int) -> Dict[str, Any]:
"""Get training data and annotations after pipeline.
Expand All @@ -255,7 +249,6 @@ def prepare_train_frames(self, idx: int) -> Dict[str, Any]:
item = copy(self.video_infos[idx]) # Copying dict(), not contents
return self.pipeline(item)

@check_input_parameters_type()
def prepare_test_frames(self, idx: int) -> Dict[str, Any]:
"""Get testing data after pipeline.
Expand Down
9 changes: 0 additions & 9 deletions otx/algorithms/action/adapters/mmaction/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.label import LabelEntity
from otx.api.entities.model_template import TaskType
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)


@check_input_parameters_type()
def patch_config(config: Config, data_pipeline_path: str, work_dir: str, task_type: TaskType):
"""Patch recipe config suitable to mmaction."""
# FIXME omnisource is hard coded
Expand All @@ -47,7 +42,6 @@ def patch_config(config: Config, data_pipeline_path: str, work_dir: str, task_ty
raise NotImplementedError(f"{task_type} is not supported in action task")


@check_input_parameters_type()
def _patch_cls_datasets(config: Config):
"""Patch cls dataset config suitable to mmaction."""

Expand All @@ -61,7 +55,6 @@ def _patch_cls_datasets(config: Config):
cfg.labels = None


@check_input_parameters_type()
def _patch_det_dataset(config: Config):
"""Patch det dataset config suitable to mmaction."""
assert "data" in config
Expand All @@ -72,7 +65,6 @@ def _patch_det_dataset(config: Config):
cfg.type = "OTXActionDetDataset"


@check_input_parameters_type()
def set_data_classes(config: Config, labels: List[LabelEntity], task_type: TaskType):
"""Setter data classes into config."""
for subset in ("train", "val", "test"):
Expand All @@ -88,7 +80,6 @@ def set_data_classes(config: Config, labels: List[LabelEntity], task_type: TaskT
config.model["roi_head"]["bbox_head"]["topk"] = len(labels) - 1


@check_input_parameters_type({"train_dataset": DatasetParamTypeCheck, "val_dataset": DatasetParamTypeCheck})
def prepare_for_training(
config: Union[Config, ConfigDict],
train_dataset: DatasetEntity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np

from otx.api.entities.datasets import DatasetItemEntity
from otx.api.utils.argument_checks import check_input_parameters_type

try:
from openvino.model_zoo.model_api.adapters import OpenvinoAdapter
Expand All @@ -37,15 +36,13 @@
warnings.warn(f"{e}, ModelAPI was not found.")


@check_input_parameters_type()
def softmax_numpy(x: np.ndarray):
"""Softmax numpy."""
x = np.exp(x - np.max(x))
x /= np.sum(x)
return x


@check_input_parameters_type()
def get_multiclass_predictions(logits: np.ndarray, activate: bool = True):
"""Get multiclass predictions."""
index = np.argmax(logits)
Expand Down Expand Up @@ -93,7 +90,6 @@ def _get_outputs(self):
layer_name = name
return layer_name

@check_input_parameters_type()
def preprocess(self, inputs: List[DatasetItemEntity]):
"""Pre-process."""
meta = {"original_shape": inputs[0].media.numpy.shape}
Expand All @@ -108,14 +104,12 @@ def preprocess(self, inputs: List[DatasetItemEntity]):
return dict_inputs, meta

@staticmethod
@check_input_parameters_type()
def _reshape(inputs: List[np.ndarray]) -> np.ndarray:
"""Reshape(expand, transpose, permute) the input np.ndarray."""
np_inputs = np.expand_dims(inputs, axis=(0, 1)) # [1, 1, T, H, W, C]
np_inputs = np_inputs.transpose(0, 1, -1, 2, 3, 4) # [1, 1, C, T, H, W]
return np_inputs

@check_input_parameters_type()
# pylint: disable=unused-argument
def postprocess(self, outputs: Dict[str, np.ndarray], meta: Dict[str, Any]):
"""Post-process."""
Expand Down Expand Up @@ -162,7 +156,6 @@ def _get_outputs(self):
out_names["labels"] = name
return out_names

@check_input_parameters_type()
def preprocess(self, inputs: List[DatasetItemEntity]):
"""Pre-process."""
meta = {"original_shape": inputs[0].media.numpy.shape}
Expand All @@ -177,14 +170,12 @@ def preprocess(self, inputs: List[DatasetItemEntity]):
return dict_inputs, meta

@staticmethod
@check_input_parameters_type()
def reshape(inputs: List[np.ndarray]) -> np.ndarray:
"""Reshape(expand, transpose, permute) the input np.ndarray."""
np_inputs = np.expand_dims(inputs, axis=0) # [1, T, H, W, C]
np_inputs = np_inputs.transpose(0, -1, 1, 2, 3) # [1, C, T, H, W]
return np_inputs

@check_input_parameters_type()
def postprocess(self, outputs: Dict[str, np.ndarray], meta: Dict[str, Any]):
"""Post-process."""
# TODO Support multi label classification
Expand Down
13 changes: 1 addition & 12 deletions otx/algorithms/action/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@
from otx.api.usecases.tasks.interfaces.export_interface import ExportType, IExportTask
from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask
from otx.api.usecases.tasks.interfaces.unload_interface import IUnload
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)
from otx.api.utils.vis_utils import get_actmap

logger = get_root_logger()
Expand All @@ -74,12 +70,10 @@
class ActionInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluationTask, IUnload):
"""Inference Task Implementation of OTX Action Task."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment, **kwargs):
super().__init__(ActionConfig, task_environment, **kwargs)
self.deploy_cfg = None

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(
self,
dataset: DatasetEntity,
Expand Down Expand Up @@ -287,11 +281,7 @@ def _create_model(config: Config, from_scratch: bool = False):
model = build_model(model_cfg)
return model

@check_input_parameters_type()
def evaluate(
self,
output_resultset: ResultSetEntity,
):
def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None):
"""Evaluate function of OTX Action Task."""
logger.info("called evaluate()")
self._remove_empty_frames(output_resultset.ground_truth_dataset)
Expand Down Expand Up @@ -321,7 +311,6 @@ def unload(self):
if self._work_dir_is_temp:
self._delete_scratch_space()

@check_input_parameters_type()
def export(
self,
export_type: ExportType,
Expand Down
16 changes: 0 additions & 16 deletions otx/algorithms/action/tasks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@
IOptimizationTask,
OptimizationType,
)
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

try:
from openvino.model_zoo.model_api.adapters import OpenvinoAdapter, create_core
Expand All @@ -89,7 +85,6 @@
class ActionOpenVINOInferencer(BaseInferencer):
"""ActionOpenVINOInferencer class in OpenVINO task for action recognition."""

@check_input_parameters_type()
def __init__(
self,
task_type: str,
Expand Down Expand Up @@ -126,12 +121,10 @@ def __init__(
else:
self.converter = DetectionBoxToAnnotationConverter(self.label_schema)

@check_input_parameters_type()
def pre_process(self, image: List[DatasetItemEntity]) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Pre-process function of OpenVINO Inferencer for Action Recognition."""
return self.model.preprocess(image)

@check_input_parameters_type()
def post_process(
self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]
) -> Optional[AnnotationSceneEntity]:
Expand All @@ -140,15 +133,13 @@ def post_process(
prediction = self.model.postprocess(prediction, metadata)
return self.converter.convert_to_annotation(prediction, metadata)

@check_input_parameters_type()
def predict(self, image: List[DatasetItemEntity]) -> AnnotationSceneEntity:
"""Predict function of OpenVINO Action Inferencer for Action Recognition."""
data, metadata = self.pre_process(image)
raw_predictions = self.forward(data)
predictions = self.post_process(raw_predictions, metadata)
return predictions

# @check_input_parameters_type()
def forward(self, image: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Forward function of OpenVINO Action Inferencer for Action Recognition."""

Expand All @@ -158,13 +149,11 @@ def forward(self, image: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
class DataLoaderWrapper(DataLoader):
"""DataLoader implementation for ActionOpenVINOTask."""

@check_input_parameters_type()
def __init__(self, dataloader: DataLoader, inferencer: BaseInferencer):
super().__init__(config=None)
self.dataloader = dataloader
self.inferencer = inferencer

@check_input_parameters_type()
def __getitem__(self, index: int):
"""Get item from dataset."""
item = self.dataloader[index]
Expand All @@ -180,7 +169,6 @@ def __len__(self):
class ActionOpenVINOTask(IDeploymentTask, IInferenceTask, IEvaluationTask, IOptimizationTask):
"""Task implementation for OTX Action Recognition using OpenVINO backend."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
self.task_environment = task_environment
self.hparams = self.task_environment.get_hyper_parameters(ActionConfig)
Expand All @@ -203,7 +191,6 @@ def load_inferencer(self) -> ActionOpenVINOInferencer:
)

# pylint: disable=no-value-for-parameter
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(
self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None
) -> DatasetEntity:
Expand All @@ -223,7 +210,6 @@ def infer(
update_progress_callback(int(i / dataset_size * 100))
return dataset

@check_input_parameters_type()
def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None):
"""Evaluate function of OpenVINOTask."""

Expand All @@ -234,7 +220,6 @@ def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optiona
elif self.task_type == "ACTION_DETECTION":
output_resultset.performance = MetricsHelper.compute_f_measure(output_resultset).get_performance()

@check_input_parameters_type()
def deploy(self, output_model: ModelEntity) -> None:
"""Deploy function of OpenVINOTask."""

Expand Down Expand Up @@ -271,7 +256,6 @@ def deploy(self, output_model: ModelEntity) -> None:
output_model.exportable_code = zip_buffer.getvalue()
logger.info("Deploying completed")

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(
self,
optimization_type: OptimizationType,
Expand Down
2 changes: 0 additions & 2 deletions otx/algorithms/action/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.entities.subset import Subset
from otx.api.utils.argument_checks import check_input_parameters_type


@check_input_parameters_type()
def find_label_by_name(labels: List[LabelEntity], name: str, domain: Domain):
"""Return label from name."""
matching_labels = [label for label in labels if label.name == name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.id import ID
from otx.api.entities.label import LabelEntity
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

logger = get_logger()

Expand All @@ -35,7 +31,6 @@
class OTXClsDataset(BaseDataset):
"""Multi-class classification dataset class."""

@check_input_parameters_type({"otx_dataset": DatasetParamTypeCheck})
def __init__(
self, otx_dataset: DatasetEntity, labels: List[LabelEntity], empty_label=None, **kwargs
): # pylint: disable=super-init-not-called
Expand Down Expand Up @@ -83,7 +78,6 @@ def load_annotations(self):
self.gt_labels.append(class_indices)
self.gt_labels = np.array(self.gt_labels)

@check_input_parameters_type()
def __getitem__(self, index: int):
"""Get item from dataset."""
dataset = self.otx_dataset
Expand Down Expand Up @@ -435,7 +429,6 @@ class SelfSLDataset(Dataset):

CLASSES = None

@check_input_parameters_type({"otx_dataset": DatasetParamTypeCheck})
def __init__(
self, otx_dataset: DatasetEntity, pipeline: Dict[str, Any], **kwargs
): # pylint: disable=unused-argument
Expand Down
Loading

0 comments on commit 5652811

Please sign in to comment.