diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 033b7ab2870..48daf0b760b 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -379,13 +379,7 @@ def predict( model = self.model - if checkpoint is not None: - checkpoint = str(checkpoint) - elif self.checkpoint is not None: - checkpoint = str(self.checkpoint) - else: - checkpoint = None - + checkpoint = checkpoint if checkpoint is not None else self.checkpoint datamodule = datamodule if datamodule is not None else self.datamodule is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] @@ -400,8 +394,6 @@ def predict( ) lit_module.label_info = datamodule.label_info - # NOTE, trainer.test takes only lightning based checkpoint. - # So, it can't take the OTX1.x checkpoint. if checkpoint is not None and not is_ir_ckpt: loaded_checkpoint = torch.load(checkpoint) lit_module.load_state_dict(loaded_checkpoint) @@ -595,19 +587,27 @@ def explain( """ from otx.algo.utils.xai_utils import dump_saliency_maps, process_saliency_maps_in_pred_entity - ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint - if explain_config is None: - explain_config = ExplainConfig() + model = self.model + + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + datamodule = datamodule if datamodule is not None else self.datamodule + + is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] + if is_ir_ckpt and not isinstance(model, OVModel): + datamodule = self._auto_configurator.get_ov_datamodule() + model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) lit_module = self._build_lightning_module( - model=self.model, + model=model, optimizer=self.optimizer, scheduler=self.scheduler, ) - if datamodule is None: - datamodule = self.datamodule lit_module.label_info = datamodule.label_info + if checkpoint is not None and not is_ir_ckpt: + loaded_checkpoint = torch.load(checkpoint) + lit_module.load_state_dict(loaded_checkpoint) + lit_module.model.explain_mode = True self._build_trainer(**kwargs) @@ -615,9 +615,11 @@ def explain( predict_result = self.trainer.predict( model=lit_module, datamodule=datamodule, - ckpt_path=ckpt_path, ) + if explain_config is None: + explain_config = ExplainConfig() + predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config) if dump: dump_saliency_maps( diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index 202e4097042..b26a297e6c2 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -37,6 +37,8 @@ def test_engine_from_config( pytest.skip( reason="H-labels require num_multiclass_head, num_multilabel_classes, which skip until we have the ability to automate this.", ) + if "anomaly" in task.lower(): + pytest.skip(reason="There's no dataset for anomaly tasks.") tmp_path_train = tmp_path / task engine = Engine.from_config( @@ -86,6 +88,37 @@ def test_engine_from_config( test_metric_from_ov_model = engine.test(checkpoint=exported_model_path, accelerator="cpu") assert len(test_metric_from_ov_model) > 0 + # List of models with explain supported. + if task not in [ + OTXTaskType.MULTI_CLASS_CLS, + OTXTaskType.MULTI_LABEL_CLS, + # Will be supported after merging PR#2997 + # OTXTaskType.DETECTION, + # OTXTaskType.ROTATED_DETECTION, + # OTXTaskType.INSTANCE_SEGMENTATION, + ]: + return + + # Predict Torch model with explain + predictions = engine.predict(explain=True) + assert len(predictions[0].saliency_maps) > 0 + + # Export IR model with explain + exported_model_with_explain = engine.export(explain=True) + assert exported_model_with_explain.exists() + + # Infer IR Model with explain: predict + predictions = engine.predict(explain=True, checkpoint=exported_model_with_explain, accelerator="cpu") + assert len(predictions) > 0 + sal_maps_from_prediction = predictions[0].saliency_maps + assert len(sal_maps_from_prediction) > 0 + + # Infer IR Model with explain: explain + explain_results = engine.explain(checkpoint=exported_model_with_explain, accelerator="cpu") + assert len(explain_results[0].saliency_maps) > 0 + sal_maps_from_explain = explain_results[0].saliency_maps + assert (sal_maps_from_prediction[0][0] == sal_maps_from_explain[0][0]).all() + @pytest.mark.parametrize("recipe", pytest.TILE_RECIPE_LIST) def test_engine_from_tile_recipe(