diff --git a/anomalib/deploy/inferencers/base_inferencer.py b/anomalib/deploy/inferencers/base_inferencer.py index 089a96f613..c4d9219750 100644 --- a/anomalib/deploy/inferencers/base_inferencer.py +++ b/anomalib/deploy/inferencers/base_inferencer.py @@ -91,6 +91,7 @@ def predict( anomaly_map=output["anomaly_map"], pred_mask=output["pred_mask"], pred_boxes=output["pred_boxes"], + box_labels=output["box_labels"], ) @staticmethod diff --git a/anomalib/deploy/inferencers/openvino_inferencer.py b/anomalib/deploy/inferencers/openvino_inferencer.py index 83cea4d2b2..2296ec5908 100644 --- a/anomalib/deploy/inferencers/openvino_inferencer.py +++ b/anomalib/deploy/inferencers/openvino_inferencer.py @@ -192,8 +192,10 @@ def post_process( if self.config.dataset.task == TaskType.DETECTION: pred_boxes = self._get_boxes(pred_mask) + box_labels = np.ones(pred_boxes.shape[0]) else: pred_boxes = None + box_labels = None return { "anomaly_map": anomaly_map, @@ -201,6 +203,7 @@ def post_process( "pred_score": pred_score, "pred_mask": pred_mask, "pred_boxes": pred_boxes, + "box_labels": box_labels, } @staticmethod diff --git a/anomalib/deploy/inferencers/torch_inferencer.py b/anomalib/deploy/inferencers/torch_inferencer.py index d4c8b0b705..939c25b3e7 100644 --- a/anomalib/deploy/inferencers/torch_inferencer.py +++ b/anomalib/deploy/inferencers/torch_inferencer.py @@ -209,8 +209,10 @@ def post_process(self, predictions: Tensor, meta_data: Optional[Union[Dict, Dict if self.config.dataset.task == TaskType.DETECTION: pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0].numpy() + box_labels = np.ones(pred_boxes.shape[0]) else: pred_boxes = None + box_labels = None return { "anomaly_map": anomaly_map, @@ -218,4 +220,5 @@ def post_process(self, predictions: Tensor, meta_data: Optional[Union[Dict, Dict "pred_score": pred_score, "pred_mask": pred_mask, "pred_boxes": pred_boxes, + "box_labels": box_labels, } diff --git a/anomalib/models/components/base/anomaly_module.py b/anomalib/models/components/base/anomaly_module.py index 27fd074a43..2cdcf54914 100644 --- a/anomalib/models/components/base/anomaly_module.py +++ b/anomalib/models/components/base/anomaly_module.py @@ -86,6 +86,12 @@ def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value if "pred_boxes" not in outputs.keys(): outputs["pred_boxes"] = masks_to_boxes(outputs["pred_masks"]) + outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]] + # apply thresholding to boxes + if "box_scores" in outputs: + # apply threshold to assign normal/anomalous label to boxes + is_anomalous = [scores > self.pixel_threshold.value for scores in outputs["box_scores"]] + outputs["box_labels"] = [labels.int() for labels in is_anomalous] return outputs def test_step(self, batch, _): # pylint: disable=arguments-differ @@ -163,17 +169,17 @@ def _post_process(outputs): outputs["pred_scores"] = ( outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values ) - elif "pred_scores" not in outputs and "boxes_scores" in outputs: + elif "pred_scores" not in outputs and "box_scores" in outputs: # infer image score from bbox confidence scores outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float() - for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["boxes_scores"])): + for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"])): if boxes.numel(): outputs["pred_scores"][idx] = scores.max().item() if "pred_boxes" in outputs and "anomaly_maps" not in outputs: # create anomaly maps from bbox predictions for thresholding and evaluation image_size = tuple(outputs["image"].shape[-2:]) - outputs["anomaly_maps"] = boxes_to_anomaly_maps(outputs["pred_boxes"], outputs["boxes_scores"], image_size) + outputs["anomaly_maps"] = boxes_to_anomaly_maps(outputs["pred_boxes"], outputs["box_scores"], image_size) outputs["mask"] = boxes_to_masks(outputs["boxes"], image_size) def _outputs_to_cpu(self, output): diff --git a/anomalib/post_processing/post_process.py b/anomalib/post_processing/post_process.py index fbc63ec175..226cb21f23 100644 --- a/anomalib/post_processing/post_process.py +++ b/anomalib/post_processing/post_process.py @@ -153,19 +153,17 @@ def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4 return mask -def draw_boxes(image: np.ndarray, boxes: np.ndarray, is_ground_truth: bool = False) -> np.ndarray: +def draw_boxes(image: np.ndarray, boxes: np.ndarray, color: Tuple[int, int, int]) -> np.ndarray: """Draw bounding boxes on an image. Args: image (np.ndarray): Source image. boxes (np.nparray): 2D array of shape (N, 4) where each row contains the xyxy coordinates of a bounding box. - is_ground_truth (bool): Flag indicating if the boxes are ground truth. When true, boxes will be drawn in red, - otherwise in blue. + color (Tuple[int, int, int]): Color of the drawn boxes in RGB format. Returns: np.ndarray: Image showing the bounding boxes drawn on top of the source image. """ - color = (255, 0, 0) if is_ground_truth else (0, 0, 255) for box in boxes: x_1, y_1, x_2, y_2 = box.astype(np.int) image = cv2.rectangle(image, (x_1, y_1), (x_2, y_2), color=color, thickness=2) diff --git a/anomalib/post_processing/visualizer.py b/anomalib/post_processing/visualizer.py index c186ed88f9..d14e51ef58 100644 --- a/anomalib/post_processing/visualizer.py +++ b/anomalib/post_processing/visualizer.py @@ -35,9 +35,12 @@ class ImageResult: pred_mask: Optional[np.ndarray] = None gt_boxes: Optional[np.ndarray] = None pred_boxes: Optional[np.ndarray] = None + box_labels: Optional[np.ndarray] = None heat_map: np.ndarray = field(init=False) segmentations: np.ndarray = field(init=False) + normal_boxes: np.ndarray = field(init=False) + anomalous_boxes: np.ndarray = field(init=False) def __post_init__(self) -> None: """Generate heatmap overlay and segmentations, convert masks to images.""" @@ -50,6 +53,10 @@ def __post_init__(self) -> None: self.segmentations = (self.segmentations * 255).astype(np.uint8) if self.gt_mask is not None and self.gt_mask.max() <= 1.0: self.gt_mask *= 255 + if self.pred_boxes is not None: + assert self.box_labels is not None, "Box labels must be provided when box locations are provided." + self.normal_boxes = self.pred_boxes[~self.box_labels.astype(np.bool)] + self.anomalous_boxes = self.pred_boxes[self.box_labels.astype(np.bool)] class Visualizer: @@ -98,6 +105,7 @@ def visualize_batch(self, batch: Dict) -> Iterator[np.ndarray]: gt_mask=batch["mask"][i].squeeze().int().cpu().numpy() if "mask" in batch else None, gt_boxes=batch["boxes"][i].cpu().numpy() if "boxes" in batch else None, pred_boxes=batch["pred_boxes"][i].cpu().numpy() if "pred_boxes" in batch else None, + box_labels=batch["box_labels"][i].cpu().numpy() if "box_labels" in batch else None, ) yield self.visualize_image(image_result) @@ -134,11 +142,12 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: assert image_result.pred_boxes is not None visualization.add_image(image_result.image, "Image") if image_result.gt_boxes is not None: - gt_image = draw_boxes(np.copy(image_result.image), image_result.gt_boxes, is_ground_truth=True) + gt_image = draw_boxes(np.copy(image_result.image), image_result.gt_boxes, color=(255, 0, 0)) visualization.add_image(image=gt_image, color_map="gray", title="Ground Truth") else: visualization.add_image(image_result.image, "Image") - pred_image = draw_boxes(np.copy(image_result.image), image_result.pred_boxes, is_ground_truth=False) + pred_image = draw_boxes(np.copy(image_result.image), image_result.normal_boxes, color=(0, 255, 0)) + pred_image = draw_boxes(pred_image, image_result.anomalous_boxes, color=(255, 0, 0)) visualization.add_image(pred_image, "Predictions") if self.task == TaskType.SEGMENTATION: assert image_result.pred_mask is not None @@ -172,11 +181,10 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray: if self.task == TaskType.DETECTION: # return image with bounding boxes augmented image_with_boxes = draw_boxes( - image=image_result.image, boxes=image_result.pred_boxes, is_ground_truth=False + image=np.copy(image_result.image), boxes=image_result.anomalous_boxes, color=(0, 0, 255) ) if image_result.gt_boxes is not None: - image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.gt_boxes, is_ground_truth=True) - image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.pred_boxes, is_ground_truth=False) + image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.gt_boxes, color=(255, 0, 0)) return image_with_boxes if self.task == TaskType.SEGMENTATION: visualization = mark_boundaries( diff --git a/anomalib/utils/callbacks/min_max_normalization.py b/anomalib/utils/callbacks/min_max_normalization.py index 303836d4ed..6d24aeb57a 100644 --- a/anomalib/utils/callbacks/min_max_normalization.py +++ b/anomalib/utils/callbacks/min_max_normalization.py @@ -49,8 +49,8 @@ def on_validation_batch_end( """Called when the validation batch ends, update the min and max observed values.""" if "anomaly_maps" in outputs: pl_module.normalization_metrics(outputs["anomaly_maps"]) - elif "boxes_scores" in outputs: - pl_module.normalization_metrics(torch.cat(outputs["boxes_scores"])) + elif "box_scores" in outputs: + pl_module.normalization_metrics(torch.cat(outputs["box_scores"])) elif "pred_scores" in outputs: pl_module.normalization_metrics(outputs["pred_scores"]) else: @@ -83,11 +83,13 @@ def on_predict_batch_end( @staticmethod def _normalize_batch(outputs, pl_module): """Normalize a batch of predictions.""" + image_threshold = pl_module.image_threshold.value.cpu() + pixel_threshold = pl_module.pixel_threshold.value.cpu() stats = pl_module.normalization_metrics.cpu() - outputs["pred_scores"] = normalize( - outputs["pred_scores"], pl_module.image_threshold.value.cpu(), stats.min, stats.max - ) - if "anomaly_maps" in outputs.keys(): - outputs["anomaly_maps"] = normalize( - outputs["anomaly_maps"], pl_module.pixel_threshold.value.cpu(), stats.min, stats.max - ) + outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max) + if "anomaly_maps" in outputs: + outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max) + if "box_scores" in outputs: + outputs["box_scores"] = [ + normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"] + ]