Skip to content

Commit

Permalink
(WIP) Set otx eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Jun 9, 2023
1 parent 0c0e683 commit 639f4c5
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,26 @@


class OTXVIsualPromptingDataset(Dataset):
"""Visual Prompting Dataset Adaptor."""
"""Visual Prompting Dataset Adaptor.
Args:
config
dataset
transform
stage
"""

def __init__(
self,
config: Union[DictConfig, ListConfig],
dataset: DatasetEntity,
transform: MultipleInputsCompose
transform: MultipleInputsCompose,
) -> None:

self.config = config
self.dataset = dataset
self.transform = transform

self.labels = dataset.get_labels()
self.label_idx = {label.id: i for i, label in enumerate(self.labels)}

Expand Down Expand Up @@ -133,11 +141,11 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, Tensor]]:

item.update(dict(
original_size=(height, width),
image=dataset_item.numpy,
mask=masks,
bbox=bboxes,
label=labels,
point=None, # TODO (sungchul): update point information
images=dataset_item.numpy,
masks=masks,
bboxes=bboxes,
labels=labels,
points=None, # TODO (sungchul): update point information
))
item = self.transform(item)
return item
Expand Down Expand Up @@ -171,31 +179,36 @@ def setup(self, stage: Optional[str] = None) -> None:
image_size = [image_size]

if stage == "fit" or stage is None:
self.train_otx_dataset = self.dataset.get_subset(Subset.TRAINING)
self.val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION)
train_otx_dataset = self.dataset.get_subset(Subset.TRAINING)
val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION)

# TODO (sungchul): distinguish between train and val config here
self.train_transform = self.val_transform = MultipleInputsCompose([
train_transform = val_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])

self.train_dataset = OTXVIsualPromptingDataset(self.config, train_otx_dataset, train_transform)
self.val_dataset = OTXVIsualPromptingDataset(self.config, val_otx_dataset, val_transform)

if stage == "test":
self.test_otx_dataset = self.dataset.get_subset(Subset.TESTING)
self.test_transform = MultipleInputsCompose([
test_otx_dataset = self.dataset.get_subset(Subset.TESTING)
test_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])
self.test_dataset = OTXVIsualPromptingDataset(self.config, test_otx_dataset, test_transform)

if stage == "predict":
self.predict_otx_dataset = self.dataset
self.predict_transform = MultipleInputsCompose([
predict_otx_dataset = self.dataset
predict_transform = MultipleInputsCompose([
ResizeLongestSide(target_length=max(image_size)),
Pad(),
transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
])
self.predict_dataset = OTXVIsualPromptingDataset(self.config, predict_otx_dataset, predict_transform)

def summary(self):
"""Print size of the dataset, number of images."""
Expand All @@ -216,9 +229,8 @@ def train_dataloader(
Returns:
Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: Train dataloader.
"""
dataset = OTXVIsualPromptingDataset(self.config, self.train_otx_dataset, self.train_transform)
return DataLoader(
dataset,
self.train_dataset,
shuffle=False,
batch_size=self.config.dataset.train_batch_size,
num_workers=self.config.dataset.num_workers,
Expand All @@ -231,9 +243,8 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
Returns:
Union[DataLoader, List[DataLoader]]: Validation Dataloader.
"""
dataset = OTXVIsualPromptingDataset(self.config, self.val_otx_dataset, self.val_transform)
return DataLoader(
dataset,
self.val_dataset,
shuffle=False,
batch_size=self.config.dataset.eval_batch_size,
num_workers=self.config.dataset.num_workers,
Expand All @@ -246,9 +257,8 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
Returns:
Union[DataLoader, List[DataLoader]]: Test Dataloader.
"""
dataset = OTXVIsualPromptingDataset(self.config, self.test_otx_dataset, self.test_transform)
return DataLoader(
dataset,
self.test_dataset,
shuffle=False,
batch_size=self.config.dataset.test_batch_size,
num_workers=self.config.dataset.num_workers,
Expand All @@ -261,9 +271,8 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
Returns:
Union[DataLoader, List[DataLoader]]: Predict Dataloader.
"""
dataset = OTXVIsualPromptingDataset(self.config, self.predict_otx_dataset, self.predict_transform)
return DataLoader(
dataset,
self.predict_dataset,
shuffle=False,
batch_size=self.config.dataset.eval_batch_size,
num_workers=self.config.dataset.num_workers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@


def collate_fn(batch):
index = [item['index'] for item in batch]
image = torch.stack([item['image'] for item in batch])
bbox = [torch.tensor(item['bbox']) for item in batch]
mask = [torch.stack(item['mask']) for item in batch if item['mask'] != []]
label = [item['label'] for item in batch] if batch else []
if mask:
return {'index': index, 'image': image, 'bbox': bbox, 'mask': mask, 'label': label}
return {'index': -1, 'image': [], 'bbox': [], 'mask': [], 'label': []}
def _convert_empty_to_none(x):
func = torch.stack if x == "masks" else torch.tensor
items = [func(item[x]) for item in batch if item[x]]
return None if len(items) == 0 else items

index = [item["index"] for item in batch]
images = torch.stack([item["images"] for item in batch])
bboxes = _convert_empty_to_none("bboxes")
points = _convert_empty_to_none("points")
masks = _convert_empty_to_none("masks")
labels = [item["labels"] for item in batch]
if masks:
return {"index": index, "images": images, "bboxes": bboxes, "points": points, "masks": masks, "label": labels}
return {"index": -1, "images": [], "bboxes": [], "points": [], "masks": [], "labels": []}


class ResizeLongestSide:
Expand All @@ -44,14 +50,13 @@ def __init__(self, target_length: int) -> None:
self.target_length = target_length

def __call__(self, item: Dict[str, Union[int, Tensor]]):
item["image"] = torch.as_tensor(
self.apply_image(item["image"]).transpose((2, 0, 1)),
item["images"] = torch.as_tensor(
self.apply_image(item["images"]).transpose((2, 0, 1)),
dtype=torch.get_default_dtype())
item["mask"] = [torch.as_tensor(self.apply_image(mask)) for mask in item["mask"]]
item["bbox"] = self.apply_boxes(item["bbox"], item["original_size"])
if item["point"]:
item["point"] = self.apply_coords(item["point"], item["original_size"])

item["masks"] = [torch.as_tensor(self.apply_image(mask)) for mask in item["masks"]]
item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"])
if item["points"]:
item["points"] = self.apply_coords(item["points"], item["original_size"])
return item

def apply_image(self, image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -130,15 +135,17 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[i
class Pad:
""""""
def __call__(self, item: Dict[str, Union[int, Tensor]]):
_, h, w = item["image"].shape
_, h, w = item["images"].shape
max_dim = max(w, h)
pad_w = (max_dim - w) // 2
pad_h = (max_dim - h) // 2
padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)

item["image"] = transforms.functional.pad(item["image"], padding, fill=0, padding_mode="constant")
item["mask"] = [transforms.functional.pad(mask, padding, fill=0, padding_mode="constant") for mask in item["mask"]]
item["bbox"] = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in item["bbox"]]
item["images"] = transforms.functional.pad(item["images"], padding, fill=0, padding_mode="constant")
item["masks"] = [transforms.functional.pad(mask, padding, fill=0, padding_mode="constant") for mask in item["masks"]]
item["bboxes"] = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in item["bboxes"]]
if item["points"]:
item["points"] = [[point[0] + pad_w, point[1] + pad_h, point[2] + pad_w, point[3] + pad_h] for point in item["points"]]
return item


Expand All @@ -147,7 +154,7 @@ class MultipleInputsCompose(Compose):
def __call__(self, item: Dict[str, Union[int, Tensor]]):
for t in self.transforms:
if isinstance(t, transforms.Normalize):
item["image"] = t(item["image"])
item["images"] = t(item["images"])
else:
item = t(item)
return item
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = False,
checkpoint: str = None
checkpoint: str = None,
mask_threshold: float = 0.,
return_logits: bool = False
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Expand All @@ -60,13 +62,16 @@ def __init__(
freeze_prompt_encoder (bool): Whether freezing prompt encoder, default is True.
freeze_mask_decoder (bool): Whether freezing mask decoder, default is False.
checkpoint (optional, str): Checkpoint path to be loaded, default is None.
mask_threshold (float):
"""
super().__init__()
# self.save_hyperparameters()

self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.mask_threshold = mask_threshold
self.return_logits = return_logits

if freeze_image_encoder:
for param in self.image_encoder.parameters():
Expand Down Expand Up @@ -95,14 +100,14 @@ def __init__(
state_dict = torch.load(f)
self.load_state_dict(state_dict)

def forward(self, images, bboxes):
def forward(self, images, bboxes, points=None):
_, _, height, width = images.shape
image_embeddings = self.image_encoder(images)
pred_masks = []
ious = []
for embedding, bbox in zip(image_embeddings, bboxes):
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=None,
points=points,
boxes=bbox,
masks=None,
)
Expand All @@ -128,11 +133,12 @@ def forward(self, images, bboxes):

def training_step(self, batch, batch_idx):
"""Training step of SAM."""
images = batch["image"]
bboxes = batch["bbox"]
gt_masks = batch["mask"]
images = batch["images"]
bboxes = batch["bboxes"]
points = batch["points"]
gt_masks = batch["masks"]

pred_masks, ious = self(images, bboxes)
pred_masks, ious = self(images, bboxes, points)

loss_focal = 0.
loss_dice = 0.
Expand Down Expand Up @@ -164,11 +170,12 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx):
"""Validation step of SAM."""
images = batch["image"]
bboxes = batch["bbox"]
gt_masks = batch["mask"]
images = batch["images"]
bboxes = batch["bboxes"]
points = batch["points"]
gt_masks = batch["masks"]

pred_masks, _ = self(images, bboxes)
pred_masks, _ = self(images, bboxes, points)
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
self.val_iou(pred_mask, gt_mask)
self.val_f1(pred_mask, gt_mask)
Expand All @@ -178,6 +185,20 @@ def validation_step(self, batch, batch_idx):

return results

def predict_step(self, batch, batch_idx):
"""Predict step of SAM."""
images = batch["images"]
bboxes = batch["bboxes"]
points = batch["points"]

pred_masks, _ = self(images, bboxes, points)

masks = self.postprocess_masks(pred_masks, self.input_size, self.original_size)
if not self.return_logits:
masks = masks > self.mask_threshold

return masks

def postprocess_masks(
self,
masks: torch.Tensor,
Expand Down
Loading

0 comments on commit 639f4c5

Please sign in to comment.