Skip to content

Commit

Permalink
Fix e2e failed case of visual prompting task (#2509)
Browse files Browse the repository at this point in the history
* Move `load_model` after `set_seed`

* Fix unit tests
  • Loading branch information
sungchul2 authored Sep 19, 2023
1 parent c50cf00 commit d84d925
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
17 changes: 15 additions & 2 deletions src/otx/algorithms/visual_prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar

from otx.algorithms.common.utils import set_random_seed
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.callbacks import (
InferenceCallback,
Expand Down Expand Up @@ -104,12 +105,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
self.precision = [ModelPrecision.FP32]
self.optimization_type = ModelOptimizationType.MO

self.model = self.load_model(otx_model=task_environment.model)

self.trainer: Trainer

self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())

def set_seed(self):
"""Set seed and deterministic."""
if self.seed is None:
# If the seed is not present via task.train, it will be found in the recipe.
self.seed = self.config.get("seed", 5)
if not self.deterministic:
# deterministic is the same.
self.deterministic = self.config.get("deterministic", False)
self.config["seed"] = self.seed
self.config["deterministic"] = self.deterministic
set_random_seed(self.seed, logger, self.deterministic)

def get_config(self) -> Union[DictConfig, ListConfig]:
"""Get Visual Prompting Config from task environment.
Expand Down Expand Up @@ -226,6 +237,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter
DatasetEntity: Output dataset with predictions.
"""
logger.info("Performing inference on the validation set using the base torch model.")
self.model = self.load_model(otx_model=self.task_environment.model)
datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset)

logger.info("Inference Configs '%s'", self.config)
Expand Down Expand Up @@ -330,6 +342,7 @@ def export( # noqa: D102
"The saliency maps and representation vector outputs will not be dumped in the exported model."
)

self.model = self.load_model(otx_model=self.task_environment.model)
if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
Expand Down
10 changes: 6 additions & 4 deletions src/otx/algorithms/visual_prompting/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Optional

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
Expand Down Expand Up @@ -62,13 +62,15 @@ def train( # noqa: D102

logger.info("Training the model.")

if seed:
logger.info(f"Setting seed to {seed}")
seed_everything(seed, workers=True)
self.seed = seed
self.deterministic = deterministic
self.set_seed()
self.config.trainer.deterministic = "warn" if deterministic else deterministic

logger.info("Training Configs '%s'", self.config)

self.model = self.load_model(otx_model=self.task_environment.model)

datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset)
loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp)
callbacks = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test_load_model_without_otx_model_or_with_lightning_ckpt(
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything"
)

load_inference_task(path=path, resume=resume)
inference_task = load_inference_task(path=path, resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()

Expand All @@ -111,7 +112,8 @@ def test_load_model_with_pytorch_pth(self, mocker, load_inference_task, resume:
return_value=dict(config=dict(model=dict(backbone="sam_vit_b")), model={}),
)

load_inference_task(path="checkpoint.pth", resume=resume)
inference_task = load_inference_task(path="checkpoint.pth", resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()
mocker_io_bytes_io.assert_called_once()
Expand Down Expand Up @@ -162,6 +164,7 @@ def test_model_info(self, mocker, load_inference_task):
)

inference_task = load_inference_task(output_path=None)
inference_task.model = inference_task.load_model(otx_model=inference_task.task_environment.model)
setattr(inference_task, "trainer", None)
mocker.patch.object(inference_task, "trainer")

Expand Down

0 comments on commit d84d925

Please sign in to comment.