Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor XAI data entities #3230

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from torchvision.models import get_model, get_model_weights

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
MulticlassClsBatchPredEntityWithXAI,
)
from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXMulticlassClsModel
Expand Down Expand Up @@ -224,7 +220,7 @@
self,
outputs: Any, # noqa: ANN401
inputs: MulticlassClsBatchDataEntity,
) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity:
) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
if self.training:
return OTXBatchLossEntity(loss=outputs)

Expand All @@ -240,7 +236,7 @@

saliency_maps = outputs["saliency_map"].detach().cpu().numpy()

return MulticlassClsBatchPredEntityWithXAI(
return MulticlassClsBatchPredEntity(

Check warning on line 239 in src/otx/algo/classification/torchvision_model.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/torchvision_model.py#L239

Added line #L239 was not covered by tests
batch_size=len(preds),
images=inputs.images,
imgs_info=inputs.imgs_info,
Expand Down
20 changes: 10 additions & 10 deletions src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from datumaro import Image

from otx.core.config.explain import ExplainConfig
from otx.core.data.entity.base import OTXBatchPredEntityWithXAI
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntityWithXAI
from otx.core.data.entity.base import OTXBatchPredEntity
from otx.core.types.explain import TargetExplainGroup

if TYPE_CHECKING:
Expand All @@ -23,22 +22,23 @@


def process_saliency_maps_in_pred_entity(
predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any],
predict_result: list[OTXBatchPredEntity],
explain_config: ExplainConfig,
) -> list[Any] | list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI]:
) -> list[OTXBatchPredEntity]:
"""Process saliency maps in PredEntity."""
for predict_result_per_batch in predict_result:

def _process(predict_result_per_batch: OTXBatchPredEntity) -> OTXBatchPredEntity:
saliency_maps = predict_result_per_batch.saliency_maps
imgs_info = predict_result_per_batch.imgs_info
ori_img_shapes = [img_info.ori_shape for img_info in imgs_info]
pred_labels = predict_result_per_batch.labels # type: ignore[union-attr]
if pred_labels:
if pred_labels := getattr(predict_result_per_batch, "labels", None):
pred_labels = [pred.tolist() for pred in pred_labels]

processed_saliency_maps = process_saliency_maps(saliency_maps, explain_config, pred_labels, ori_img_shapes)

predict_result_per_batch.saliency_maps = processed_saliency_maps
return predict_result
return predict_result_per_batch.wrap(saliency_maps=processed_saliency_maps)

return [_process(predict_result_per_batch) for predict_result_per_batch in predict_result]


def process_saliency_maps(
Expand Down Expand Up @@ -116,7 +116,7 @@ def postprocess(saliency_map: np.ndarray, output_size: tuple[int, int] | None) -


def dump_saliency_maps(
predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any],
predict_result: list[OTXBatchPredEntity],
explain_config: ExplainConfig,
datamodule: EVAL_DATALOADERS | OTXDataModule,
output_dir: Path,
Expand Down
13 changes: 7 additions & 6 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,13 @@ def preprocess(self, x: Image) -> Image:

def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShotVisualPromptingBatchDataEntity:
"""Transforms for ZeroShotVisualPromptingBatchDataEntity."""
entity.images = [self.preprocess(self.apply_image(image)) for image in entity.images]
entity.prompts = [
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
]
return entity
return entity.wrap(
images=[self.preprocess(self.apply_image(image)) for image in entity.images],
prompts=[
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
],
)

def initialize_reference_info(self) -> None:
"""Initialize reference information."""
Expand Down
7 changes: 5 additions & 2 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@
)
transformed_entity = self._apply_transforms(entity)

if transformed_entity is None:
msg = "This is not allowed."
raise RuntimeError(msg)

Check warning on line 151 in src/otx/core/data/dataset/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/visual_prompting.py#L150-L151

Added lines #L150 - L151 were not covered by tests

# insert masks to transformed_entity
transformed_entity.masks = masks # type: ignore[union-attr]
return transformed_entity
return transformed_entity.wrap(masks=masks)

@property
def collate_fn(self) -> Callable:
Expand Down
20 changes: 3 additions & 17 deletions src/otx/core/data/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from otx.core.data.entity.base import (
OTXBatchDataEntity,
OTXBatchPredEntity,
OTXBatchPredEntityWithXAI,
OTXDataEntity,
OTXPredEntity,
OTXPredEntityWithXAI,
)
from otx.core.data.entity.utils import register_pytree_node
from otx.core.types.task import OTXTaskType
Expand Down Expand Up @@ -51,15 +49,10 @@


@dataclass
class ActionClsPredEntity(ActionClsDataEntity, OTXPredEntity):
class ActionClsPredEntity(OTXPredEntity, ActionClsDataEntity):
"""Data entity to represent the action classification model's output prediction."""


@dataclass
class ActionClsPredEntityWithXAI(ActionClsDataEntity, OTXPredEntityWithXAI):
"""Data entity to represent the detection model output prediction with explanations."""


@dataclass
class ActionClsBatchDataEntity(OTXBatchDataEntity[ActionClsDataEntity]):
"""Batch data entity for action classification.
Expand Down Expand Up @@ -92,16 +85,9 @@

def pin_memory(self) -> ActionClsBatchDataEntity:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
return self
return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels])

Check warning on line 88 in src/otx/core/data/entity/action_classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/action_classification.py#L88

Added line #L88 was not covered by tests


@dataclass
class ActionClsBatchPredEntity(ActionClsBatchDataEntity, OTXBatchPredEntity):
class ActionClsBatchPredEntity(OTXBatchPredEntity, ActionClsBatchDataEntity):
"""Data entity to represent model output predictions for action classification task."""


@dataclass
class ActionClsBatchPredEntityWithXAI(ActionClsBatchDataEntity, OTXBatchPredEntityWithXAI):
"""Data entity to represent model output predictions for multi-class classification task with explanations."""
24 changes: 11 additions & 13 deletions src/otx/core/data/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from otx.core.data.entity.base import (
OTXBatchDataEntity,
OTXBatchPredEntity,
OTXBatchPredEntityWithXAI,
OTXDataEntity,
OTXPredEntity,
)
Expand Down Expand Up @@ -48,7 +47,7 @@


@dataclass
class ActionDetPredEntity(ActionDetDataEntity, OTXPredEntity):
class ActionDetPredEntity(OTXPredEntity, ActionDetDataEntity):

Check warning on line 50 in src/otx/core/data/entity/action_detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/action_detection.py#L50

Added line #L50 was not covered by tests
"""Data entity to represent the action classification model's output prediction."""


Expand Down Expand Up @@ -89,18 +88,17 @@

def pin_memory(self) -> ActionDetBatchDataEntity:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.bboxes = [tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes]
self.labels = [label.pin_memory() for label in self.labels]
self.proposals = [tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals]
return self
return (

Check warning on line 91 in src/otx/core/data/entity/action_detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/action_detection.py#L91

Added line #L91 was not covered by tests
super()
.pin_memory()
.wrap(
bboxes=[tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes],
labels=[label.pin_memory() for label in self.labels],
proposals=[tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals],
)
)


@dataclass
class ActionDetBatchPredEntity(ActionDetBatchDataEntity, OTXBatchPredEntity):
class ActionDetBatchPredEntity(OTXBatchPredEntity, ActionDetBatchDataEntity):

Check warning on line 103 in src/otx/core/data/entity/action_detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/action_detection.py#L103

Added line #L103 was not covered by tests
"""Data entity to represent model output predictions for action classification task."""


@dataclass
class ActionDetBatchPredEntityWithXAI(ActionDetBatchDataEntity, OTXBatchPredEntityWithXAI):
"""Data entity to represent model output predictions for multi-class classification task with explanations."""
10 changes: 4 additions & 6 deletions src/otx/core/data/entity/anomaly/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,16 @@

def pin_memory(self) -> AnomalyClassificationDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
return self
return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels])

Check warning on line 66 in src/otx/core/data/entity/anomaly/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/classification.py#L66

Added line #L66 was not covered by tests


@dataclass
class AnomalyClassificationPrediction(AnomalyClassificationDataItem, OTXPredEntity):
class AnomalyClassificationPrediction(OTXPredEntity, AnomalyClassificationDataItem):

Check warning on line 70 in src/otx/core/data/entity/anomaly/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/classification.py#L70

Added line #L70 was not covered by tests
"""Anomaly classification Prediction item."""


@dataclass
class AnomalyClassificationBatchPrediction(AnomalyClassificationDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalyClassificationBatchPrediction(OTXBatchPredEntity, AnomalyClassificationDataBatch):

Check warning on line 75 in src/otx/core/data/entity/anomaly/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/classification.py#L74-L75

Added lines #L74 - L75 were not covered by tests
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
20 changes: 12 additions & 8 deletions src/otx/core/data/entity/anomaly/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,24 @@

def pin_memory(self) -> AnomalyDetectionDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
self.masks = self.masks.pin_memory()
self.boxes = [box.pin_memory() for box in self.boxes]
return self
return (

Check warning on line 71 in src/otx/core/data/entity/anomaly/detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/detection.py#L71

Added line #L71 was not covered by tests
super()
.pin_memory()
.wrap(
labels=[label.pin_memory() for label in self.labels],
masks=self.masks.pin_memory(),
boxes=[box.pin_memory() for box in self.boxes],
)
)


@dataclass
class AnomalyDetectionPrediction(AnomalyDetectionDataItem, OTXPredEntity):
class AnomalyDetectionPrediction(OTXPredEntity, AnomalyDetectionDataItem):

Check warning on line 83 in src/otx/core/data/entity/anomaly/detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/detection.py#L83

Added line #L83 was not covered by tests
"""Anomaly Detection Prediction item."""


@dataclass
class AnomalyDetectionBatchPrediction(AnomalyDetectionDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalyDetectionBatchPrediction(OTXBatchPredEntity, AnomalyDetectionDataBatch):

Check warning on line 88 in src/otx/core/data/entity/anomaly/detection.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/detection.py#L87-L88

Added lines #L87 - L88 were not covered by tests
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
Expand Down
18 changes: 11 additions & 7 deletions src/otx/core/data/entity/anomaly/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,23 @@

def pin_memory(self) -> AnomalySegmentationDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
self.masks = self.masks.pin_memory()
return self
return (

Check warning on line 68 in src/otx/core/data/entity/anomaly/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/segmentation.py#L68

Added line #L68 was not covered by tests
super()
.pin_memory()
.wrap(
labels=[label.pin_memory() for label in self.labels],
masks=self.masks.pin_memory(),
)
)


@dataclass
class AnomalySegmentationPrediction(AnomalySegmentationDataItem, OTXPredEntity):
class AnomalySegmentationPrediction(OTXPredEntity, AnomalySegmentationDataItem):

Check warning on line 79 in src/otx/core/data/entity/anomaly/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/segmentation.py#L79

Added line #L79 was not covered by tests
"""Anomaly Segmentation Prediction item."""


@dataclass
class AnomalySegmentationBatchPrediction(AnomalySegmentationDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalySegmentationBatchPrediction(OTXBatchPredEntity, AnomalySegmentationDataBatch):

Check warning on line 84 in src/otx/core/data/entity/anomaly/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/entity/anomaly/segmentation.py#L83-L84

Added lines #L83 - L84 were not covered by tests
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
Loading
Loading