Skip to content

Commit

Permalink
Fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Jul 3, 2023
1 parent ad9e466 commit 927cadc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
logger = get_logger()


class OTXVIsualPromptingDataset(Dataset):
class OTXVisualPromptingDataset(Dataset):
"""Visual Prompting Dataset Adaptor.
Args:
Expand Down Expand Up @@ -236,10 +236,10 @@ def setup(self, stage: Optional[str] = None) -> None:
]
)

self.train_dataset = OTXVIsualPromptingDataset(
self.train_dataset = OTXVisualPromptingDataset(
train_otx_dataset, train_transform, offset_bbox=self.config.offset_bbox
)
self.val_dataset = OTXVIsualPromptingDataset(val_otx_dataset, val_transform)
self.val_dataset = OTXVisualPromptingDataset(val_otx_dataset, val_transform)

if stage == "test":
test_otx_dataset = self.dataset.get_subset(Subset.TESTING)
Expand All @@ -250,7 +250,7 @@ def setup(self, stage: Optional[str] = None) -> None:
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
]
)
self.test_dataset = OTXVIsualPromptingDataset(test_otx_dataset, test_transform)
self.test_dataset = OTXVisualPromptingDataset(test_otx_dataset, test_transform)

if stage == "predict":
predict_otx_dataset = self.dataset
Expand All @@ -261,7 +261,7 @@ def setup(self, stage: Optional[str] = None) -> None:
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
]
)
self.predict_dataset = OTXVIsualPromptingDataset(predict_otx_dataset, predict_transform)
self.predict_dataset = OTXVisualPromptingDataset(predict_otx_dataset, predict_transform)

def summary(self):
"""Print size of the dataset, number of images."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset import (
OTXVisualPromptingDataModule,
OTXVIsualPromptingDataset,
OTXVisualPromptingDataset,
)
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import (
collate_fn,
Expand Down Expand Up @@ -45,15 +45,15 @@ def dataset_mask(self) -> DatasetEntity:
@e2e_pytest_unit
def test_len(self, dataset_polygon) -> None:
"""Test __len__."""
otx_dataset = OTXVIsualPromptingDataset(dataset_polygon, self.transform)
otx_dataset = OTXVisualPromptingDataset(dataset_polygon, self.transform)
assert len(otx_dataset) == 1

@e2e_pytest_unit
@pytest.mark.parametrize("use_mask", [False, True])
def test_getitem(self, dataset_polygon, dataset_mask, use_mask: bool) -> None:
"""Test __getitem__."""
dataset = dataset_mask if use_mask else dataset_polygon
otx_dataset = OTXVIsualPromptingDataset(dataset=dataset, transform=self.transform)
otx_dataset = OTXVisualPromptingDataset(dataset=dataset, transform=self.transform)

item = otx_dataset[0]

Expand All @@ -74,7 +74,7 @@ def test_getitem(self, dataset_polygon, dataset_mask, use_mask: bool) -> None:
@e2e_pytest_unit
def test_convert_polygon_to_mask(self, dataset_polygon) -> None:
"""Test convert_polygon_to_mask."""
otx_dataset = OTXVIsualPromptingDataset(dataset_polygon, self.transform)
otx_dataset = OTXVisualPromptingDataset(dataset_polygon, self.transform)

polygon = Polygon(points=[Point(x=0.1, y=0.1), Point(x=0.2, y=0.2), Point(x=0.3, y=0.3)])
width = 100
Expand All @@ -89,7 +89,7 @@ def test_convert_polygon_to_mask(self, dataset_polygon) -> None:
@e2e_pytest_unit
def test_generate_bbox(self, dataset_polygon) -> None:
"""Test generate_bbox."""
otx_dataset = OTXVIsualPromptingDataset(dataset_polygon, self.transform)
otx_dataset = OTXVisualPromptingDataset(dataset_polygon, self.transform)

x1, y1, x2, y2 = 10, 20, 30, 40
width = 100
Expand All @@ -107,7 +107,7 @@ def test_generate_bbox(self, dataset_polygon) -> None:
@e2e_pytest_unit
def test_generate_bbox_from_mask(self, dataset_polygon) -> None:
"""Test generate_bbox_from_mask."""
otx_dataset = OTXVIsualPromptingDataset(dataset_polygon, self.transform)
otx_dataset = OTXVisualPromptingDataset(dataset_polygon, self.transform)

gt_mask = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
width = 3
Expand Down Expand Up @@ -141,8 +141,8 @@ def test_setup(self, mocker, datamodule) -> None:

datamodule.setup()

assert isinstance(datamodule.train_dataset, OTXVIsualPromptingDataset)
assert isinstance(datamodule.val_dataset, OTXVIsualPromptingDataset)
assert isinstance(datamodule.train_dataset, OTXVisualPromptingDataset)
assert isinstance(datamodule.val_dataset, OTXVisualPromptingDataset)

@e2e_pytest_unit
def test_train_dataloader(self, mocker, datamodule) -> None:
Expand Down

0 comments on commit 927cadc

Please sign in to comment.