diff --git a/otx/mpa/det/explainer.py b/otx/mpa/det/explainer.py index 4f1644eee2d..f49673920bb 100644 --- a/otx/mpa/det/explainer.py +++ b/otx/mpa/det/explainer.py @@ -141,13 +141,14 @@ def explain(self, cfg, model_builder=None): model = self.build_model(cfg, model_builder, fp16=False) model.CLASSES = target_classes model.eval() + feature_model = self._get_feature_module(model) model = build_data_parallel(model, cfg, distributed=False) # InferenceProgressCallback (Time Monitor enable into Infer task) self.set_inference_progress_callback(model, cfg) # Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map. - with self.explainer_hook(model.module) as saliency_hook: + with self.explainer_hook(feature_model) as saliency_hook: for data in test_dataloader: _ = model(return_loss=False, rescale=True, **data) saliency_maps = saliency_hook.records diff --git a/otx/mpa/det/stage.py b/otx/mpa/det/stage.py index 6bd05d7b46f..615565e03fd 100644 --- a/otx/mpa/det/stage.py +++ b/otx/mpa/det/stage.py @@ -24,7 +24,7 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs): cfg = self.cfg self.configure_model(cfg, model_cfg, training, **kwargs) self.configure_ckpt(cfg, model_ckpt, kwargs.get("pretrained", None)) - self.configure_data(cfg, training, data_cfg) + self.configure_data(cfg, training, data_cfg, **kwargs) self.configure_regularization(cfg, training) self.configure_hyperparams(cfg, training, **kwargs) self.configure_task(cfg, training, **kwargs) diff --git a/tests/integration/cli/classification/test_classification.py b/tests/integration/cli/classification/test_classification.py index 2e308c4a19d..c633697b3cb 100644 --- a/tests/integration/cli/classification/test_classification.py +++ b/tests/integration/cli/classification/test_classification.py @@ -652,7 +652,7 @@ class TestToolsSelfSLClassification: @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) @set_dummy_data - def test_otx_train(self, template, tmp_dir_path): + def test_otx_selfsl_train(self, template, tmp_dir_path): otx_train_testing(template, tmp_dir_path, otx_dir, args_selfsl) template_work_dir = get_template_dir(template, tmp_dir_path) args1 = copy.deepcopy(args) @@ -662,5 +662,5 @@ def test_otx_train(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.parametrize("template", templates, ids=templates_ids) - def test_otx_eval(self, template, tmp_dir_path): + def test_otx_selfsl_eval(self, template, tmp_dir_path): otx_eval_testing(template, tmp_dir_path, otx_dir, args)