Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vinnamkim committed Mar 28, 2024
1 parent 63f426f commit a554467
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
24 changes: 14 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ def fxt_seg_data_entity() -> tuple[tuple, SegDataEntity, SegBatchDataEntity]:
fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size)
fake_masks = Mask(torch.randint(low=0, high=255, size=img_size, dtype=torch.uint8))
# define data entity
single_data_entity = SegDataEntity(fake_image, fake_image_info, fake_masks)
single_data_entity = SegDataEntity(
image=fake_image,
img_info=fake_image_info,
gt_seg_map=fake_masks,
)
batch_data_entity = SegBatchDataEntity(
1,
[Image(data=torch.from_numpy(fake_image))],
[fake_image_info],
[fake_masks],
batch_size=1,
images=[Image(data=torch.from_numpy(fake_image))],
imgs_info=[fake_image_info],
masks=[fake_masks],
)
batch_pred_data_entity = SegBatchPredEntity(
1,
[Image(data=torch.from_numpy(fake_image))],
[fake_image_info],
[],
[fake_masks],
batch_size=1,
images=[Image(data=torch.from_numpy(fake_image))],
imgs_info=[fake_image_info],
masks=[fake_masks],
scores=[],
)

return single_data_entity, batch_pred_data_entity, batch_data_entity
Expand Down
11 changes: 7 additions & 4 deletions tests/integration/api/test_xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import openvino.runtime as ov
import pytest
from otx.core.data.entity.base import OTXBatchPredEntity, OTXBatchPredEntityWithXAI
from otx.core.data.entity.base import OTXBatchPredEntity
from otx.engine import Engine

RECIPE_LIST_ALL = pytest.RECIPE_LIST
Expand Down Expand Up @@ -57,7 +57,8 @@ def test_forward_explain(
assert isinstance(predict_result[0], OTXBatchPredEntity)

predict_result_explain = engine.predict(explain=True)
assert isinstance(predict_result_explain[0], OTXBatchPredEntityWithXAI)
assert isinstance(predict_result_explain[0], OTXBatchPredEntity)
assert predict_result_explain[0].has_xai_outputs

batch_size = len(predict_result[0].scores)
for i in range(batch_size):
Expand Down Expand Up @@ -106,7 +107,8 @@ def test_predict_with_explain(

# Predict with explain torch & process maps
predict_result_explain_torch = engine.predict(explain=True)
assert isinstance(predict_result_explain_torch[0], OTXBatchPredEntityWithXAI)
assert isinstance(predict_result_explain_torch[0], OTXBatchPredEntity)
assert predict_result_explain_torch[0].has_xai_outputs
assert predict_result_explain_torch[0].saliency_maps is not None
assert isinstance(predict_result_explain_torch[0].saliency_maps[0], dict)

Expand Down Expand Up @@ -134,7 +136,8 @@ def test_predict_with_explain(

# Predict OV model with xai & process maps
predict_result_explain_ov = engine.predict(checkpoint=exported_model_path, explain=True)
assert isinstance(predict_result_explain_ov[0], OTXBatchPredEntityWithXAI)
assert isinstance(predict_result_explain_ov[0], OTXBatchPredEntity)
assert predict_result_explain_ov[0].has_xai_outputs
assert predict_result_explain_ov[0].saliency_maps is not None
assert isinstance(predict_result_explain_ov[0].saliency_maps[0], dict)
assert predict_result_explain_ov[0].feature_vectors is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/algo/hooks/test_saliency_map_dumping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from otx.algo.utils.xai_utils import dump_saliency_maps
from otx.core.config.explain import ExplainConfig
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.classification import MulticlassClsBatchPredEntityWithXAI
from otx.core.data.entity.classification import MulticlassClsBatchPredEntity
from otx.core.types.task import OTXTaskType
from otx.engine.utils.auto_configurator import AutoConfigurator

Expand All @@ -30,7 +30,7 @@ def test_sal_map_dump(
datamodule = auto_configurator.get_datamodule()

predict_result = [
MulticlassClsBatchPredEntityWithXAI(
MulticlassClsBatchPredEntity(
batch_size=BATCH_SIZE,
images=None,
imgs_info=IMGS_INFO,
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/algo/hooks/test_saliency_map_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from otx.algo.utils.xai_utils import process_saliency_maps, process_saliency_maps_in_pred_entity
from otx.core.config.explain import ExplainConfig
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.classification import MulticlassClsBatchPredEntityWithXAI, MultilabelClsBatchPredEntityWithXAI
from otx.core.data.entity.classification import MulticlassClsBatchPredEntity, MultilabelClsBatchPredEntity
from otx.core.types.explain import TargetExplainGroup

NUM_CLASSES = 5
Expand Down Expand Up @@ -100,8 +100,8 @@ def test_process_image(postprocess) -> None:
assert all(s_map_dict["map_per_image"].shape == (RAW_SIZE, RAW_SIZE) for s_map_dict in processed_saliency_maps)


def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntityWithXAI:
return MulticlassClsBatchPredEntityWithXAI(
def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntity:
return MulticlassClsBatchPredEntity(
batch_size=BATCH_SIZE,
images=None,
imgs_info=IMGS_INFO,
Expand All @@ -112,8 +112,8 @@ def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntityWith
)


def _get_pred_result_multilabel(pred_labels) -> MultilabelClsBatchPredEntityWithXAI:
return MultilabelClsBatchPredEntityWithXAI(
def _get_pred_result_multilabel(pred_labels) -> MultilabelClsBatchPredEntity:
return MultilabelClsBatchPredEntity(
batch_size=BATCH_SIZE,
images=None,
imgs_info=IMGS_INFO,
Expand Down

0 comments on commit a554467

Please sign in to comment.