Skip to content

Commit

Permalink
Set train task & don't use project
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Jun 9, 2023
1 parent 66e432a commit 0c0e683
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SAMPromptEncoder,
)

import pytorch_lightning as pl
from pytorch_lightning import LightningModule


CKPT_PATHS = {
Expand All @@ -38,7 +38,7 @@
}


class SegmentAnything(pl.LightningModule):
class SegmentAnything(LightningModule):
def __init__(
self,
image_encoder: nn.Module,
Expand All @@ -62,6 +62,8 @@ def __init__(
checkpoint (optional, str): Checkpoint path to be loaded, default is None.
"""
super().__init__()
# self.save_hyperparameters()

self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
Expand Down Expand Up @@ -149,7 +151,13 @@ def training_step(self, batch, batch_idx):
loss_iou += F.mse_loss(iou_prediction, batch_iou.unsqueeze(1), reduction='sum') / num_masks

loss = 20. * loss_focal + loss_dice + loss_iou
results = dict(iou=self.train_iou, f1=self.train_f1, loss=loss, loss_focal=loss_focal, loss_dice=loss_dice, loss_iou=loss_iou)
results = dict(
train_IoU=self.train_iou,
train_F1=self.train_f1,
train_loss=loss,
train_loss_focal=loss_focal,
train_loss_dice=loss_dice,
train_loss_iou=loss_iou)
self.log_dict(results, prog_bar=True)

return loss
Expand All @@ -165,7 +173,7 @@ def validation_step(self, batch, batch_idx):
self.val_iou(pred_mask, gt_mask)
self.val_f1(pred_mask, gt_mask)

results = dict(iou=self.val_iou, f1=self.val_f1)
results = dict(val_IoU=self.val_iou, val_F1=self.val_f1)
self.log_dict(results, on_epoch=True, prog_bar=True)

return results
Expand Down
29 changes: 7 additions & 22 deletions otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,13 @@ model:
backbone: vit_b
checkpoint: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

metrics:
image:
- F1Score
- AUROC
pixel:
- IoU
- AUROC
threshold:
method: adaptive #options: [adaptive, manual]
manual_image: null
manual_pixel: null

visualization:
show_images: False # show images on the screen
save_images: True # save images to the file system
log_images: True # log images to the available loggers (if any)
image_save_path: null # path to which images will be saved
mode: full # options: ["full", "simple"]

project:
seed: 42
path: ./results
callback:
checkpoint:
# arguments for ModelCheckpoint
monitor: val_IoU
mode: max
save_last: true
verbose: true

logging:
logger: [csv] # options: [comet, tensorboard, wandb, csv] or combinations.
Expand Down
60 changes: 16 additions & 44 deletions otx/algorithms/visual_prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,22 @@

import ctypes
import io
import time
import os
import shutil
import subprocess # nosec
import tempfile
from glob import glob
from typing import Dict, List, Optional, Union
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything import sam_model_registry
from torch import optim, nn, utils, Tensor
import segmentation_models_pytorch as smp

import torch.nn.functional as F
import torch
from anomalib.models import AnomalyModule, get_model
from anomalib.post_processing import NormalizationMethod, ThresholdMethod
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
PostProcessingConfigurationCallback,
)
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Trainer

from otx.algorithms.anomaly.adapters.anomalib.callbacks import (
AnomalyInferenceCallback,
ProgressCallback,
)
# from otx.algorithms.anomaly.adapters.anomalib.config import get_anomalib_config
from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.config import get_visual_promtping_config
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.anomaly.configs.base.configuration import BaseAnomalyConfig
from otx.algorithms.visual_prompting.configs.base.configuration import VisualPromptingConfig
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters
Expand All @@ -70,7 +55,8 @@
from otx.api.usecases.tasks.interfaces.export_interface import ExportType, IExportTask
from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask
from otx.api.usecases.tasks.interfaces.unload_interface import IUnload
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets import OTXVisualPromptingDataModule

logger = get_logger()

Expand Down Expand Up @@ -98,10 +84,10 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]

# Hyperparameters.
self._work_dir_is_temp = False
if output_path is None:
output_path = tempfile.mkdtemp(prefix="otx-visual_prompting")
self.output_path = output_path
if self.output_path is None:
self.output_path = tempfile.mkdtemp(prefix="otx-visual_prompting")
self._work_dir_is_temp = True
self.project_path: str = output_path
self.config = self.get_config()

# Set default model attributes.
Expand All @@ -112,6 +98,8 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
self.model = self.load_model(otx_model=task_environment.model)

self.trainer: Trainer

self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())

def get_config(self) -> Union[DictConfig, ListConfig]:
"""Get Visual Prompting Config from task environment.
Expand All @@ -121,13 +109,12 @@ def get_config(self) -> Union[DictConfig, ListConfig]:
"""
self.hyper_parameters: VisualPromptingConfig = self.task_environment.get_hyper_parameters()
config = get_visual_promtping_config(task_name=self.model_name, otx_config=self.hyper_parameters)
config.project.path = self.project_path

config.dataset.task = "visual_prompting"

return config

def load_model(self, otx_model: Optional[ModelEntity]) -> pl.LightningModule:
def load_model(self, otx_model: Optional[ModelEntity]) -> LightningModule:
"""Create and Load Visual Prompting Module.
Currently, load model through `sam_model_registry` because there is only SAM.
Expand All @@ -137,7 +124,7 @@ def load_model(self, otx_model: Optional[ModelEntity]) -> pl.LightningModule:
otx_model (Optional[ModelEntity]): OTX Model from the task environment.
Returns:
pl.LightningModule: Visual prompting model with/without weights.
LightningModule: Visual prompting model with/without weights.
"""
if otx_model is None:
backbone = self.config.model.backbone
Expand Down Expand Up @@ -175,29 +162,14 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter
DatasetEntity: Output dataset with predictions.
"""
logger.info("Performing inference on the validation set using the base torch model.")
config = self.get_config()
datamodule = OTXAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
datamodule = OTXVisualPromptingDataModule(config=self.config, dataset=dataset)

logger.info("Inference Configs '%s'", config)
logger.info("Inference Configs '%s'", self.config)

# Callbacks.
progress = ProgressCallback(parameters=inference_parameters)
inference = AnomalyInferenceCallback(dataset, self.labels, self.task_type)
normalize = MinMaxNormalizationCallback()
metrics_configuration = MetricsConfigurationCallback(
task=config.dataset.task,
image_metrics=config.metrics.image,
pixel_metrics=config.metrics.get("pixel"),
)
post_processing_configuration = PostProcessingConfigurationCallback(
normalization_method=NormalizationMethod.MIN_MAX,
threshold_method=ThresholdMethod.ADAPTIVE,
manual_image_threshold=config.metrics.threshold.manual_image,
manual_pixel_threshold=config.metrics.threshold.manual_pixel,
)
callbacks = [progress, normalize, inference, metrics_configuration, post_processing_configuration]
callbacks = [ProgressCallback(parameters=inference_parameters)]

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks)
self.trainer.predict(model=self.model, datamodule=datamodule)
return dataset

Expand Down Expand Up @@ -311,7 +283,7 @@ def model_info(self) -> Dict:
return {
"model": self.model.state_dict(),
"config": self.get_config(),
"VERSION": 1,
"version": self.trainer.logger.version,
}

def save_model(self, output_model: ModelEntity) -> None:
Expand Down
28 changes: 4 additions & 24 deletions otx/algorithms/visual_prompting/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,9 @@
from typing import Optional

import torch
from anomalib.models import AnomalyModule, get_model
from anomalib.post_processing import NormalizationMethod, ThresholdMethod
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
PostProcessingConfigurationCallback,
)
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.common.utils.logger import get_logger
Expand Down Expand Up @@ -73,28 +67,14 @@ def train(
logger.info("Training Configs '%s'", self.config)

datamodule = OTXVisualPromptingDataModule(config=self.config, dataset=dataset)
loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp)
callbacks = [
# LearningRateMonitor(logging_interval='step'),
ProgressCallback(parameters=train_parameters),
ModelCheckpoint(monitor="iou", mode="max"),
# MinMaxNormalizationCallback(),
# MetricsConfigurationCallback(
# task=config.dataset.task,
# image_metrics=config.metrics.image,
# pixel_metrics=config.metrics.get("pixel"),
# ),
# PostProcessingConfigurationCallback(
# normalization_method=NormalizationMethod.MIN_MAX,
# threshold_method=ThresholdMethod.ADAPTIVE,
# manual_image_threshold=config.metrics.threshold.manual_image,
# manual_pixel_threshold=config.metrics.threshold.manual_pixel,
# ),
ModelCheckpoint(dirpath=loggers.log_dir, **self.config.callback.checkpoint),
]

self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks)
self.trainer = Trainer(**self.config.trainer, logger=loggers, callbacks=callbacks)
self.trainer.fit(model=self.model, datamodule=datamodule)
logger.info("Evaluation with best checkpoint.")
self.trainer.validate(model=self.model, datamodule=datamodule)

self.save_model(output_model)

Expand Down

0 comments on commit 0c0e683

Please sign in to comment.