Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix e2e failed case of visual prompting task #2509

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/otx/algorithms/visual_prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar

from otx.algorithms.common.utils import set_random_seed
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.callbacks import (
InferenceCallback,
Expand Down Expand Up @@ -104,12 +105,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
self.precision = [ModelPrecision.FP32]
self.optimization_type = ModelOptimizationType.MO

self.model = self.load_model(otx_model=task_environment.model)

self.trainer: Trainer

self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())

def set_seed(self):
"""Set seed and deterministic."""
if self.seed is None:
# If the seed is not present via task.train, it will be found in the recipe.
self.seed = self.config.get("seed", 5)
if not self.deterministic:
# deterministic is the same.
self.deterministic = self.config.get("deterministic", False)
self.config["seed"] = self.seed
self.config["deterministic"] = self.deterministic
set_random_seed(self.seed, logger, self.deterministic)

def get_config(self) -> Union[DictConfig, ListConfig]:
"""Get Visual Prompting Config from task environment.

Expand Down Expand Up @@ -226,6 +237,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter
DatasetEntity: Output dataset with predictions.
"""
logger.info("Performing inference on the validation set using the base torch model.")
self.model = self.load_model(otx_model=self.task_environment.model)
datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset)

logger.info("Inference Configs '%s'", self.config)
Expand Down Expand Up @@ -330,6 +342,7 @@ def export( # noqa: D102
"The saliency maps and representation vector outputs will not be dumped in the exported model."
)

self.model = self.load_model(otx_model=self.task_environment.model)
if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
Expand Down
10 changes: 6 additions & 4 deletions src/otx/algorithms/visual_prompting/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Optional

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
Expand Down Expand Up @@ -62,13 +62,15 @@ def train( # noqa: D102

logger.info("Training the model.")

if seed:
logger.info(f"Setting seed to {seed}")
seed_everything(seed, workers=True)
self.seed = seed
self.deterministic = deterministic
self.set_seed()
self.config.trainer.deterministic = "warn" if deterministic else deterministic

logger.info("Training Configs '%s'", self.config)

self.model = self.load_model(otx_model=self.task_environment.model)

datamodule = OTXVisualPromptingDataModule(config=self.config.dataset, dataset=dataset)
loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp)
callbacks = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test_load_model_without_otx_model_or_with_lightning_ckpt(
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything"
)

load_inference_task(path=path, resume=resume)
inference_task = load_inference_task(path=path, resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()

Expand All @@ -111,7 +112,8 @@ def test_load_model_with_pytorch_pth(self, mocker, load_inference_task, resume:
return_value=dict(config=dict(model=dict(backbone="sam_vit_b")), model={}),
)

load_inference_task(path="checkpoint.pth", resume=resume)
inference_task = load_inference_task(path="checkpoint.pth", resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()
mocker_io_bytes_io.assert_called_once()
Expand Down Expand Up @@ -162,6 +164,7 @@ def test_model_info(self, mocker, load_inference_task):
)

inference_task = load_inference_task(output_path=None)
inference_task.model = inference_task.load_model(otx_model=inference_task.task_environment.model)
setattr(inference_task, "trainer", None)
mocker.patch.object(inference_task, "trainer")

Expand Down