Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype argument when calling media.data #1546

Merged
merged 8 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1492>)
- Pass Keyword Argument to TabularDataBase
(<https://github.com/openvinotoolkit/datumaro/pull/1522>)
- Enable dtype argument when calling media.data
(<https://github.com/openvinotoolkit/datumaro/pull/1546>)

### Bug fixes
- Preserve end_frame information of a video when it is zero.
Expand Down
29 changes: 26 additions & 3 deletions src/datumaro/components/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
from copy import deepcopy
from enum import IntEnum
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -39,6 +40,7 @@
copyto_image,
decode_image,
lazy_image,
load_image,
save_image,
)

Expand Down Expand Up @@ -224,6 +226,7 @@
f"{self.__class__.__name__}.from_numpy(), {self.__class__.__name__}.from_bytes())."
)
super().__init__(*args, **kwargs)
self._dtype = np.uint8

if ext is not None:
if not ext.startswith("."):
Expand Down Expand Up @@ -322,6 +325,8 @@
if not self.has_data:
return None

if self.__data._dtype != self._dtype:
self.__data._loader = partial(load_image, dtype=self._dtype)
data = self.__data()

if self._size is None and data is not None:
Expand Down Expand Up @@ -368,6 +373,11 @@
if isinstance(self.__data, lazy_image):
self.__data._crypter = crypter

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""
self._dtype = dtype
return self.data


class ImageFromData(FromDataMixin, Image):
def save(
Expand Down Expand Up @@ -400,8 +410,8 @@

data = super().data

if isinstance(data, np.ndarray) and data.dtype != np.uint8:
data = np.clip(data, 0.0, 255.0).astype(np.uint8)
if isinstance(data, np.ndarray) and data.dtype != self._dtype:
data = np.clip(data, 0.0, 255.0).astype(self._dtype)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand All @@ -413,6 +423,11 @@
"""Indicates that size info is cached and won't require image loading"""
return self._size is not None or isinstance(self._data, np.ndarray)

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""
self._dtype = dtype
return self.data


class ImageFromBytes(ImageFromData):
_FORMAT_MAGICS = (
Expand Down Expand Up @@ -446,13 +461,21 @@
data = super().data

if isinstance(data, bytes):
data = decode_image(data, dtype=np.uint8)
data = decode_image(data, dtype=self._dtype)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
self._size = tuple(map(int, data.shape[:2]))
return data

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""

if dtype != np.uint8:
raise ValueError("ImageFromBytes only support `dtype=np.uint8`.")
self._dtype = dtype
return self.data

Check warning on line 477 in src/datumaro/components/media.py

View check run for this annotation

Codecov / codecov/patch

src/datumaro/components/media.py#L475-L477

Added lines #L475 - L477 were not covered by tests


class VideoFrame(ImageFromNumpy):
_type = MediaType.VIDEO_FRAME
Expand Down
28 changes: 20 additions & 8 deletions src/datumaro/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
COLOR_BGR = 1
COLOR_RGB = 2

def decode_by_cv2(self, image_bytes: bytes) -> np.ndarray:
def decode_by_cv2(self, image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray:
"""Convert image color channel for OpenCV image (np.ndarray)."""
image_buffer = np.frombuffer(image_bytes, dtype=np.uint8)
image_buffer = np.frombuffer(image_bytes, dtype=dtype)

if self == ImageColorChannel.UNCHANGED:
return cv2.imdecode(image_buffer, cv2.IMREAD_UNCHANGED)
Expand Down Expand Up @@ -283,15 +283,26 @@
raise NotImplementedError()


def decode_image(image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray:
def decode_image(image_bytes: bytes, dtype: np.dtype = np.uint8) -> np.ndarray:
ctx_color_scale = IMAGE_COLOR_CHANNEL.get()

if IMAGE_BACKEND.get() == ImageBackend.cv2:
image = ctx_color_scale.decode_by_cv2(image_bytes)
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
image = ctx_color_scale.decode_by_pil(image_bytes)
if np.issubdtype(dtype, np.floating):
# PIL doesn't support floating point image loading
# CV doesn't support floating point image with color channel setting (IMREAD_COLOR)
with decode_image_context(
image_backend=ImageBackend.cv2, image_color_channel=ImageColorChannel.UNCHANGED
):
image = ctx_color_scale.decode_by_cv2(image_bytes, dtype=dtype)
image = image[..., ::-1]
if ctx_color_scale == ImageColorChannel.COLOR_BGR:
image = image[..., ::-1]

Check warning on line 298 in src/datumaro/util/image.py

View check run for this annotation

Codecov / codecov/patch

src/datumaro/util/image.py#L298

Added line #L298 was not covered by tests
else:
raise NotImplementedError()
if IMAGE_BACKEND.get() == ImageBackend.cv2:
image = ctx_color_scale.decode_by_cv2(image_bytes)
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
image = ctx_color_scale.decode_by_pil(image_bytes)
else:
raise NotImplementedError()

image = image.astype(dtype)

Expand Down Expand Up @@ -376,6 +387,7 @@
assert isinstance(cache, (ImageCache, bool))
self._cache = cache
self._crypter = crypter
self._dtype = dtype

def __call__(self) -> np.ndarray:
image = None
Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ pytest-stress
pytest-html
coverage
pytest-csv

tifffile
20 changes: 20 additions & 0 deletions tests/unit/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,26 @@ def test_ext_detection_failure(self):
image = Image.from_bytes(data=image_bytes)
self.assertEqual(image.ext, None)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_floating_image_from_numpy(self):
image_float = np.random.rand(32, 32, 3).astype(np.float16) * 255.0
media = Image.from_numpy(image_float)
data = media.get_data_as_dtype(dtype=np.float16)
self.assertTrue(np.all(image_float == data))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_floating_image_from_file(self):
import tifffile

with TestDir() as test_dir:
image_float = np.random.rand(32, 32, 3).astype(np.float32) * 255.0
image_path = osp.join(test_dir, "floating_image.tiff")
tifffile.imwrite(image_path, image_float)

media = Image.from_file(image_path)
data = media.get_data_as_dtype(dtype=np.float32)
self.assertTrue(np.all(image_float == data))


class RoIImageTest(TestCase):
def _test_ctors(self, img_ctor, args_list, test_dir, is_bytes=False):
Expand Down
Loading