Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🛠️ [WIP] Visualization POC #233

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion model_api/python/model_api/models/result_types/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

from __future__ import annotations

import cv2
import numpy as np

from model_api.visualizer.primitives import BoundingBoxes, Label, Overlay, Polygon

class AnomalyResult:
from .base import Result


class AnomalyResult(Result):
"""Results for anomaly models."""

def __init__(
Expand All @@ -19,6 +24,7 @@ def __init__(
pred_mask: np.ndarray | None = None,
pred_score: float | None = None,
) -> None:
super().__init__()
self.anomaly_map = anomaly_map
self.pred_boxes = pred_boxes
self.pred_label = pred_label
Expand All @@ -40,3 +46,14 @@ def __str__(self) -> str:
f"pred_label:{self.pred_label};"
f"pred_mask min:{pred_mask_min} max:{pred_mask_max};"
)

def _register_primitives(self) -> None:
"""Converts the result to primitives."""
anomaly_map = cv2.applyColorMap(self.anomaly_map, cv2.COLORMAP_JET)
self._add_primitive(Overlay(anomaly_map))
for box in self.pred_boxes:
self._add_primitive(BoundingBoxes(*box))
if self.pred_label is not None:
self._add_primitive(Label(self.pred_label, bg_color="red" if self.pred_label == "Anomaly" else "green"))
self._add_primitive(Label(f"Score: {self.pred_score}"))
self._add_primitive(Polygon(mask=self.pred_mask))
12 changes: 12 additions & 0 deletions model_api/python/model_api/models/result_types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Base result type"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC

from model_api.visualizer.visualize_mixin import VisualizeMixin


class Result(VisualizeMixin, ABC):
"""Base result type."""
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from typing import TYPE_CHECKING

from model_api.visualizer.primitives import Label

from .base import Result
from .utils import array_shape_to_str

if TYPE_CHECKING:
import numpy as np


class ClassificationResult:
class ClassificationResult(Result):
"""Results for classification models."""

def __init__(
Expand All @@ -35,3 +38,8 @@ def __str__(self) -> str:
f"{labels}, {array_shape_to_str(self.saliency_map)}, {array_shape_to_str(self.feature_vector)}, "
f"{array_shape_to_str(self.raw_scores)}"
)

def _register_primitives(self) -> None:
# TODO add saliency map
for idx, label, confidence in self.top_labels:
self._add_primitive(Label(f"Rank: {idx}, {label}: {confidence:.3f}"))
8 changes: 8 additions & 0 deletions model_api/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Visualizer."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .visualizer import Visualizer

__all__ = ["Visualizer"]
134 changes: 134 additions & 0 deletions model_api/python/model_api/visualizer/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Base class for primitives."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from io import BytesIO

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont


class Primitive(ABC):
"""Primitive class."""

@abstractmethod
def compute(self, **kwargs) -> Image:
pass


class Label(Primitive):
"""Label primitive."""

def __init__(
self,
label: str,
fg_color: str | tuple[int, int, int] = "black",
bg_color: str | tuple[int, int, int] = "yellow",
font_path: str | None | BytesIO = None,
size: int = 16,
) -> None:
self.label = label
self.fg_color = fg_color
self.bg_color = bg_color
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)

def compute(self, image: Image, overlay_on_image: bool = True, buffer_y: int = 5) -> Image:
"""Generate label image.

If overlay_on_image is True, the label will be drawn on top of the image.
Else only the label will be drawn. This is useful for collecting labels so that they can be drawn on the same
image.
"""
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), self.label, font=self.font)
label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color)
draw = ImageDraw.Draw(label_image)
draw.text((0, 0), self.label, font=self.font, fill=self.fg_color)
if overlay_on_image:
image.paste(label_image, (0, 0))
return image
return label_image

@classmethod
def overlay_labels(cls, image: Image, label_images: list[Image], buffer: int = 5) -> Image:
"""Overlay multiple label images on top of the image.

Paste the labels in a row but wrap the labels if they exceed the image width.
"""
offset_x = 0
offset_y = 0
for label_image in label_images:
image.paste(label_image, (offset_x, offset_y))
offset_x += label_image.width + buffer
if offset_x + label_image.width > image.width:
offset_x = 0
offset_y += label_image.height
return image


class Polygon(Primitive):
"""Polygon primitive."""

def __init__(
self,
points: list[tuple[int, int]] | None = None,
mask: np.ndarray | None = None,
color: str | tuple[int, int, int] = "blue",
) -> None:
self.points = self._get_points(points, mask)
self.color = color

def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
if points is not None:
return points
return self._get_points_from_mask(mask)

def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
_points = contours[0].squeeze().tolist()
return [tuple(point) for point in _points]

def compute(self, image: Image) -> Image:
draw = ImageDraw.Draw(image)
draw.polygon(self.points, fill=self.color)
return image


class Overlay(Primitive):
"""Overlay an image.

Useful for XAI and Anomaly Maps.
"""

def __init__(self, image: Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_image(image)
self.opacity = opacity

def _to_image(self, image: Image | np.ndarray) -> Image:
if isinstance(image, Image.Image):
return image
return Image.fromarray(image)

def compute(self, image: Image) -> Image:
_image = self.image.resize(image.size)
return Image.blend(image, _image, self.opacity)


class BoundingBoxes(Primitive):
def __init__(self, x1: int, y1: int, x2: int, y2: int, color: str | tuple[int, int, int] = "blue") -> None:
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.color = color

def compute(self, image: Image) -> Image:
draw = ImageDraw.Draw(image)
draw.rectangle([self.x1, self.y1, self.x2, self.y2], fill=None, outline=self.color, width=2)
return image
83 changes: 83 additions & 0 deletions model_api/python/model_api/visualizer/visualize_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Mixin for visualization."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod

from .primitives import BoundingBoxes, Label, Overlay, Polygon, Primitive


class VisualizeMixin(ABC):
"""Mixin for visualization."""

def __init__(self) -> None:
self._labels = []
self._polygons = []
self._overlays = []
self._bounding_boxes = []
self._registered_primitives = False

@abstractmethod
def _register_primitives(self) -> None:
"""Convert result entities to primitives."""

def _add_primitive(self, primitive: Primitive) -> None:
"""Add primitive."""
if isinstance(primitive, Label):
self._labels.append(primitive)
elif isinstance(primitive, Polygon):
self._polygons.append(primitive)
elif isinstance(primitive, Overlay):
self._overlays.append(primitive)
elif isinstance(primitive, BoundingBoxes):
self._bounding_boxes.append(primitive)

@property
def has_labels(self) -> bool:
"""Check if there are labels."""
self._register_primitives_if_needed()
return bool(self._labels)

@property
def has_bounding_boxes(self) -> bool:
"""Check if there are bounding boxes."""
self._register_primitives_if_needed()
return bool(self._bounding_boxes)

@property
def has_polygons(self) -> bool:
"""Check if there are polygons."""
self._register_primitives_if_needed()
return bool(self._polygons)

@property
def has_overlays(self) -> bool:
"""Check if there are overlays."""
self._register_primitives_if_needed()
return bool(self._overlays)

def get_labels(self) -> list[Label]:
"""Get labels."""
self._register_primitives_if_needed()
return self._labels

def get_polygons(self) -> list[Polygon]:
"""Get polygons."""
self._register_primitives_if_needed()
return self._polygons

def get_overlays(self) -> list[Overlay]:
"""Get overlays."""
self._register_primitives_if_needed()
return self._overlays

def get_bounding_boxes(self) -> list[BoundingBoxes]:
"""Get bounding boxes."""
self._register_primitives_if_needed()
return self._bounding_boxes

def _register_primitives_if_needed(self):
if not self._registered_primitives:
self._register_primitives()
self._registered_primitives = True
Loading
Loading