From 0c0e683069b332adfde50d7fc8d3584b7ba3b86a Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 9 Jun 2023 14:07:29 +0900 Subject: [PATCH] Set train task & don't use project --- .../visual_prompters/segment_anything.py | 16 +++-- .../configs/sam_vit_b/config.yaml | 29 +++------ .../visual_prompting/tasks/inference.py | 60 +++++-------------- .../visual_prompting/tasks/train.py | 28 ++------- 4 files changed, 39 insertions(+), 94 deletions(-) diff --git a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py index ad1102212b4..cc3055dba8c 100644 --- a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py +++ b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py @@ -28,7 +28,7 @@ SAMPromptEncoder, ) -import pytorch_lightning as pl +from pytorch_lightning import LightningModule CKPT_PATHS = { @@ -38,7 +38,7 @@ } -class SegmentAnything(pl.LightningModule): +class SegmentAnything(LightningModule): def __init__( self, image_encoder: nn.Module, @@ -62,6 +62,8 @@ def __init__( checkpoint (optional, str): Checkpoint path to be loaded, default is None. """ super().__init__() + # self.save_hyperparameters() + self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder @@ -149,7 +151,13 @@ def training_step(self, batch, batch_idx): loss_iou += F.mse_loss(iou_prediction, batch_iou.unsqueeze(1), reduction='sum') / num_masks loss = 20. * loss_focal + loss_dice + loss_iou - results = dict(iou=self.train_iou, f1=self.train_f1, loss=loss, loss_focal=loss_focal, loss_dice=loss_dice, loss_iou=loss_iou) + results = dict( + train_IoU=self.train_iou, + train_F1=self.train_f1, + train_loss=loss, + train_loss_focal=loss_focal, + train_loss_dice=loss_dice, + train_loss_iou=loss_iou) self.log_dict(results, prog_bar=True) return loss @@ -165,7 +173,7 @@ def validation_step(self, batch, batch_idx): self.val_iou(pred_mask, gt_mask) self.val_f1(pred_mask, gt_mask) - results = dict(iou=self.val_iou, f1=self.val_f1) + results = dict(val_IoU=self.val_iou, val_F1=self.val_f1) self.log_dict(results, on_epoch=True, prog_bar=True) return results diff --git a/otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml b/otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml index eeff662a660..c874d832fd0 100644 --- a/otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml +++ b/otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml @@ -23,28 +23,13 @@ model: backbone: vit_b checkpoint: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -metrics: - image: - - F1Score - - AUROC - pixel: - - IoU - - AUROC - threshold: - method: adaptive #options: [adaptive, manual] - manual_image: null - manual_pixel: null - -visualization: - show_images: False # show images on the screen - save_images: True # save images to the file system - log_images: True # log images to the available loggers (if any) - image_save_path: null # path to which images will be saved - mode: full # options: ["full", "simple"] - -project: - seed: 42 - path: ./results +callback: + checkpoint: + # arguments for ModelCheckpoint + monitor: val_IoU + mode: max + save_last: true + verbose: true logging: logger: [csv] # options: [comet, tensorboard, wandb, csv] or combinations. diff --git a/otx/algorithms/visual_prompting/tasks/inference.py b/otx/algorithms/visual_prompting/tasks/inference.py index 6404d141a55..a23247dc0a8 100644 --- a/otx/algorithms/visual_prompting/tasks/inference.py +++ b/otx/algorithms/visual_prompting/tasks/inference.py @@ -16,6 +16,7 @@ import ctypes import io +import time import os import shutil import subprocess # nosec @@ -23,30 +24,14 @@ from glob import glob from typing import Dict, List, Optional, Union from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything import sam_model_registry -from torch import optim, nn, utils, Tensor -import segmentation_models_pytorch as smp -import torch.nn.functional as F import torch -from anomalib.models import AnomalyModule, get_model -from anomalib.post_processing import NormalizationMethod, ThresholdMethod -from anomalib.utils.callbacks import ( - MetricsConfigurationCallback, - MinMaxNormalizationCallback, - PostProcessingConfigurationCallback, -) from omegaconf import DictConfig, ListConfig from pytorch_lightning import Trainer -from otx.algorithms.anomaly.adapters.anomalib.callbacks import ( - AnomalyInferenceCallback, - ProgressCallback, -) -# from otx.algorithms.anomaly.adapters.anomalib.config import get_anomalib_config +from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback from otx.algorithms.visual_prompting.adapters.pytorch_lightning.config import get_visual_promtping_config -from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule from otx.algorithms.common.utils.logger import get_logger -from otx.algorithms.anomaly.configs.base.configuration import BaseAnomalyConfig from otx.algorithms.visual_prompting.configs.base.configuration import VisualPromptingConfig from otx.api.entities.datasets import DatasetEntity from otx.api.entities.inference_parameters import InferenceParameters @@ -70,7 +55,8 @@ from otx.api.usecases.tasks.interfaces.export_interface import ExportType, IExportTask from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask from otx.api.usecases.tasks.interfaces.unload_interface import IUnload -import pytorch_lightning as pl +from pytorch_lightning import LightningModule +from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets import OTXVisualPromptingDataModule logger = get_logger() @@ -98,10 +84,10 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] # Hyperparameters. self._work_dir_is_temp = False - if output_path is None: - output_path = tempfile.mkdtemp(prefix="otx-visual_prompting") + self.output_path = output_path + if self.output_path is None: + self.output_path = tempfile.mkdtemp(prefix="otx-visual_prompting") self._work_dir_is_temp = True - self.project_path: str = output_path self.config = self.get_config() # Set default model attributes. @@ -112,6 +98,8 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] self.model = self.load_model(otx_model=task_environment.model) self.trainer: Trainer + + self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) def get_config(self) -> Union[DictConfig, ListConfig]: """Get Visual Prompting Config from task environment. @@ -121,13 +109,12 @@ def get_config(self) -> Union[DictConfig, ListConfig]: """ self.hyper_parameters: VisualPromptingConfig = self.task_environment.get_hyper_parameters() config = get_visual_promtping_config(task_name=self.model_name, otx_config=self.hyper_parameters) - config.project.path = self.project_path config.dataset.task = "visual_prompting" return config - def load_model(self, otx_model: Optional[ModelEntity]) -> pl.LightningModule: + def load_model(self, otx_model: Optional[ModelEntity]) -> LightningModule: """Create and Load Visual Prompting Module. Currently, load model through `sam_model_registry` because there is only SAM. @@ -137,7 +124,7 @@ def load_model(self, otx_model: Optional[ModelEntity]) -> pl.LightningModule: otx_model (Optional[ModelEntity]): OTX Model from the task environment. Returns: - pl.LightningModule: Visual prompting model with/without weights. + LightningModule: Visual prompting model with/without weights. """ if otx_model is None: backbone = self.config.model.backbone @@ -175,29 +162,14 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter DatasetEntity: Output dataset with predictions. """ logger.info("Performing inference on the validation set using the base torch model.") - config = self.get_config() - datamodule = OTXAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type) + datamodule = OTXVisualPromptingDataModule(config=self.config, dataset=dataset) - logger.info("Inference Configs '%s'", config) + logger.info("Inference Configs '%s'", self.config) # Callbacks. - progress = ProgressCallback(parameters=inference_parameters) - inference = AnomalyInferenceCallback(dataset, self.labels, self.task_type) - normalize = MinMaxNormalizationCallback() - metrics_configuration = MetricsConfigurationCallback( - task=config.dataset.task, - image_metrics=config.metrics.image, - pixel_metrics=config.metrics.get("pixel"), - ) - post_processing_configuration = PostProcessingConfigurationCallback( - normalization_method=NormalizationMethod.MIN_MAX, - threshold_method=ThresholdMethod.ADAPTIVE, - manual_image_threshold=config.metrics.threshold.manual_image, - manual_pixel_threshold=config.metrics.threshold.manual_pixel, - ) - callbacks = [progress, normalize, inference, metrics_configuration, post_processing_configuration] + callbacks = [ProgressCallback(parameters=inference_parameters)] - self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks) + self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) self.trainer.predict(model=self.model, datamodule=datamodule) return dataset @@ -311,7 +283,7 @@ def model_info(self) -> Dict: return { "model": self.model.state_dict(), "config": self.get_config(), - "VERSION": 1, + "version": self.trainer.logger.version, } def save_model(self, output_model: ModelEntity) -> None: diff --git a/otx/algorithms/visual_prompting/tasks/train.py b/otx/algorithms/visual_prompting/tasks/train.py index 46ac52870d4..eb0f98e5535 100644 --- a/otx/algorithms/visual_prompting/tasks/train.py +++ b/otx/algorithms/visual_prompting/tasks/train.py @@ -18,15 +18,9 @@ from typing import Optional import torch -from anomalib.models import AnomalyModule, get_model -from anomalib.post_processing import NormalizationMethod, ThresholdMethod -from anomalib.utils.callbacks import ( - MetricsConfigurationCallback, - MinMaxNormalizationCallback, - PostProcessingConfigurationCallback, -) from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback from otx.algorithms.common.utils.logger import get_logger @@ -73,28 +67,14 @@ def train( logger.info("Training Configs '%s'", self.config) datamodule = OTXVisualPromptingDataModule(config=self.config, dataset=dataset) + loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp) callbacks = [ - # LearningRateMonitor(logging_interval='step'), ProgressCallback(parameters=train_parameters), - ModelCheckpoint(monitor="iou", mode="max"), - # MinMaxNormalizationCallback(), - # MetricsConfigurationCallback( - # task=config.dataset.task, - # image_metrics=config.metrics.image, - # pixel_metrics=config.metrics.get("pixel"), - # ), - # PostProcessingConfigurationCallback( - # normalization_method=NormalizationMethod.MIN_MAX, - # threshold_method=ThresholdMethod.ADAPTIVE, - # manual_image_threshold=config.metrics.threshold.manual_image, - # manual_pixel_threshold=config.metrics.threshold.manual_pixel, - # ), + ModelCheckpoint(dirpath=loggers.log_dir, **self.config.callback.checkpoint), ] - self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) + self.trainer = Trainer(**self.config.trainer, logger=loggers, callbacks=callbacks) self.trainer.fit(model=self.model, datamodule=datamodule) - logger.info("Evaluation with best checkpoint.") - self.trainer.validate(model=self.model, datamodule=datamodule) self.save_model(output_model)