Skip to content

Commit

Permalink
pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Jun 27, 2023
1 parent 7d35e6a commit aff872e
Show file tree
Hide file tree
Showing 37 changed files with 843 additions and 656 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import numpy as np
from bson import ObjectId
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
Expand All @@ -30,11 +28,13 @@
create_annotation_from_segmentation_map,
create_hard_prediction_from_soft_prediction,
)
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback


class InferenceCallback(Callback):
"""Callback that updates otx_dataset during inference.
Args:
otx_dataset (DatasetEntity): Dataset that predictions will be updated.
"""
Expand All @@ -60,24 +60,29 @@ def on_predict_epoch_end(self, _trainer: Trainer, _pl_module: LightningModule, o
iou_predictions.append(output["iou_predictions"][0])
gt_labels.append(output["labels"][0])

for dataset_item, pred_mask, iou_prediction, labels in zip(self.otx_dataset, pred_masks, iou_predictions, gt_labels):
for dataset_item, pred_mask, iou_prediction, labels in zip(
self.otx_dataset, pred_masks, iou_predictions, gt_labels
):
annotations: List[Annotation] = []
for soft_prediction, iou, label in zip(pred_mask, iou_prediction, labels):
probability = max(min(float(iou), 1.), 0.)
probability = max(min(float(iou), 1.0), 0.0)
label.probability = probability
soft_prediction = soft_prediction.numpy()
hard_prediction = create_hard_prediction_from_soft_prediction(
soft_prediction=soft_prediction,
soft_threshold=0.5
soft_prediction=soft_prediction, soft_threshold=0.5
)

if self.use_mask:
# set mask as annotation
annotation = [Annotation(
shape=Image(data=hard_prediction.astype(np.uint8), size=hard_prediction.shape),
labels=[ScoredLabel(label=label.label, probability=probability)],
id=ID(ObjectId()),
)]
annotation = [
Annotation(
shape=Image(
data=hard_prediction.astype(np.uint8), size=hard_prediction.shape
), # type: ignore[arg-type]
labels=[ScoredLabel(label=label.label, probability=probability)],
id=ID(ObjectId()),
)
]
else:
# generate polygon annotations
annotation = create_annotation_from_segmentation_map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,19 @@

import os
from pathlib import Path
from typing import Union, Optional
from typing import Optional, Union

from omegaconf import DictConfig, ListConfig, OmegaConf

from otx.api.configuration.configurable_parameters import ConfigurableParameters
from otx.algorithms.common.utils.logger import get_logger
from otx.api.configuration.configurable_parameters import ConfigurableParameters

logger = get_logger()


def get_visual_promtping_config(
task_name: str,
otx_config: ConfigurableParameters,
output_path: str
task_name: str, otx_config: ConfigurableParameters, output_path: str
) -> Union[DictConfig, ListConfig]:

"""Get visual prompting configuration.
Create an visual prompting config object that matches the values specified in the
Expand All @@ -43,7 +40,8 @@ def get_visual_promtping_config(
output_path (str): Path to save the configuration file.
Returns:
Union[DictConfig, ListConfig]: Visual prompting config object for the specified model type with overwritten default values.
Union[DictConfig, ListConfig]: Visual prompting config object for the specified model type
with overwritten default values.
"""
if os.path.isfile(os.path.join(output_path, "config.yaml")):
# If there is already a config.yaml file in the output path, load it
Expand All @@ -52,8 +50,10 @@ def get_visual_promtping_config(
print(f"[*] Load configuration file at {config_path}")
else:
# Load the default config.yaml file
config_path = Path(f"otx/algorithms/visual_prompting/configs/{task_name.lower()}/config.yaml")
visual_prompting_config = get_configurable_parameters(model_name=task_name.lower(), config_path=config_path, output_path=Path(output_path))
config_path = f"otx/algorithms/visual_prompting/configs/{task_name.lower()}/config.yaml"
visual_prompting_config = get_configurable_parameters(
model_name=task_name.lower(), config_path=Path(config_path), output_path=Path(output_path)
)
update_visual_prompting_config(visual_prompting_config, otx_config)
return visual_prompting_config

Expand Down Expand Up @@ -86,28 +86,33 @@ def get_configurable_parameters(
)

if config_path is None:
config_path = Path(f"otx/algorithms/visual_prompting/configs/{model_name}/{config_filename}.{config_file_extension}")
config_path = Path(
f"otx/algorithms/visual_prompting/configs/{model_name}/{config_filename}.{config_file_extension}"
)

config = OmegaConf.load(config_path)
print(f"[*] Load configuration file at {config_path}")

if weight_file:
config.trainer.resume_from_checkpoint = weight_file

(output_path / f"{config_filename}.{config_file_extension}").write_text(OmegaConf.to_yaml(config))

return config


def update_visual_prompting_config(visual_prompting_config: Union[DictConfig, ListConfig], otx_config: ConfigurableParameters) -> None:
def update_visual_prompting_config(
visual_prompting_config: Union[DictConfig, ListConfig], otx_config: ConfigurableParameters
) -> None:
"""Update visual prompting configuration.
Overwrite the default parameter values in the visual prompting config with the
values specified in the OTX config. The function is recursively called for
each parameter group present in the OTX config.
Args:
visual_prompting_config (Union[DictConfig, ListConfig]): Visual prompting config object for the specified model type with overwritten default values.
visual_prompting_config (Union[DictConfig, ListConfig]): Visual prompting config object
for the specified model type with overwritten default values.
otx_config (ConfigurableParameters): OTX config object parsed from configuration.yaml file.
"""
groups = getattr(otx_config, "groups", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

from .dataset import OTXVisualPromptingDataModule
from .pipelines import ResizeLongestSide, MultipleInputsCompose, Pad
from .pipelines import MultipleInputsCompose, Pad, ResizeLongestSide

__all__ = ["OTXVisualPromptingDataModule", "ResizeLongestSide", "MultipleInputsCompose", "Pad"]
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

import cv2
import numpy as np
import torchvision.transforms as transforms
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import (
Expand All @@ -34,29 +33,24 @@
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.image import Image
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.shapes.polygon import Point, Polygon
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.entities.shapes.polygon import Polygon
from otx.api.entities.subset import Subset
from otx.api.utils.shape_factory import ShapeFactory
from pytorch_lightning import LightningDataModule

logger = get_logger()


class OTXVIsualPromptingDataset(Dataset):
"""Visual Prompting Dataset Adaptor.
Args:
dataset (DatasetEntity): Dataset entity.
transform (MultipleInputsCompose): Transformations to apply to the dataset.
offset_bbox (int): Offset to apply to the bounding box, defaults to 0.
"""

def __init__(
self,
dataset: DatasetEntity,
transform: MultipleInputsCompose,
offset_bbox: int = 0
) -> None:
def __init__(self, dataset: DatasetEntity, transform: MultipleInputsCompose, offset_bbox: int = 0) -> None:

self.dataset = dataset
self.transform = transform
Expand Down Expand Up @@ -87,7 +81,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:

width, height = dataset_item.width, dataset_item.height
bboxes: List[List[int]] = []
points: List = [] # TBD
points: List = [] # TBD
gt_masks: List[np.ndarray] = []
labels: List[ScoredLabel] = []
for annotation in dataset_item.get_annotations(labels=self.labels, include_empty=False, preserve_id=True):
Expand Down Expand Up @@ -115,24 +109,34 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:
labels.extend(annotation.get_labels(include_empty=False))

if len(gt_masks) == 0:
return {"images": [], "bboxes": [], "points": [], "gt_masks": [], "original_size": [], "path": [], "labels": []}
return {
"images": [],
"bboxes": [],
"points": [],
"gt_masks": [],
"original_size": [],
"path": [],
"labels": [],
}

bboxes = np.array(bboxes)
item.update(dict(
original_size=(height, width),
images=dataset_item.numpy,
path=dataset_item.media.path,
gt_masks=gt_masks,
bboxes=bboxes,
points=points, # TODO (sungchul): update point information
labels=labels,
))
item.update(
dict(
original_size=(height, width),
images=dataset_item.numpy,
path=dataset_item.media.path,
gt_masks=gt_masks,
bboxes=bboxes,
points=points, # TODO (sungchul): update point information
labels=labels,
)
)
item = self.transform(item)
return item

def convert_polygon_to_mask(self, shape: Polygon, width: int, height: int) -> np.ndarray:
"""Convert polygon to mask.
Args:
shape (Polygon): Polygon to convert.
width (int): Width of image.
Expand All @@ -147,17 +151,18 @@ def convert_polygon_to_mask(self, shape: Polygon, width: int, height: int) -> np
gt_mask = cv2.drawContours(gt_mask, np.asarray([contour]), 0, 1, -1)
return gt_mask

def generate_bbox(self, x1: int, y1: int, x2: int, y2: int, width: int, height: int) -> List[int]:
def generate_bbox(self, x1: int, y1: int, x2: int, y2: int, width: int, height: int) -> List[int]: # noqa: D417
"""Generate bounding box.
Args:
x1, y1, x2, y2 (int): Bounding box coordinates.
x1, y1, x2, y2 (int): Bounding box coordinates. # type: ignore
width (int): Width of image.
height (int): Height of image.
Returns:
List[int]: Generated bounding box.
"""

def get_randomness(length: int) -> int:
if self.offset_bbox == 0:
return 0
Expand All @@ -167,7 +172,7 @@ def get_randomness(length: int) -> int:
max(0, x1 + get_randomness(width)),
max(0, y1 + get_randomness(height)),
min(width, x2 + get_randomness(width)),
min(height, y2 + get_randomness(height))
min(height, y2 + get_randomness(height)),
]
return bbox

Expand All @@ -190,7 +195,7 @@ def generate_bbox_from_mask(self, gt_mask: np.ndarray, width: int, height: int)

class OTXVisualPromptingDataModule(LightningDataModule):
"""Visual Prompting DataModule.
Args:
config (Union[DictConfig, ListConfig]): Configuration.
dataset (DatasetEntity): Dataset entity.
Expand Down Expand Up @@ -224,31 +229,39 @@ def setup(self, stage: Optional[str] = None) -> None:
val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION)

# TODO (sungchul): distinguish between train and val config here
train_transform = val_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])
train_transform = val_transform = MultipleInputsCompose(
[
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
]
)

self.train_dataset = OTXVIsualPromptingDataset(train_otx_dataset, train_transform, offset_bbox=self.config.offset_bbox)
self.train_dataset = OTXVIsualPromptingDataset(
train_otx_dataset, train_transform, offset_bbox=self.config.offset_bbox
)
self.val_dataset = OTXVIsualPromptingDataset(val_otx_dataset, val_transform)

if stage == "test":
test_otx_dataset = self.dataset.get_subset(Subset.TESTING)
test_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])
test_transform = MultipleInputsCompose(
[
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
]
)
self.test_dataset = OTXVIsualPromptingDataset(test_otx_dataset, test_transform)

if stage == "predict":
predict_otx_dataset = self.dataset
predict_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])
predict_transform = MultipleInputsCompose(
[
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
]
)
self.predict_dataset = OTXVIsualPromptingDataset(predict_otx_dataset, predict_transform)

def summary(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .transforms import ResizeLongestSide, collate_fn, MultipleInputsCompose, Pad
from .transforms import MultipleInputsCompose, Pad, ResizeLongestSide, collate_fn

__all__ = ["ResizeLongestSide", "collate_fn", "MultipleInputsCompose", "Pad"]
Loading

0 comments on commit aff872e

Please sign in to comment.