diff --git a/src/python/model_api/visualizer/__init__.py b/src/python/model_api/visualizer/__init__.py index 4d7718e5..924d4d0e 100644 --- a/src/python/model_api/visualizer/__init__.py +++ b/src/python/model_api/visualizer/__init__.py @@ -4,8 +4,8 @@ # SPDX-License-Identifier: Apache-2.0 from .layout import Flatten, HStack, Layout -from .primitive import BoundingBox, Overlay, Polygon +from .primitive import BoundingBox, Label, Overlay, Polygon from .scene import Scene from .visualizer import Visualizer -__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] +__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] diff --git a/src/python/model_api/visualizer/layout/flatten.py b/src/python/model_api/visualizer/layout/flatten.py index 7a8e7f58..ba128658 100644 --- a/src/python/model_api/visualizer/layout/flatten.py +++ b/src/python/model_api/visualizer/layout/flatten.py @@ -5,7 +5,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Type, Union, cast + +from model_api.visualizer.primitive import Label from .layout import Layout @@ -29,13 +31,19 @@ def __init__(self, *args: Union[Type[Primitive], Layout]) -> None: def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, scene: Scene) -> PIL.Image | None: if scene.has_primitives(primitive): primitives = scene.get_primitives(primitive) - for _primitive in primitives: - image = _primitive.compute(image) + if primitive == Label: # Labels need to be rendered next to each other + # cast is needed as mypy does not know that the primitives are of type Label. + primitives_ = cast("list[Label]", primitives) + image = Label.overlay_labels(image, primitives_) + else: + # Other primitives are rendered on top of each other + for _primitive in primitives: + image = _primitive.compute(image) return image return None def __call__(self, scene: Scene) -> PIL.Image: - _image: PIL.Image = scene.base.copy() + image_: PIL.Image = scene.base.copy() for child in self.children: - _image = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, _image, scene) - return _image + image_ = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, image_, scene) + return image_ diff --git a/src/python/model_api/visualizer/layout/hstack.py b/src/python/model_api/visualizer/layout/hstack.py index c5e1902e..6eb9e87a 100644 --- a/src/python/model_api/visualizer/layout/hstack.py +++ b/src/python/model_api/visualizer/layout/hstack.py @@ -30,8 +30,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc if scene.has_primitives(primitive): images = [] for _primitive in scene.get_primitives(primitive): - _image = _primitive.compute(image.copy()) - images.append(_image) + image_ = _primitive.compute(image.copy()) + images.append(image_) return self._stitch(*images) return None @@ -70,9 +70,9 @@ def __call__(self, scene: Scene) -> PIL.Image: images: list[PIL.Image] = [] for child in self.children: if isinstance(child, Layout): - _image = child(scene) + image_ = child(scene) else: - _image = self._compute_on_primitive(child, scene.base.copy(), scene) - if _image is not None: - images.append(_image) + image_ = self._compute_on_primitive(child, scene.base.copy(), scene) + if image_ is not None: + images.append(image_) return self._stitch(*images) diff --git a/src/python/model_api/visualizer/primitive/__init__.py b/src/python/model_api/visualizer/primitive/__init__.py index 51837c59..ba6c135c 100644 --- a/src/python/model_api/visualizer/primitive/__init__.py +++ b/src/python/model_api/visualizer/primitive/__init__.py @@ -4,8 +4,9 @@ # SPDX-License-Identifier: Apache-2.0 from .bounding_box import BoundingBox +from .label import Label from .overlay import Overlay from .polygon import Polygon from .primitive import Primitive -__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"] +__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"] diff --git a/src/python/model_api/visualizer/primitive/label.py b/src/python/model_api/visualizer/primitive/label.py new file mode 100644 index 00000000..89dbfe09 --- /dev/null +++ b/src/python/model_api/visualizer/primitive/label.py @@ -0,0 +1,106 @@ +"""Label primitive.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from io import BytesIO +from typing import Union + +from PIL import Image, ImageDraw, ImageFont + +from .primitive import Primitive + + +class Label(Primitive): + """Label primitive. + + Labels require a different processing than other primitives as the class also handles the instance when the layout + requests all the labels to be drawn on a single image. + + Args: + label (str): Text of the label. + fg_color (str | tuple[int, int, int]): Foreground color of the label. + bg_color (str | tuple[int, int, int]): Background color of the label. + font_path (str | None | BytesIO): Path to the font file. + size (int): Size of the font. + + Examples: + >>> label = Label(label="Label 1") + >>> label.compute(image).save("label.jpg") + + >>> label = Label(text="Label 1", fg_color="red", bg_color="blue", font_path="arial.ttf", size=20) + >>> label.compute(image).save("label.jpg") + + or multiple labels on a single image: + >>> label1 = Label(text="Label 1") + >>> label2 = Label(text="Label 2") + >>> label3 = Label(text="Label 3") + >>> Label.overlay_labels(image, [label1, label2, label3]).save("labels.jpg") + """ + + def __init__( + self, + label: str, + fg_color: Union[str, tuple[int, int, int]] = "black", + bg_color: Union[str, tuple[int, int, int]] = "yellow", + font_path: Union[str, BytesIO, None] = None, + size: int = 16, + ) -> None: + self.label = label + self.fg_color = fg_color + self.bg_color = bg_color + self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size) + + def compute(self, image: Image, buffer_y: int = 5) -> Image: + """Generate label on top of the image. + + Args: + image (PIL.Image): Image to paste the label on. + buffer_y (int): Buffer to add to the y-axis of the label. + """ + label_image = self.generate_label_image(buffer_y) + image.paste(label_image, (0, 0)) + return image + + def generate_label_image(self, buffer_y: int = 5) -> Image: + """Generate label image. + + Args: + buffer_y (int): Buffer to add to the y-axis of the label. This is needed as the text is clipped from the + bottom otherwise. + + Returns: + PIL.Image: Image that consists only of the label. + """ + dummy_image = Image.new("RGB", (1, 1)) + draw = ImageDraw.Draw(dummy_image) + textbox = draw.textbbox((0, 0), self.label, font=self.font) + label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color) + draw = ImageDraw.Draw(label_image) + draw.text((0, 0), self.label, font=self.font, fill=self.fg_color) + return label_image + + @classmethod + def overlay_labels(cls, image: Image, labels: list["Label"], buffer_y: int = 5, buffer_x: int = 5) -> Image: + """Overlay multiple label images on top of the image. + Paste the labels in a row but wrap the labels if they exceed the image width. + + Args: + image (PIL.Image): Image to paste the labels on. + labels (list[Label]): Labels to be pasted on the image. + buffer_y (int): Buffer to add to the y-axis of the labels. + buffer_x (int): Space between the labels. + + Returns: + PIL.Image: Image with the labels pasted on it. + """ + offset_x = 0 + offset_y = 0 + for label in labels: + label_image = label.generate_label_image(buffer_y) + image.paste(label_image, (offset_x, offset_y)) + offset_x += label_image.width + buffer_x + if offset_x + label_image.width > image.width: + offset_x = 0 + offset_y += label_image.height + return image diff --git a/src/python/model_api/visualizer/scene/scene.py b/src/python/model_api/visualizer/scene/scene.py index 9fdb7072..161c41fc 100644 --- a/src/python/model_api/visualizer/scene/scene.py +++ b/src/python/model_api/visualizer/scene/scene.py @@ -5,12 +5,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from PIL import Image -from model_api.visualizer.primitive import Overlay, Primitive +from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive if TYPE_CHECKING: from pathlib import Path @@ -27,16 +27,23 @@ class Scene: def __init__( self, base: Image, + bounding_box: BoundingBox | list[BoundingBox] | None = None, + label: Label | list[Label] | None = None, overlay: Overlay | list[Overlay] | np.ndarray | None = None, + polygon: Polygon | list[Polygon] | None = None, layout: Layout | None = None, ) -> None: self.base = base self.overlay = self._to_overlay(overlay) + self.bounding_box = self._to_bounding_box(bounding_box) + self.label = self._to_label(label) + self.polygon = self._to_polygon(polygon) self.layout = layout def show(self) -> Image: ... - def save(self, path: Path) -> None: ... + def save(self, path: Path) -> None: + self.render().save(path) def render(self) -> Image: if self.layout is None: @@ -46,16 +53,42 @@ def render(self) -> Image: def has_primitives(self, primitive: type[Primitive]) -> bool: if primitive == Overlay: return bool(self.overlay) + if primitive == BoundingBox: + return bool(self.bounding_box) + if primitive == Label: + return bool(self.label) + if primitive == Polygon: + return bool(self.polygon) return False def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]: + """Get primitives of the given type. + + Args: + primitive (type[Primitive]): The type of primitive to get. + + Example: + >>> scene = Scene(base=Image.new("RGB", (100, 100)), overlay=[Overlay(Image.new("RGB", (100, 100)))]) + >>> scene.get_primitives(Overlay) + [Overlay(image=Image.new("RGB", (100, 100)))] + + Returns: + list[Primitive]: The primitives of the given type or an empty list if no primitives are found. + """ primitives: list[Primitive] | None = None + # cast is needed as mypy does not know that the primitives are a subclass of Primitive. if primitive == Overlay: - primitives = self.overlay # type: ignore[assignment] # TODO(ashwinvaidya17): Address this in the next PR - if primitives is None: + primitives = cast("list[Primitive]", self.overlay) + elif primitive == BoundingBox: + primitives = cast("list[Primitive]", self.bounding_box) + elif primitive == Label: + primitives = cast("list[Primitive]", self.label) + elif primitive == Polygon: + primitives = cast("list[Primitive]", self.polygon) + else: msg = f"Primitive {primitive} not found" raise ValueError(msg) - return primitives + return primitives or [] @property def default_layout(self) -> Layout: @@ -70,3 +103,18 @@ def _to_overlay(self, overlay: Overlay | list[Overlay] | np.ndarray | None) -> l if isinstance(overlay, Overlay): return [overlay] return overlay + + def _to_bounding_box(self, bounding_box: BoundingBox | list[BoundingBox] | None) -> list[BoundingBox] | None: + if isinstance(bounding_box, BoundingBox): + return [bounding_box] + return bounding_box + + def _to_label(self, label: Label | list[Label] | None) -> list[Label] | None: + if isinstance(label, Label): + return [label] + return label + + def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon] | None: + if isinstance(polygon, Polygon): + return [polygon] + return polygon diff --git a/tests/python/unit/visualizer/test_primitive.py b/tests/python/unit/visualizer/test_primitive.py index 2f8a540b..85ef4447 100644 --- a/tests/python/unit/visualizer/test_primitive.py +++ b/tests/python/unit/visualizer/test_primitive.py @@ -7,7 +7,7 @@ import PIL from PIL import ImageDraw -from model_api.visualizer import BoundingBox, Overlay, Polygon +from model_api.visualizer import BoundingBox, Label, Overlay, Polygon def test_overlay(mock_image: PIL.Image): @@ -51,3 +51,9 @@ def test_polygon(mock_image: PIL.Image): draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red") polygon = Polygon(mask=mask, color="red") assert polygon.compute(mock_image) == expected_image + + +def test_label(mock_image: PIL.Image): + label = Label(label="Label") + # When using a single label, compute and overlay_labels should return the same image + assert label.compute(mock_image) == Label.overlay_labels(mock_image, [label])