Skip to content

Commit

Permalink
fix: typo
Browse files Browse the repository at this point in the history
  • Loading branch information
cih9088 committed Jan 12, 2023
1 parent bf1012c commit 1a75de3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion otx/mpa/det/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ def explain(self, cfg, model_builder=None):
model = self.build_model(cfg, model_builder, fp16=False)
model.CLASSES = target_classes
model.eval()
feature_model = self._get_feature_module(model)
model = build_data_parallel(model, cfg, distributed=False)

# InferenceProgressCallback (Time Monitor enable into Infer task)
self.set_inference_progress_callback(model, cfg)

# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
with self.explainer_hook(model.module) as saliency_hook:
with self.explainer_hook(feature_model) as saliency_hook:
for data in test_dataloader:
_ = model(return_loss=False, rescale=True, **data)
saliency_maps = saliency_hook.records
Expand Down
2 changes: 1 addition & 1 deletion otx/mpa/det/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs):
cfg = self.cfg
self.configure_model(cfg, model_cfg, training, **kwargs)
self.configure_ckpt(cfg, model_ckpt, kwargs.get("pretrained", None))
self.configure_data(cfg, training, data_cfg)
self.configure_data(cfg, training, data_cfg, **kwargs)
self.configure_regularization(cfg, training)
self.configure_hyperparams(cfg, training, **kwargs)
self.configure_task(cfg, training, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/classification/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ class TestToolsSelfSLClassification:
@e2e_pytest_component
@pytest.mark.parametrize("template", templates, ids=templates_ids)
@set_dummy_data
def test_otx_train(self, template, tmp_dir_path):
def test_otx_selfsl_train(self, template, tmp_dir_path):
otx_train_testing(template, tmp_dir_path, otx_dir, args_selfsl)
template_work_dir = get_template_dir(template, tmp_dir_path)
args1 = copy.deepcopy(args)
Expand All @@ -662,5 +662,5 @@ def test_otx_train(self, template, tmp_dir_path):
@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_eval(self, template, tmp_dir_path):
def test_otx_selfsl_eval(self, template, tmp_dir_path):
otx_eval_testing(template, tmp_dir_path, otx_dir, args)

0 comments on commit 1a75de3

Please sign in to comment.