Skip to content

Commit

Permalink
Add label primitive
Browse files Browse the repository at this point in the history
Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 committed Jan 17, 2025
1 parent ae3241a commit b75fa98
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# SPDX-License-Identifier: Apache-2.0

from .layout import Flatten, HStack, Layout
from .primitive import BoundingBox, Overlay, Polygon
from .primitive import BoundingBox, Label, Overlay, Polygon
from .scene import Scene
from .visualizer import Visualizer

__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
20 changes: 14 additions & 6 deletions src/python/model_api/visualizer/layout/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Type, Union
from typing import TYPE_CHECKING, Type, Union, cast

from model_api.visualizer.primitive import Label

from .layout import Layout

Expand All @@ -29,13 +31,19 @@ def __init__(self, *args: Union[Type[Primitive], Layout]) -> None:
def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, scene: Scene) -> PIL.Image | None:
if scene.has_primitives(primitive):
primitives = scene.get_primitives(primitive)
for _primitive in primitives:
image = _primitive.compute(image)
if primitive == Label: # Labels need to be rendered next to each other
# cast is needed as mypy does not know that the primitives are of type Label.
primitives_ = cast("list[Label]", primitives)
image = Label.overlay_labels(image, primitives_)
else:
# Other primitives are rendered on top of each other
for _primitive in primitives:
image = _primitive.compute(image)
return image
return None

def __call__(self, scene: Scene) -> PIL.Image:
_image: PIL.Image = scene.base.copy()
image_: PIL.Image = scene.base.copy()
for child in self.children:
_image = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, _image, scene)
return _image
image_ = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, image_, scene)
return image_
12 changes: 6 additions & 6 deletions src/python/model_api/visualizer/layout/hstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
if scene.has_primitives(primitive):
images = []
for _primitive in scene.get_primitives(primitive):
_image = _primitive.compute(image.copy())
images.append(_image)
image_ = _primitive.compute(image.copy())
images.append(image_)
return self._stitch(*images)
return None

Expand Down Expand Up @@ -70,9 +70,9 @@ def __call__(self, scene: Scene) -> PIL.Image:
images: list[PIL.Image] = []
for child in self.children:
if isinstance(child, Layout):
_image = child(scene)
image_ = child(scene)
else:
_image = self._compute_on_primitive(child, scene.base.copy(), scene)
if _image is not None:
images.append(_image)
image_ = self._compute_on_primitive(child, scene.base.copy(), scene)
if image_ is not None:
images.append(image_)
return self._stitch(*images)
3 changes: 2 additions & 1 deletion src/python/model_api/visualizer/primitive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# SPDX-License-Identifier: Apache-2.0

from .bounding_box import BoundingBox
from .label import Label
from .overlay import Overlay
from .polygon import Polygon
from .primitive import Primitive

__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]
__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"]
106 changes: 106 additions & 0 deletions src/python/model_api/visualizer/primitive/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Label primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from io import BytesIO
from typing import Union

from PIL import Image, ImageDraw, ImageFont

from .primitive import Primitive


class Label(Primitive):
"""Label primitive.
Labels require a different processing than other primitives as the class also handles the instance when the layout
requests all the labels to be drawn on a single image.
Args:
label (str): Text of the label.
fg_color (str | tuple[int, int, int]): Foreground color of the label.
bg_color (str | tuple[int, int, int]): Background color of the label.
font_path (str | None | BytesIO): Path to the font file.
size (int): Size of the font.
Examples:
>>> label = Label(label="Label 1")
>>> label.compute(image).save("label.jpg")
>>> label = Label(text="Label 1", fg_color="red", bg_color="blue", font_path="arial.ttf", size=20)
>>> label.compute(image).save("label.jpg")
or multiple labels on a single image:
>>> label1 = Label(text="Label 1")
>>> label2 = Label(text="Label 2")
>>> label3 = Label(text="Label 3")
>>> Label.overlay_labels(image, [label1, label2, label3]).save("labels.jpg")
"""

def __init__(
self,
label: str,
fg_color: Union[str, tuple[int, int, int]] = "black",
bg_color: Union[str, tuple[int, int, int]] = "yellow",
font_path: Union[str, BytesIO, None] = None,
size: int = 16,
) -> None:
self.label = label
self.fg_color = fg_color
self.bg_color = bg_color
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)

def compute(self, image: Image, buffer_y: int = 5) -> Image:
"""Generate label on top of the image.
Args:
image (PIL.Image): Image to paste the label on.
buffer_y (int): Buffer to add to the y-axis of the label.
"""
label_image = self.generate_label_image(buffer_y)
image.paste(label_image, (0, 0))
return image

def generate_label_image(self, buffer_y: int = 5) -> Image:
"""Generate label image.
Args:
buffer_y (int): Buffer to add to the y-axis of the label. This is needed as the text is clipped from the
bottom otherwise.
Returns:
PIL.Image: Image that consists only of the label.
"""
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), self.label, font=self.font)
label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color)
draw = ImageDraw.Draw(label_image)
draw.text((0, 0), self.label, font=self.font, fill=self.fg_color)
return label_image

@classmethod
def overlay_labels(cls, image: Image, labels: list["Label"], buffer_y: int = 5, buffer_x: int = 5) -> Image:
"""Overlay multiple label images on top of the image.
Paste the labels in a row but wrap the labels if they exceed the image width.
Args:
image (PIL.Image): Image to paste the labels on.
labels (list[Label]): Labels to be pasted on the image.
buffer_y (int): Buffer to add to the y-axis of the labels.
buffer_x (int): Space between the labels.
Returns:
PIL.Image: Image with the labels pasted on it.
"""
offset_x = 0
offset_y = 0
for label in labels:
label_image = label.generate_label_image(buffer_y)
image.paste(label_image, (offset_x, offset_y))
offset_x += label_image.width + buffer_x
if offset_x + label_image.width > image.width:
offset_x = 0
offset_y += label_image.height
return image
60 changes: 54 additions & 6 deletions src/python/model_api/visualizer/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from PIL import Image

from model_api.visualizer.primitive import Overlay, Primitive
from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -27,16 +27,23 @@ class Scene:
def __init__(
self,
base: Image,
bounding_box: BoundingBox | list[BoundingBox] | None = None,
label: Label | list[Label] | None = None,
overlay: Overlay | list[Overlay] | np.ndarray | None = None,
polygon: Polygon | list[Polygon] | None = None,
layout: Layout | None = None,
) -> None:
self.base = base
self.overlay = self._to_overlay(overlay)
self.bounding_box = self._to_bounding_box(bounding_box)
self.label = self._to_label(label)
self.polygon = self._to_polygon(polygon)
self.layout = layout

def show(self) -> Image: ...

def save(self, path: Path) -> None: ...
def save(self, path: Path) -> None:
self.render().save(path)

def render(self) -> Image:
if self.layout is None:
Expand All @@ -46,16 +53,42 @@ def render(self) -> Image:
def has_primitives(self, primitive: type[Primitive]) -> bool:
if primitive == Overlay:
return bool(self.overlay)
if primitive == BoundingBox:
return bool(self.bounding_box)
if primitive == Label:
return bool(self.label)
if primitive == Polygon:
return bool(self.polygon)
return False

def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]:
"""Get primitives of the given type.
Args:
primitive (type[Primitive]): The type of primitive to get.
Example:
>>> scene = Scene(base=Image.new("RGB", (100, 100)), overlay=[Overlay(Image.new("RGB", (100, 100)))])
>>> scene.get_primitives(Overlay)
[Overlay(image=Image.new("RGB", (100, 100)))]
Returns:
list[Primitive]: The primitives of the given type or an empty list if no primitives are found.
"""
primitives: list[Primitive] | None = None
# cast is needed as mypy does not know that the primitives are a subclass of Primitive.
if primitive == Overlay:
primitives = self.overlay # type: ignore[assignment] # TODO(ashwinvaidya17): Address this in the next PR
if primitives is None:
primitives = cast("list[Primitive]", self.overlay)
elif primitive == BoundingBox:
primitives = cast("list[Primitive]", self.bounding_box)
elif primitive == Label:
primitives = cast("list[Primitive]", self.label)
elif primitive == Polygon:
primitives = cast("list[Primitive]", self.polygon)
else:
msg = f"Primitive {primitive} not found"
raise ValueError(msg)
return primitives
return primitives or []

@property
def default_layout(self) -> Layout:
Expand All @@ -70,3 +103,18 @@ def _to_overlay(self, overlay: Overlay | list[Overlay] | np.ndarray | None) -> l
if isinstance(overlay, Overlay):
return [overlay]
return overlay

def _to_bounding_box(self, bounding_box: BoundingBox | list[BoundingBox] | None) -> list[BoundingBox] | None:
if isinstance(bounding_box, BoundingBox):
return [bounding_box]
return bounding_box

def _to_label(self, label: Label | list[Label] | None) -> list[Label] | None:
if isinstance(label, Label):
return [label]
return label

def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon] | None:
if isinstance(polygon, Polygon):
return [polygon]
return polygon
8 changes: 7 additions & 1 deletion tests/python/unit/visualizer/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import PIL
from PIL import ImageDraw

from model_api.visualizer import BoundingBox, Overlay, Polygon
from model_api.visualizer import BoundingBox, Label, Overlay, Polygon


def test_overlay(mock_image: PIL.Image):
Expand Down Expand Up @@ -51,3 +51,9 @@ def test_polygon(mock_image: PIL.Image):
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(mask=mask, color="red")
assert polygon.compute(mock_image) == expected_image


def test_label(mock_image: PIL.Image):
label = Label(label="Label")
# When using a single label, compute and overlay_labels should return the same image
assert label.compute(mock_image) == Label.overlay_labels(mock_image, [label])

0 comments on commit b75fa98

Please sign in to comment.