Skip to content

Commit

Permalink
Detection improvements (#820)
Browse files Browse the repository at this point in the history
* apply pixel threshold to bbox detections

* allow visualizing normal boxes

* normalize box scores

* fix bbox logic in base anomaly module

* boxes_scores -> box_scores

* fix inferencers
  • Loading branch information
djdameln authored Dec 29, 2022
1 parent b21f12c commit ced7bc9
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 21 deletions.
1 change: 1 addition & 0 deletions anomalib/deploy/inferencers/base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def predict(
anomaly_map=output["anomaly_map"],
pred_mask=output["pred_mask"],
pred_boxes=output["pred_boxes"],
box_labels=output["box_labels"],
)

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,18 @@ def post_process(

if self.config.dataset.task == TaskType.DETECTION:
pred_boxes = self._get_boxes(pred_mask)
box_labels = np.ones(pred_boxes.shape[0])
else:
pred_boxes = None
box_labels = None

return {
"anomaly_map": anomaly_map,
"pred_label": pred_label,
"pred_score": pred_score,
"pred_mask": pred_mask,
"pred_boxes": pred_boxes,
"box_labels": box_labels,
}

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,16 @@ def post_process(self, predictions: Tensor, meta_data: Optional[Union[Dict, Dict

if self.config.dataset.task == TaskType.DETECTION:
pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0].numpy()
box_labels = np.ones(pred_boxes.shape[0])
else:
pred_boxes = None
box_labels = None

return {
"anomaly_map": anomaly_map,
"pred_label": pred_label,
"pred_score": pred_score,
"pred_mask": pred_mask,
"pred_boxes": pred_boxes,
"box_labels": box_labels,
}
12 changes: 9 additions & 3 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value
if "pred_boxes" not in outputs.keys():
outputs["pred_boxes"] = masks_to_boxes(outputs["pred_masks"])
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
# apply thresholding to boxes
if "box_scores" in outputs:
# apply threshold to assign normal/anomalous label to boxes
is_anomalous = [scores > self.pixel_threshold.value for scores in outputs["box_scores"]]
outputs["box_labels"] = [labels.int() for labels in is_anomalous]
return outputs

def test_step(self, batch, _): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -163,17 +169,17 @@ def _post_process(outputs):
outputs["pred_scores"] = (
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)
elif "pred_scores" not in outputs and "boxes_scores" in outputs:
elif "pred_scores" not in outputs and "box_scores" in outputs:
# infer image score from bbox confidence scores
outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float()
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["boxes_scores"])):
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"])):
if boxes.numel():
outputs["pred_scores"][idx] = scores.max().item()

if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
# create anomaly maps from bbox predictions for thresholding and evaluation
image_size = tuple(outputs["image"].shape[-2:])
outputs["anomaly_maps"] = boxes_to_anomaly_maps(outputs["pred_boxes"], outputs["boxes_scores"], image_size)
outputs["anomaly_maps"] = boxes_to_anomaly_maps(outputs["pred_boxes"], outputs["box_scores"], image_size)
outputs["mask"] = boxes_to_masks(outputs["boxes"], image_size)

def _outputs_to_cpu(self, output):
Expand Down
6 changes: 2 additions & 4 deletions anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,17 @@ def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4
return mask


def draw_boxes(image: np.ndarray, boxes: np.ndarray, is_ground_truth: bool = False) -> np.ndarray:
def draw_boxes(image: np.ndarray, boxes: np.ndarray, color: Tuple[int, int, int]) -> np.ndarray:
"""Draw bounding boxes on an image.
Args:
image (np.ndarray): Source image.
boxes (np.nparray): 2D array of shape (N, 4) where each row contains the xyxy coordinates of a bounding box.
is_ground_truth (bool): Flag indicating if the boxes are ground truth. When true, boxes will be drawn in red,
otherwise in blue.
color (Tuple[int, int, int]): Color of the drawn boxes in RGB format.
Returns:
np.ndarray: Image showing the bounding boxes drawn on top of the source image.
"""
color = (255, 0, 0) if is_ground_truth else (0, 0, 255)
for box in boxes:
x_1, y_1, x_2, y_2 = box.astype(np.int)
image = cv2.rectangle(image, (x_1, y_1), (x_2, y_2), color=color, thickness=2)
Expand Down
18 changes: 13 additions & 5 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ class ImageResult:
pred_mask: Optional[np.ndarray] = None
gt_boxes: Optional[np.ndarray] = None
pred_boxes: Optional[np.ndarray] = None
box_labels: Optional[np.ndarray] = None

heat_map: np.ndarray = field(init=False)
segmentations: np.ndarray = field(init=False)
normal_boxes: np.ndarray = field(init=False)
anomalous_boxes: np.ndarray = field(init=False)

def __post_init__(self) -> None:
"""Generate heatmap overlay and segmentations, convert masks to images."""
Expand All @@ -50,6 +53,10 @@ def __post_init__(self) -> None:
self.segmentations = (self.segmentations * 255).astype(np.uint8)
if self.gt_mask is not None and self.gt_mask.max() <= 1.0:
self.gt_mask *= 255
if self.pred_boxes is not None:
assert self.box_labels is not None, "Box labels must be provided when box locations are provided."
self.normal_boxes = self.pred_boxes[~self.box_labels.astype(np.bool)]
self.anomalous_boxes = self.pred_boxes[self.box_labels.astype(np.bool)]


class Visualizer:
Expand Down Expand Up @@ -98,6 +105,7 @@ def visualize_batch(self, batch: Dict) -> Iterator[np.ndarray]:
gt_mask=batch["mask"][i].squeeze().int().cpu().numpy() if "mask" in batch else None,
gt_boxes=batch["boxes"][i].cpu().numpy() if "boxes" in batch else None,
pred_boxes=batch["pred_boxes"][i].cpu().numpy() if "pred_boxes" in batch else None,
box_labels=batch["box_labels"][i].cpu().numpy() if "box_labels" in batch else None,
)
yield self.visualize_image(image_result)

Expand Down Expand Up @@ -134,11 +142,12 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray:
assert image_result.pred_boxes is not None
visualization.add_image(image_result.image, "Image")
if image_result.gt_boxes is not None:
gt_image = draw_boxes(np.copy(image_result.image), image_result.gt_boxes, is_ground_truth=True)
gt_image = draw_boxes(np.copy(image_result.image), image_result.gt_boxes, color=(255, 0, 0))
visualization.add_image(image=gt_image, color_map="gray", title="Ground Truth")
else:
visualization.add_image(image_result.image, "Image")
pred_image = draw_boxes(np.copy(image_result.image), image_result.pred_boxes, is_ground_truth=False)
pred_image = draw_boxes(np.copy(image_result.image), image_result.normal_boxes, color=(0, 255, 0))
pred_image = draw_boxes(pred_image, image_result.anomalous_boxes, color=(255, 0, 0))
visualization.add_image(pred_image, "Predictions")
if self.task == TaskType.SEGMENTATION:
assert image_result.pred_mask is not None
Expand Down Expand Up @@ -172,11 +181,10 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray:
if self.task == TaskType.DETECTION:
# return image with bounding boxes augmented
image_with_boxes = draw_boxes(
image=image_result.image, boxes=image_result.pred_boxes, is_ground_truth=False
image=np.copy(image_result.image), boxes=image_result.anomalous_boxes, color=(0, 0, 255)
)
if image_result.gt_boxes is not None:
image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.gt_boxes, is_ground_truth=True)
image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.pred_boxes, is_ground_truth=False)
image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.gt_boxes, color=(255, 0, 0))
return image_with_boxes
if self.task == TaskType.SEGMENTATION:
visualization = mark_boundaries(
Expand Down
20 changes: 11 additions & 9 deletions anomalib/utils/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def on_validation_batch_end(
"""Called when the validation batch ends, update the min and max observed values."""
if "anomaly_maps" in outputs:
pl_module.normalization_metrics(outputs["anomaly_maps"])
elif "boxes_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["boxes_scores"]))
elif "box_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
elif "pred_scores" in outputs:
pl_module.normalization_metrics(outputs["pred_scores"])
else:
Expand Down Expand Up @@ -83,11 +83,13 @@ def on_predict_batch_end(
@staticmethod
def _normalize_batch(outputs, pl_module):
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
stats = pl_module.normalization_metrics.cpu()
outputs["pred_scores"] = normalize(
outputs["pred_scores"], pl_module.image_threshold.value.cpu(), stats.min, stats.max
)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = normalize(
outputs["anomaly_maps"], pl_module.pixel_threshold.value.cpu(), stats.min, stats.max
)
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
if "box_scores" in outputs:
outputs["box_scores"] = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
]

0 comments on commit ced7bc9

Please sign in to comment.