-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/v2' into enhance/perf-classifi…
…cation
- Loading branch information
Showing
73 changed files
with
2,445 additions
and
26 deletions.
There are no files selected for viewing
1,381 changes: 1,381 additions & 0 deletions
1,381
for_developers/images/product_design/core_design_drawing.drawio
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
defaults: | ||
- default | ||
|
||
task: ACTION_DETECTION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
defaults: | ||
- default | ||
|
||
model_checkpoint: | ||
dirpath: ${base.output_dir}/checkpoints | ||
filename: "epoch_{epoch:03d}" | ||
monitor: "val/map" | ||
mode: "max" | ||
save_last: True | ||
auto_insert_metric_name: False | ||
|
||
early_stopping: | ||
monitor: "val/map" | ||
patience: 100 | ||
mode: "max" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
defaults: | ||
- default | ||
|
||
data_format: ava | ||
|
||
mem_cache_img_max_size: ${as_int_tuple:500,500} | ||
|
||
train_subset: | ||
batch_size: 64 | ||
transform_lib_type: MMACTION | ||
val_subset: | ||
batch_size: 64 | ||
transform_lib_type: MMACTION | ||
test_subset: | ||
batch_size: 64 | ||
transform_lib_type: MMACTION |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
defaults: | ||
- default | ||
|
||
_target_: otx.core.model.module.action_detection.OTXActionDetLitModule | ||
|
||
otx_model: | ||
_target_: otx.core.model.entity.action_detection.MMActionCompatibleModel | ||
config: ??? | ||
|
||
# compile model for faster training with pytorch 2.0 | ||
torch_compile: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Module for OTXActionDetDataset.""" | ||
|
||
from __future__ import annotations | ||
|
||
import pickle | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import torch | ||
from datumaro import Bbox, Image | ||
from datumaro.components.annotation import AnnotationType | ||
from torchvision import tv_tensors | ||
|
||
from otx.core.data.dataset.base import OTXDataset | ||
from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetDataEntity | ||
from otx.core.data.entity.base import ImageInfo | ||
|
||
|
||
class OTXActionDetDataset(OTXDataset[ActionDetDataEntity]): | ||
"""OTXDataset class for action detection task.""" | ||
|
||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
self.num_classes = len(self.dm_subset.categories()[AnnotationType.label]) | ||
|
||
def _get_item_impl(self, idx: int) -> ActionDetDataEntity | None: | ||
item = self.dm_subset.get(id=self.ids[idx], subset=self.dm_subset.name) | ||
img = item.media_as(Image) | ||
img_data, img_shape = self._get_img_data_and_shape(img) | ||
|
||
bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] | ||
bboxes = ( | ||
np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32) | ||
if len(bbox_anns) > 0 | ||
else np.zeros((0, 4), dtype=np.float32) | ||
) | ||
|
||
entity = ActionDetDataEntity( | ||
image=img_data, | ||
img_info=ImageInfo( | ||
img_idx=idx, | ||
img_shape=img_shape, | ||
ori_shape=img_shape, | ||
pad_shape=img_shape, | ||
scale_factor=(1.0, 1.0), | ||
), | ||
bboxes=tv_tensors.BoundingBoxes( | ||
bboxes, | ||
format=tv_tensors.BoundingBoxFormat.XYXY, | ||
canvas_size=img_shape, | ||
), | ||
labels=torch.nn.functional.one_hot( | ||
torch.as_tensor([ann.label for ann in bbox_anns]), | ||
self.num_classes, | ||
).to(torch.float), | ||
frame_path=item.media.path, | ||
proposals=self._get_proposals( | ||
item.media.path, | ||
self.dm_subset.infos().get(f"{self.dm_subset.name}_proposals", None), | ||
), | ||
) | ||
|
||
return self._apply_transforms(entity) | ||
|
||
@staticmethod | ||
def _get_proposals(frame_path: str, proposal_file: str | None) -> np.ndarray: | ||
"""Get proposal from frame path and proposal file name. | ||
Datumaro AVA dataset expect data structure as | ||
- data_root/ | ||
- frames/ | ||
- video0 | ||
- video0_0001.jpg | ||
- vdieo0_0002.jpg | ||
- annotations/ | ||
- train.csv | ||
- val.csv | ||
- train.pkl | ||
- val.pkl | ||
""" | ||
if proposal_file is None: | ||
return np.array([[0, 0, 1, 1]], dtype=np.float64) | ||
|
||
annotation_dir = Path(frame_path).parent.parent.parent | ||
proposal_file_path = annotation_dir / "annotations" / proposal_file | ||
if not proposal_file_path.exists(): | ||
return np.array([[0, 0, 1, 1]], dtype=np.float64) | ||
with Path.open(proposal_file_path, "rb") as f: | ||
info = pickle.load(f) # noqa: S301 | ||
return ( | ||
info[",".join(Path(frame_path).stem.rsplit("_", 1))][:, :4] | ||
if ",".join(Path(frame_path).stem.rsplit("_", 1)) in info | ||
else np.array([[0, 0, 1, 1]], dtype=np.float32) | ||
) | ||
|
||
@property | ||
def collate_fn(self) -> Callable: | ||
"""Collection function to collect ActionClsDataEntity into ActionClsBatchDataEntity.""" | ||
return ActionDetBatchDataEntity.collate_fn |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Module for OTX action data entities.""" | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
from otx.core.data.entity.base import ( | ||
OTXBatchDataEntity, | ||
OTXBatchPredEntity, | ||
OTXDataEntity, | ||
OTXPredEntity, | ||
) | ||
from otx.core.data.entity.utils import register_pytree_node | ||
from otx.core.types.task import OTXTaskType | ||
|
||
if TYPE_CHECKING: | ||
from torch import LongTensor | ||
from torchvision import tv_tensors | ||
|
||
|
||
@register_pytree_node | ||
@dataclass | ||
class ActionDetDataEntity(OTXDataEntity): | ||
"""Data entity for action classification task. | ||
Args: | ||
bboxes: 2D bounding boxes for actors. | ||
labels: One-hot vector of video's action labels. | ||
frame_path: Data media's file path for getting proper meta information. | ||
proposals: Pre-calculated actor proposals. | ||
""" | ||
|
||
bboxes: tv_tensors.BoundingBoxes | ||
labels: LongTensor | ||
frame_path: str | ||
proposals: tv_tensors.BoundingBoxes | ||
|
||
@property | ||
def task(self) -> OTXTaskType: | ||
"""OTX Task type definition.""" | ||
return OTXTaskType.ACTION_DETECTION | ||
|
||
|
||
@dataclass | ||
class ActionDetPredEntity(ActionDetDataEntity, OTXPredEntity): | ||
"""Data entity to represent the action classification model's output prediction.""" | ||
|
||
|
||
@dataclass | ||
class ActionDetBatchDataEntity(OTXBatchDataEntity[ActionDetDataEntity]): | ||
"""Batch data entity for action classification. | ||
Args: | ||
bboxes(list[tv_tensors.BoundingBoxes]): A list of bounding boxes of videos. | ||
labels(list[LongTensor]): A list of labels of videos. | ||
""" | ||
|
||
bboxes: list[tv_tensors.BoundingBoxes] | ||
labels: list[LongTensor] | ||
proposals: list[tv_tensors.BoundingBoxes] | ||
|
||
@property | ||
def task(self) -> OTXTaskType: | ||
"""OTX task type definition.""" | ||
return OTXTaskType.ACTION_DETECTION | ||
|
||
@classmethod | ||
def collate_fn(cls, entities: list[ActionDetDataEntity]) -> ActionDetBatchDataEntity: | ||
"""Collection function to collect `ActionClsDataEntity` into `ActionClsBatchDataEntity`.""" | ||
batch_data = super().collate_fn(entities) | ||
return ActionDetBatchDataEntity( | ||
batch_size=batch_data.batch_size, | ||
images=batch_data.images, | ||
imgs_info=batch_data.imgs_info, | ||
bboxes=[entity.bboxes for entity in entities], | ||
labels=[entity.labels for entity in entities], | ||
proposals=[entity.proposals for entity in entities], | ||
) | ||
|
||
|
||
@dataclass | ||
class ActionDetBatchPredEntity(ActionDetBatchDataEntity, OTXBatchPredEntity): | ||
"""Data entity to represent model output predictions for action classification task.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.