From d84d9255c8f033f249132b22647dfa7831d47fc8 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 19 Sep 2023 21:47:35 +0900 Subject: [PATCH] Fix e2e failed case of visual prompting task (#2509) * Move `load_model` after `set_seed` * Fix unit tests --- .../visual_prompting/tasks/inference.py | 17 +++++++++++++++-- .../algorithms/visual_prompting/tasks/train.py | 10 ++++++---- .../visual_prompting/tasks/test_inference.py | 7 +++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/otx/algorithms/visual_prompting/tasks/inference.py b/src/otx/algorithms/visual_prompting/tasks/inference.py index b84984e5fef..2dc63541392 100644 --- a/src/otx/algorithms/visual_prompting/tasks/inference.py +++ b/src/otx/algorithms/visual_prompting/tasks/inference.py @@ -31,6 +31,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import TQDMProgressBar +from otx.algorithms.common.utils import set_random_seed from otx.algorithms.common.utils.logger import get_logger from otx.algorithms.visual_prompting.adapters.pytorch_lightning.callbacks import ( InferenceCallback, @@ -104,12 +105,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] self.precision = [ModelPrecision.FP32] self.optimization_type = ModelOptimizationType.MO - self.model = self.load_model(otx_model=task_environment.model) - self.trainer: Trainer self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + def set_seed(self): + """Set seed and deterministic.""" + if self.seed is None: + # If the seed is not present via task.train, it will be found in the recipe. + self.seed = self.config.get("seed", 5) + if not self.deterministic: + # deterministic is the same. + self.deterministic = self.config.get("deterministic", False) + self.config["seed"] = self.seed + self.config["deterministic"] = self.deterministic + set_random_seed(self.seed, logger, self.deterministic) + def get_config(self) -> Union[DictConfig, ListConfig]: """Get Visual Prompting Config from task environment. @@ -226,6 +237,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter DatasetEntity: Output dataset with predictions. """ logger.info("Performing inference on the validation set using the base torch model.") + self.model = self.load_model(otx_model=self.task_environment.model) datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset) logger.info("Inference Configs '%s'", self.config) @@ -330,6 +342,7 @@ def export( # noqa: D102 "The saliency maps and representation vector outputs will not be dumped in the exported model." ) + self.model = self.load_model(otx_model=self.task_environment.model) if export_type == ExportType.ONNX: output_model.model_format = ModelFormat.ONNX output_model.optimization_type = ModelOptimizationType.ONNX diff --git a/src/otx/algorithms/visual_prompting/tasks/train.py b/src/otx/algorithms/visual_prompting/tasks/train.py index a7305123b0f..67b734a767b 100644 --- a/src/otx/algorithms/visual_prompting/tasks/train.py +++ b/src/otx/algorithms/visual_prompting/tasks/train.py @@ -16,7 +16,7 @@ from typing import Optional -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, @@ -62,13 +62,15 @@ def train( # noqa: D102 logger.info("Training the model.") - if seed: - logger.info(f"Setting seed to {seed}") - seed_everything(seed, workers=True) + self.seed = seed + self.deterministic = deterministic + self.set_seed() self.config.trainer.deterministic = "warn" if deterministic else deterministic logger.info("Training Configs '%s'", self.config) + self.model = self.load_model(otx_model=self.task_environment.model) + datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset) loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp) callbacks = [ diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py index 4196830f93f..3ca7915b2c6 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py @@ -94,7 +94,8 @@ def test_load_model_without_otx_model_or_with_lightning_ckpt( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything" ) - load_inference_task(path=path, resume=resume) + inference_task = load_inference_task(path=path, resume=resume) + inference_task.load_model(otx_model=inference_task.task_environment.model) mocker_segment_anything.assert_called_once() @@ -111,7 +112,8 @@ def test_load_model_with_pytorch_pth(self, mocker, load_inference_task, resume: return_value=dict(config=dict(model=dict(backbone="sam_vit_b")), model={}), ) - load_inference_task(path="checkpoint.pth", resume=resume) + inference_task = load_inference_task(path="checkpoint.pth", resume=resume) + inference_task.load_model(otx_model=inference_task.task_environment.model) mocker_segment_anything.assert_called_once() mocker_io_bytes_io.assert_called_once() @@ -162,6 +164,7 @@ def test_model_info(self, mocker, load_inference_task): ) inference_task = load_inference_task(output_path=None) + inference_task.model = inference_task.load_model(otx_model=inference_task.task_environment.model) setattr(inference_task, "trainer", None) mocker.patch.object(inference_task, "trainer")