diff --git a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py index 9047e73da07..43dc194c259 100644 --- a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py +++ b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py @@ -44,18 +44,26 @@ class OTXVIsualPromptingDataset(Dataset): - """Visual Prompting Dataset Adaptor.""" + """Visual Prompting Dataset Adaptor. + + Args: + config + dataset + transform + stage + """ def __init__( self, config: Union[DictConfig, ListConfig], dataset: DatasetEntity, - transform: MultipleInputsCompose + transform: MultipleInputsCompose, ) -> None: self.config = config self.dataset = dataset self.transform = transform + self.labels = dataset.get_labels() self.label_idx = {label.id: i for i, label in enumerate(self.labels)} @@ -133,11 +141,11 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, Tensor]]: item.update(dict( original_size=(height, width), - image=dataset_item.numpy, - mask=masks, - bbox=bboxes, - label=labels, - point=None, # TODO (sungchul): update point information + images=dataset_item.numpy, + masks=masks, + bboxes=bboxes, + labels=labels, + points=None, # TODO (sungchul): update point information )) item = self.transform(item) return item @@ -171,31 +179,36 @@ def setup(self, stage: Optional[str] = None) -> None: image_size = [image_size] if stage == "fit" or stage is None: - self.train_otx_dataset = self.dataset.get_subset(Subset.TRAINING) - self.val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION) + train_otx_dataset = self.dataset.get_subset(Subset.TRAINING) + val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION) # TODO (sungchul): distinguish between train and val config here - self.train_transform = self.val_transform = MultipleInputsCompose([ + train_transform = val_transform = MultipleInputsCompose([ ResizeLongestSide(target_length=max(image_size)), Pad(), transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) ]) + self.train_dataset = OTXVIsualPromptingDataset(self.config, train_otx_dataset, train_transform) + self.val_dataset = OTXVIsualPromptingDataset(self.config, val_otx_dataset, val_transform) + if stage == "test": - self.test_otx_dataset = self.dataset.get_subset(Subset.TESTING) - self.test_transform = MultipleInputsCompose([ + test_otx_dataset = self.dataset.get_subset(Subset.TESTING) + test_transform = MultipleInputsCompose([ ResizeLongestSide(target_length=max(image_size)), Pad(), transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) ]) + self.test_dataset = OTXVIsualPromptingDataset(self.config, test_otx_dataset, test_transform) if stage == "predict": - self.predict_otx_dataset = self.dataset - self.predict_transform = MultipleInputsCompose([ + predict_otx_dataset = self.dataset + predict_transform = MultipleInputsCompose([ ResizeLongestSide(target_length=max(image_size)), Pad(), transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) ]) + self.predict_dataset = OTXVIsualPromptingDataset(self.config, predict_otx_dataset, predict_transform) def summary(self): """Print size of the dataset, number of images.""" @@ -216,9 +229,8 @@ def train_dataloader( Returns: Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: Train dataloader. """ - dataset = OTXVIsualPromptingDataset(self.config, self.train_otx_dataset, self.train_transform) return DataLoader( - dataset, + self.train_dataset, shuffle=False, batch_size=self.config.dataset.train_batch_size, num_workers=self.config.dataset.num_workers, @@ -231,9 +243,8 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: Returns: Union[DataLoader, List[DataLoader]]: Validation Dataloader. """ - dataset = OTXVIsualPromptingDataset(self.config, self.val_otx_dataset, self.val_transform) return DataLoader( - dataset, + self.val_dataset, shuffle=False, batch_size=self.config.dataset.eval_batch_size, num_workers=self.config.dataset.num_workers, @@ -246,9 +257,8 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: Returns: Union[DataLoader, List[DataLoader]]: Test Dataloader. """ - dataset = OTXVIsualPromptingDataset(self.config, self.test_otx_dataset, self.test_transform) return DataLoader( - dataset, + self.test_dataset, shuffle=False, batch_size=self.config.dataset.test_batch_size, num_workers=self.config.dataset.num_workers, @@ -261,9 +271,8 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: Returns: Union[DataLoader, List[DataLoader]]: Predict Dataloader. """ - dataset = OTXVIsualPromptingDataset(self.config, self.predict_otx_dataset, self.predict_transform) return DataLoader( - dataset, + self.predict_dataset, shuffle=False, batch_size=self.config.dataset.eval_batch_size, num_workers=self.config.dataset.num_workers, diff --git a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py index ceb912c1c68..3271c9c7645 100644 --- a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py +++ b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py @@ -20,14 +20,20 @@ def collate_fn(batch): - index = [item['index'] for item in batch] - image = torch.stack([item['image'] for item in batch]) - bbox = [torch.tensor(item['bbox']) for item in batch] - mask = [torch.stack(item['mask']) for item in batch if item['mask'] != []] - label = [item['label'] for item in batch] if batch else [] - if mask: - return {'index': index, 'image': image, 'bbox': bbox, 'mask': mask, 'label': label} - return {'index': -1, 'image': [], 'bbox': [], 'mask': [], 'label': []} + def _convert_empty_to_none(x): + func = torch.stack if x == "masks" else torch.tensor + items = [func(item[x]) for item in batch if item[x]] + return None if len(items) == 0 else items + + index = [item["index"] for item in batch] + images = torch.stack([item["images"] for item in batch]) + bboxes = _convert_empty_to_none("bboxes") + points = _convert_empty_to_none("points") + masks = _convert_empty_to_none("masks") + labels = [item["labels"] for item in batch] + if masks: + return {"index": index, "images": images, "bboxes": bboxes, "points": points, "masks": masks, "label": labels} + return {"index": -1, "images": [], "bboxes": [], "points": [], "masks": [], "labels": []} class ResizeLongestSide: @@ -44,14 +50,13 @@ def __init__(self, target_length: int) -> None: self.target_length = target_length def __call__(self, item: Dict[str, Union[int, Tensor]]): - item["image"] = torch.as_tensor( - self.apply_image(item["image"]).transpose((2, 0, 1)), + item["images"] = torch.as_tensor( + self.apply_image(item["images"]).transpose((2, 0, 1)), dtype=torch.get_default_dtype()) - item["mask"] = [torch.as_tensor(self.apply_image(mask)) for mask in item["mask"]] - item["bbox"] = self.apply_boxes(item["bbox"], item["original_size"]) - if item["point"]: - item["point"] = self.apply_coords(item["point"], item["original_size"]) - + item["masks"] = [torch.as_tensor(self.apply_image(mask)) for mask in item["masks"]] + item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"]) + if item["points"]: + item["points"] = self.apply_coords(item["points"], item["original_size"]) return item def apply_image(self, image: np.ndarray) -> np.ndarray: @@ -130,15 +135,17 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[i class Pad: """""" def __call__(self, item: Dict[str, Union[int, Tensor]]): - _, h, w = item["image"].shape + _, h, w = item["images"].shape max_dim = max(w, h) pad_w = (max_dim - w) // 2 pad_h = (max_dim - h) // 2 padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) - item["image"] = transforms.functional.pad(item["image"], padding, fill=0, padding_mode="constant") - item["mask"] = [transforms.functional.pad(mask, padding, fill=0, padding_mode="constant") for mask in item["mask"]] - item["bbox"] = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in item["bbox"]] + item["images"] = transforms.functional.pad(item["images"], padding, fill=0, padding_mode="constant") + item["masks"] = [transforms.functional.pad(mask, padding, fill=0, padding_mode="constant") for mask in item["masks"]] + item["bboxes"] = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in item["bboxes"]] + if item["points"]: + item["points"] = [[point[0] + pad_w, point[1] + pad_h, point[2] + pad_w, point[3] + pad_h] for point in item["points"]] return item @@ -147,7 +154,7 @@ class MultipleInputsCompose(Compose): def __call__(self, item: Dict[str, Union[int, Tensor]]): for t in self.transforms: if isinstance(t, transforms.Normalize): - item["image"] = t(item["image"]) + item["images"] = t(item["images"]) else: item = t(item) return item diff --git a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py index cc3055dba8c..4939dd9367f 100644 --- a/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py +++ b/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py @@ -47,7 +47,9 @@ def __init__( freeze_image_encoder: bool = True, freeze_prompt_encoder: bool = True, freeze_mask_decoder: bool = False, - checkpoint: str = None + checkpoint: str = None, + mask_threshold: float = 0., + return_logits: bool = False ) -> None: """ SAM predicts object masks from an image and input prompts. @@ -60,6 +62,7 @@ def __init__( freeze_prompt_encoder (bool): Whether freezing prompt encoder, default is True. freeze_mask_decoder (bool): Whether freezing mask decoder, default is False. checkpoint (optional, str): Checkpoint path to be loaded, default is None. + mask_threshold (float): """ super().__init__() # self.save_hyperparameters() @@ -67,6 +70,8 @@ def __init__( self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder + self.mask_threshold = mask_threshold + self.return_logits = return_logits if freeze_image_encoder: for param in self.image_encoder.parameters(): @@ -95,14 +100,14 @@ def __init__( state_dict = torch.load(f) self.load_state_dict(state_dict) - def forward(self, images, bboxes): + def forward(self, images, bboxes, points=None): _, _, height, width = images.shape image_embeddings = self.image_encoder(images) pred_masks = [] ious = [] for embedding, bbox in zip(image_embeddings, bboxes): sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=None, + points=points, boxes=bbox, masks=None, ) @@ -128,11 +133,12 @@ def forward(self, images, bboxes): def training_step(self, batch, batch_idx): """Training step of SAM.""" - images = batch["image"] - bboxes = batch["bbox"] - gt_masks = batch["mask"] + images = batch["images"] + bboxes = batch["bboxes"] + points = batch["points"] + gt_masks = batch["masks"] - pred_masks, ious = self(images, bboxes) + pred_masks, ious = self(images, bboxes, points) loss_focal = 0. loss_dice = 0. @@ -164,11 +170,12 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Validation step of SAM.""" - images = batch["image"] - bboxes = batch["bbox"] - gt_masks = batch["mask"] + images = batch["images"] + bboxes = batch["bboxes"] + points = batch["points"] + gt_masks = batch["masks"] - pred_masks, _ = self(images, bboxes) + pred_masks, _ = self(images, bboxes, points) for pred_mask, gt_mask in zip(pred_masks, gt_masks): self.val_iou(pred_mask, gt_mask) self.val_f1(pred_mask, gt_mask) @@ -178,6 +185,20 @@ def validation_step(self, batch, batch_idx): return results + def predict_step(self, batch, batch_idx): + """Predict step of SAM.""" + images = batch["images"] + bboxes = batch["bboxes"] + points = batch["points"] + + pred_masks, _ = self(images, bboxes, points) + + masks = self.postprocess_masks(pred_masks, self.input_size, self.original_size) + if not self.return_logits: + masks = masks > self.mask_threshold + + return masks + def postprocess_masks( self, masks: torch.Tensor, diff --git a/otx/algorithms/visual_prompting/tasks/inference.py b/otx/algorithms/visual_prompting/tasks/inference.py index a23247dc0a8..ab95cc25293 100644 --- a/otx/algorithms/visual_prompting/tasks/inference.py +++ b/otx/algorithms/visual_prompting/tasks/inference.py @@ -15,12 +15,14 @@ # and limitations under the License. import ctypes +from copy import deepcopy import io import time import os import shutil import subprocess # nosec import tempfile +from collections import OrderedDict from glob import glob from typing import Dict, List, Optional, Union from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything import sam_model_registry @@ -126,16 +128,43 @@ def load_model(self, otx_model: Optional[ModelEntity]) -> LightningModule: Returns: LightningModule: Visual prompting model with/without weights. """ + def get_model( + config: DictConfig, + checkpoint: Optional[str] = None, + state_dict: Optional[OrderedDict] = None + ): + backbone = config.model.backbone + if checkpoint is None: + checkpoint = config.model.checkpoint + + model = sam_model_registry[backbone](checkpoint=checkpoint) + if state_dict: + model.load_state_dict(state_dict) + + return model + if otx_model is None: - backbone = self.config.model.backbone + model = get_model(config=self.config) + logger.info( + "No trained model in project yet. Created new model with '%s'", + self.model_name, + ) else: - backbone = otx_model + buffer = io.BytesIO(otx_model.get_data("weights.pth")) + model_data = torch.load(buffer, map_location=torch.device("cpu")) + + if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]: + logger.warning( + "Backbone of the model in the Task Environment is different from the one in the template. " + f"creating model with backbone={model_data['config']['model']['backbone']}" + ) + self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"] + try: + model = get_model(config=self.config, state_dict=model_data["model"]) + logger.info("Loaded model weights from Task Environment") + except BaseException as exception: + raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception - # TODO (sungchul): where can load_from be applied? - checkpoint = self.config.model.checkpoint - - # TODO (sungchul): load model in different ways - model = sam_model_registry[backbone](checkpoint=checkpoint) return model def cancel_training(self) -> None: @@ -171,6 +200,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) self.trainer.predict(model=self.model, datamodule=datamodule) + return dataset def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None) -> None: @@ -181,20 +211,17 @@ def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optiona evaluation_metric (Optional[str], optional): Evaluation metric. Defaults to None. Instead, metric is chosen depending on the task type. """ - metric: IPerformanceProvider - if self.task_type == TaskType.ANOMALY_CLASSIFICATION: - metric = MetricsHelper.compute_f_measure(output_resultset) - elif self.task_type == TaskType.ANOMALY_DETECTION: - metric = MetricsHelper.compute_anomaly_detection_scores(output_resultset) - elif self.task_type == TaskType.ANOMALY_SEGMENTATION: - metric = MetricsHelper.compute_anomaly_segmentation_scores(output_resultset) - else: - raise ValueError(f"Unknown task type: {self.task_type}") - output_resultset.performance = metric.get_performance() + # metric = MetricsHelper.compute_f_measure(output_resultset) + # output_resultset.performance = metric.get_performance() + + # if self.task_type == TaskType.ANOMALY_CLASSIFICATION: + # accuracy = MetricsHelper.compute_accuracy(output_resultset).get_performance() + # output_resultset.performance.dashboard_metrics.extend(accuracy.dashboard_metrics) - if self.task_type == TaskType.ANOMALY_CLASSIFICATION: - accuracy = MetricsHelper.compute_accuracy(output_resultset).get_performance() - output_resultset.performance.dashboard_metrics.extend(accuracy.dashboard_metrics) + metric = MetricsHelper.compute_dice_averaged_over_pixels(output_resultset) + logger.info(f"mDice after evaluation: {metric.overall_dice.value}") + output_resultset.performance = metric.get_performance() + logger.info("Evaluation completed") def _export_to_onnx(self, onnx_path: str): """Export model to ONNX.