Skip to content

Commit

Permalink
Support explain for OV models (#3027)
Browse files Browse the repository at this point in the history
* Support predict&explain for OV models

* Fixes from comments

* Comment fixes

* Fix tests
  • Loading branch information
GalyaZalesskaya authored Mar 6, 2024
1 parent c833f23 commit 64fa5c3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
34 changes: 18 additions & 16 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,7 @@ def predict(

model = self.model

if checkpoint is not None:
checkpoint = str(checkpoint)
elif self.checkpoint is not None:
checkpoint = str(self.checkpoint)
else:
checkpoint = None

checkpoint = checkpoint if checkpoint is not None else self.checkpoint
datamodule = datamodule if datamodule is not None else self.datamodule

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
Expand All @@ -400,8 +394,6 @@ def predict(
)
lit_module.label_info = datamodule.label_info

# NOTE, trainer.test takes only lightning based checkpoint.
# So, it can't take the OTX1.x checkpoint.
if checkpoint is not None and not is_ir_ckpt:
loaded_checkpoint = torch.load(checkpoint)
lit_module.load_state_dict(loaded_checkpoint)
Expand Down Expand Up @@ -595,29 +587,39 @@ def explain(
"""
from otx.algo.utils.xai_utils import dump_saliency_maps, process_saliency_maps_in_pred_entity

ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint
if explain_config is None:
explain_config = ExplainConfig()
model = self.model

checkpoint = checkpoint if checkpoint is not None else self.checkpoint
datamodule = datamodule if datamodule is not None else self.datamodule

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.get_ov_datamodule()
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

lit_module = self._build_lightning_module(
model=self.model,
model=model,
optimizer=self.optimizer,
scheduler=self.scheduler,
)
if datamodule is None:
datamodule = self.datamodule
lit_module.label_info = datamodule.label_info

if checkpoint is not None and not is_ir_ckpt:
loaded_checkpoint = torch.load(checkpoint)
lit_module.load_state_dict(loaded_checkpoint)

lit_module.model.explain_mode = True

self._build_trainer(**kwargs)

predict_result = self.trainer.predict(
model=lit_module,
datamodule=datamodule,
ckpt_path=ckpt_path,
)

if explain_config is None:
explain_config = ExplainConfig()

predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config)
if dump:
dump_saliency_maps(
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/api/test_engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_engine_from_config(
pytest.skip(
reason="H-labels require num_multiclass_head, num_multilabel_classes, which skip until we have the ability to automate this.",
)
if "anomaly" in task.lower():
pytest.skip(reason="There's no dataset for anomaly tasks.")

tmp_path_train = tmp_path / task
engine = Engine.from_config(
Expand Down Expand Up @@ -86,6 +88,37 @@ def test_engine_from_config(
test_metric_from_ov_model = engine.test(checkpoint=exported_model_path, accelerator="cpu")
assert len(test_metric_from_ov_model) > 0

# List of models with explain supported.
if task not in [
OTXTaskType.MULTI_CLASS_CLS,
OTXTaskType.MULTI_LABEL_CLS,
# Will be supported after merging PR#2997
# OTXTaskType.DETECTION,
# OTXTaskType.ROTATED_DETECTION,
# OTXTaskType.INSTANCE_SEGMENTATION,
]:
return

# Predict Torch model with explain
predictions = engine.predict(explain=True)
assert len(predictions[0].saliency_maps) > 0

# Export IR model with explain
exported_model_with_explain = engine.export(explain=True)
assert exported_model_with_explain.exists()

# Infer IR Model with explain: predict
predictions = engine.predict(explain=True, checkpoint=exported_model_with_explain, accelerator="cpu")
assert len(predictions) > 0
sal_maps_from_prediction = predictions[0].saliency_maps
assert len(sal_maps_from_prediction) > 0

# Infer IR Model with explain: explain
explain_results = engine.explain(checkpoint=exported_model_with_explain, accelerator="cpu")
assert len(explain_results[0].saliency_maps) > 0
sal_maps_from_explain = explain_results[0].saliency_maps
assert (sal_maps_from_prediction[0][0] == sal_maps_from_explain[0][0]).all()


@pytest.mark.parametrize("recipe", pytest.TILE_RECIPE_LIST)
def test_engine_from_tile_recipe(
Expand Down

0 comments on commit 64fa5c3

Please sign in to comment.