Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bounding box primitive #251

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions model_api/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Visualizer."""

# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .layout import Flatten, HStack, Layout
from .primitive import Overlay
from .primitive import BoundingBox, Overlay
from .scene import Scene
from .visualizer import Visualizer

__all__ = ["Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
56 changes: 55 additions & 1 deletion model_api/python/model_api/visualizer/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,70 @@

import numpy as np
import PIL
from PIL import Image, ImageDraw


class Primitive(ABC):
"""Primitive class."""

@abstractmethod
def compute(self, image: PIL.Image) -> PIL.Image:
def compute(self, image: Image) -> Image:
pass


class BoundingBox(Primitive):
"""Bounding box primitive.

Args:
x1 (int): x-coordinate of the top-left corner of the bounding box.
y1 (int): y-coordinate of the top-left corner of the bounding box.
x2 (int): x-coordinate of the bottom-right corner of the bounding box.
y2 (int): y-coordinate of the bottom-right corner of the bounding box.
label (str | None): Label of the bounding box.
color (str | tuple[int, int, int]): Color of the bounding box.

Example:
>>> bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100, label="Label Name")
>>> bounding_box.compute(image)
"""

def __init__(
self,
x1: int,
y1: int,
x2: int,
y2: int,
label: str | None = None,
color: str | tuple[int, int, int] = "blue",
) -> None:
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.label = label
self.color = color
self.y_buffer = 5 # Text at the bottom of the text box is clipped. This prevents that.

def compute(self, image: Image) -> Image:
draw = ImageDraw.Draw(image)
# draw rectangle
draw.rectangle((self.x1, self.y1, self.x2, self.y2), outline=self.color, width=2)
# add label
if self.label:
# draw the background of the label
textbox = draw.textbbox((0, 0), self.label)
label_image = Image.new(
"RGB",
(textbox[2] - textbox[0], textbox[3] + self.y_buffer - textbox[1]),
self.color,
)
draw = ImageDraw.Draw(label_image)
# write the label on the background
draw.text((0, 0), self.label, fill="white")
image.paste(label_image, (self.x1, self.y1))
return image


class Overlay(Primitive):
"""Overlay primitive.

Expand Down
12 changes: 11 additions & 1 deletion tests/python/unit/visualizer/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import numpy as np
import PIL
from PIL import ImageDraw

from model_api.visualizer import Overlay
from model_api.visualizer import BoundingBox, Overlay


def test_overlay(mock_image: PIL.Image):
Expand All @@ -22,3 +23,12 @@ def test_overlay(mock_image: PIL.Image):
data *= 255
overlay = Overlay(data)
assert overlay.compute(empty_image) == expected_image


def test_bounding_box(mock_image: PIL.Image):
"""Test if the bounding box is created correctly."""
expected_image = mock_image.copy()
draw = ImageDraw.Draw(expected_image)
draw.rectangle((10, 10, 100, 100), outline="blue", width=2)
bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100)
assert bounding_box.compute(mock_image) == expected_image
Loading