Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Jan 19, 2024
1 parent eb44355 commit d90f158
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 68 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1174>)
- Add ImportError to catch GitPython import error
(<https://github.com/openvinotoolkit/datumaro/pull/1174>)
- Enable image backend and color channel format to be selectable
(<https://github.com/openvinotoolkit/datumaro/pull/1246>)

### Bug fixes
- Modify the draw function in the visualizer not to raise an error for unsupported annotation types.
Expand Down
129 changes: 80 additions & 49 deletions src/datumaro/util/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (C) 2019-2023 Intel Corporation
# Copyright (C) 2019-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations

import importlib
import os
import os.path as osp
import shlex
Expand All @@ -26,20 +25,21 @@
DTypeLike = Any


class _IMAGE_BACKENDS(Enum):
class ImageBackend(Enum):
cv2 = auto()
PIL = auto()


_IMAGE_BACKEND: ContextVar[_IMAGE_BACKENDS] = ContextVar("_IMAGE_BACKENDS")
IMAGE_BACKEND: ContextVar[ImageBackend] = ContextVar("IMAGE_BACKEND")
_image_loading_errors = (FileNotFoundError,)
try:
importlib.import_module("cv2")
_IMAGE_BACKEND.set(_IMAGE_BACKENDS.cv2)
import cv2

IMAGE_BACKEND.set(ImageBackend.cv2)
except ModuleNotFoundError:
import PIL

_IMAGE_BACKEND.set(_IMAGE_BACKENDS.PIL)
IMAGE_BACKEND.set(ImageBackend.PIL)
_image_loading_errors = (*_image_loading_errors, PIL.UnidentifiedImageError)

from datumaro.util.image_cache import ImageCache
Expand All @@ -49,63 +49,101 @@ class _IMAGE_BACKENDS(Enum):
from PIL.Image import Image as PILImage


class ImageColorScale(Enum):
"""Image color scale
class ImageColorChannel(Enum):
"""Image color channel
- UNCHANGED: Use the original image's scale (default)
- COLOR: Use 3 channels (it can ignore the alpha channel or convert the gray scale image to BGR)
- UNCHANGED: Use the original image's channel (default)
- COLOR_BGR: Use BGR 3 channels (it can ignore the alpha channel or convert the gray scale image)
- COLOR_RGB: Use RGB 3 channels (it can ignore the alpha channel or convert the gray scale image)
"""

UNCHANGED = 0
COLOR = 1
COLOR_BGR = 1
COLOR_RGB = 2

def convert_cv2(self, img: np.ndarray) -> np.ndarray:
import cv2
def decode_by_cv2(self, image_bytes: bytes) -> np.ndarray:
"""Convert image color channel for OpenCV image (np.ndarray)."""
image_buffer = np.frombuffer(image_bytes, dtype=np.uint8)

if self == ImageColorChannel.UNCHANGED:
return cv2.imdecode(image_buffer, cv2.IMREAD_UNCHANGED)

img = cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)

if self == ImageColorChannel.COLOR_BGR:
if len(img.shape) == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[-1] == 4:
return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)

if self == ImageColorScale.COLOR:
return cv2.imdecode(img, cv2.IMREAD_COLOR)
return img

if self == ImageColorChannel.COLOR_RGB:
if len(img.shape) == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[-1] == 4:
return cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)

return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

raise ValueError

def decode_by_pil(self, image_bytes: bytes) -> PILImage:
"""Convert image color channel for PIL Image."""
from PIL import Image

return cv2.imdecode(img, cv2.IMREAD_UNCHANGED)
img = Image.open(BytesIO(image_bytes))

def convert_pil(self, img: PILImage) -> PILImage:
if self == ImageColorScale.COLOR:
if self == ImageColorChannel.UNCHANGED:
return img

if self == ImageColorChannel.COLOR_BGR:
return Image.fromarray(np.flip(np.asarray(img.convert("RGB")), -1))

if self == ImageColorChannel.COLOR_RGB:
return img.convert("RGB")

return img
raise ValueError


IMAGE_COLOR_SCALE: ContextVar[ImageColorScale] = ContextVar(
"IMAGE_COLOR_SCALE", default=ImageColorScale.UNCHANGED
IMAGE_COLOR_CHANNEL: ContextVar[ImageColorChannel] = ContextVar(
"IMAGE_COLOR_CHANNEL", default=ImageColorChannel.UNCHANGED
)


@contextmanager
def decode_image_context(color_scale: ImageColorScale):
"""Change Datumaro image decoding color scale.
def decode_image_context(image_backend: ImageBackend, image_color_channel: ImageColorChannel):
"""Change Datumaro image color channel while decoding.
For model training, it is recommended to use this context manager
to load images in the BGR 3-channel format. For example,
.. code-block:: python
import datumaro as dm
with decode_image_context(ImageColorScale.COLOR):
with decode_image_context(image_backend=ImageBackend.cv2, image_color_channel=ImageColorScale.COLOR):
item: dm.DatasetItem
img_data = item.media_as(dm.Image).data
assert img_data.shape[-1] == 3 # It should be a 3-channel image
"""
curr_ctx = IMAGE_COLOR_SCALE.get()
IMAGE_COLOR_SCALE.set(color_scale)

curr_ctx = (IMAGE_BACKEND.get(), IMAGE_COLOR_CHANNEL.get())

IMAGE_BACKEND.set(image_backend)
IMAGE_COLOR_CHANNEL.set(image_color_channel)

yield
IMAGE_COLOR_SCALE.set(curr_ctx)

IMAGE_BACKEND.set(curr_ctx[0])
IMAGE_COLOR_CHANNEL.set(curr_ctx[1])


def load_image(path: str, dtype: DTypeLike = np.uint8, crypter: Crypter = NULL_CRYPTER):
"""
Reads an image in the HWC Grayscale/BGR(A) [0; 255] format (default dtype is uint8).
"""

if _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.cv2:
if IMAGE_BACKEND.get() == ImageBackend.cv2:
# cv2.imread does not support paths that are not representable
# in the locale encoding on Windows, so we read the image bytes
# ourselves.
Expand All @@ -114,13 +152,13 @@ def load_image(path: str, dtype: DTypeLike = np.uint8, crypter: Crypter = NULL_C
image_bytes = crypter.decrypt(f.read())

return decode_image(image_bytes, dtype=dtype)
elif _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.PIL:
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
with open(path, "rb") as f:
image_bytes = crypter.decrypt(f.read())

return decode_image(image_bytes, dtype=dtype)

raise NotImplementedError(_IMAGE_BACKEND)
raise NotImplementedError(IMAGE_BACKEND)


def copyto_image(
Expand Down Expand Up @@ -175,10 +213,10 @@ def save_image(
# NOTE: OpenCV documentation says "If the image format is not supported,
# the image will be converted to 8-bit unsigned and saved that way".
# Conversion from np.int32 to np.uint8 is not working properly
backend = _IMAGE_BACKEND.get()
backend = IMAGE_BACKEND.get()
if dtype == np.int32:
backend = _IMAGE_BACKENDS.PIL
if backend == _IMAGE_BACKENDS.cv2:
backend = ImageBackend.PIL
if backend == ImageBackend.cv2:
# cv2.imwrite does not support paths that are not representable
# in the locale encoding on Windows, so we write the image bytes
# ourselves.
Expand All @@ -190,7 +228,7 @@ def save_image(
f.write(crypter.encrypt(image_bytes))
else:
dst.write(crypter.encrypt(image_bytes))
elif backend == _IMAGE_BACKENDS.PIL:
elif backend == ImageBackend.PIL:
from PIL import Image

if ext.startswith("."):
Expand All @@ -217,7 +255,7 @@ def encode_image(image: np.ndarray, ext: str, dtype: DTypeLike = np.uint8, **kwa
if not kwargs:
kwargs = {}

if _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.cv2:
if IMAGE_BACKEND.get() == ImageBackend.cv2:
import cv2

params = []
Expand All @@ -233,7 +271,7 @@ def encode_image(image: np.ndarray, ext: str, dtype: DTypeLike = np.uint8, **kwa
if not success:
raise Exception("Failed to encode image to '%s' format" % (ext))
return result.tobytes()
elif _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.PIL:
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
from PIL import Image

if ext.startswith("."):
Expand All @@ -256,21 +294,14 @@ def encode_image(image: np.ndarray, ext: str, dtype: DTypeLike = np.uint8, **kwa


def decode_image(image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray:
ctx_color_scale = IMAGE_COLOR_SCALE.get()
ctx_color_scale = IMAGE_COLOR_CHANNEL.get()

if _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.cv2:
image = np.frombuffer(image_bytes, dtype=np.uint8)
image = ctx_color_scale.convert_cv2(image)
if IMAGE_BACKEND.get() == ImageBackend.cv2:
image = ctx_color_scale.decode_by_cv2(image_bytes)
image = image.astype(dtype)
elif _IMAGE_BACKEND.get() == _IMAGE_BACKENDS.PIL:
from PIL import Image

image = Image.open(BytesIO(image_bytes))
image = ctx_color_scale.convert_pil(image)
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
image = ctx_color_scale.decode_by_pil(image_bytes)
image = np.asarray(image, dtype=dtype)
if len(image.shape) == 3 and image.shape[2] in {3, 4}:
image = np.array(image) # Release read-only
image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
else:
raise NotImplementedError()

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_crypter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from datumaro.components.crypter import NULL_CRYPTER, Crypter
from datumaro.components.media import Image
from datumaro.util.image import _IMAGE_BACKEND, _IMAGE_BACKENDS
from datumaro.util.image import IMAGE_BACKEND, ImageBackend


@pytest.fixture(scope="class")
Expand All @@ -36,12 +36,12 @@ def fxt_encrypted_image_file(test_dir, fxt_image_file, fxt_crypter):


class CrypterTest:
@pytest.fixture(scope="class", params=[_IMAGE_BACKENDS.cv2, _IMAGE_BACKENDS.PIL], autouse=True)
@pytest.fixture(scope="class", params=[ImageBackend.cv2, ImageBackend.PIL], autouse=True)
def fxt_image_backend(self, request):
curr_backend = _IMAGE_BACKEND.get()
_IMAGE_BACKEND.set(request.param)
curr_backend = IMAGE_BACKEND.get()
IMAGE_BACKEND.set(request.param)
yield
_IMAGE_BACKEND.set(curr_backend)
IMAGE_BACKEND.set(curr_backend)

def test_load_encrypted_image(self, fxt_image_file, fxt_encrypted_image_file, fxt_crypter):
img = Image.from_file(path=fxt_image_file)
Expand Down
56 changes: 42 additions & 14 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2019-2023 Intel Corporation
# Copyright (C) 2019-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand All @@ -18,14 +18,14 @@

class ImageOperationsTest(TestCase):
def setUp(self):
self.default_backend = image_module._IMAGE_BACKEND.get()
self.default_backend = image_module.IMAGE_BACKEND.get()

def tearDown(self):
image_module._IMAGE_BACKEND.set(self.default_backend)
image_module.IMAGE_BACKEND.set(self.default_backend)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_save_and_load_backends(self):
backends = image_module._IMAGE_BACKENDS
backends = image_module.ImageBackend
for save_backend, load_backend, c in product(backends, backends, [1, 3]):
with TestDir() as test_dir:
if c == 1:
Expand All @@ -34,10 +34,10 @@ def test_save_and_load_backends(self):
src_image = np.random.randint(0, 255 + 1, (2, 4, c))
path = osp.join(test_dir, "img.png") # lossless

image_module._IMAGE_BACKEND.set(save_backend)
image_module.IMAGE_BACKEND.set(save_backend)
image_module.save_image(path, src_image, jpeg_quality=100)

image_module._IMAGE_BACKEND.set(load_backend)
image_module.IMAGE_BACKEND.set(load_backend)
dst_image = image_module.load_image(path)

self.assertTrue(
Expand All @@ -47,17 +47,17 @@ def test_save_and_load_backends(self):

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_encode_and_decode_backends(self):
backends = image_module._IMAGE_BACKENDS
backends = image_module.ImageBackend
for save_backend, load_backend, c in product(backends, backends, [1, 3]):
if c == 1:
src_image = np.random.randint(0, 255 + 1, (2, 4))
else:
src_image = np.random.randint(0, 255 + 1, (2, 4, c))

image_module._IMAGE_BACKEND.set(save_backend)
image_module.IMAGE_BACKEND.set(save_backend)
buffer = image_module.encode_image(src_image, ".png", jpeg_quality=100) # lossless

image_module._IMAGE_BACKEND.set(load_backend)
image_module.IMAGE_BACKEND.set(load_backend)
dst_image = image_module.decode_image(buffer)

self.assertTrue(
Expand All @@ -83,17 +83,45 @@ class ImageDecodeTest:
def fxt_img_four_channels(self) -> np.ndarray:
return np.random.randint(low=0, high=256, size=(5, 4, 4), dtype=np.uint8)

def test_decode_image_context(self, fxt_img_four_channels: np.ndarray):
@pytest.mark.parametrize(
"image_backend", [image_module.ImageBackend.cv2, image_module.ImageBackend.PIL]
)
def test_decode_image_context(
self, fxt_img_four_channels: np.ndarray, image_backend: image_module.ImageBackend
):
img_bytes = image_module.encode_image(fxt_img_four_channels, ".png")

# 3 channels from ImageColorScale.COLOR
with image_module.decode_image_context(image_module.ImageColorScale.COLOR):
# 3 channels from ImageColorScale.COLOR_BGR
with image_module.decode_image_context(
image_backend, image_module.ImageColorChannel.COLOR_BGR
):
img_decoded = image_module.decode_image(img_bytes)
assert img_decoded.shape[-1] == 3
assert np.allclose(fxt_img_four_channels[:, :, :3], img_decoded)

# 3 channels from ImageColorScale.COLOR_RGB
with image_module.decode_image_context(
image_backend, image_module.ImageColorChannel.COLOR_RGB
):
img_decoded = image_module.decode_image(img_bytes)
assert img_decoded.shape[-1] == 3
assert np.allclose(
fxt_img_four_channels[:, :, :3][:, :, ::-1], # Flip color channels of the fixture
img_decoded,
)

# 4 channels from ImageColorScale.UNCHANGED
with image_module.decode_image_context(image_module.ImageColorScale.UNCHANGED):
with image_module.decode_image_context(
image_backend, image_module.ImageColorChannel.UNCHANGED
):
img_decoded = image_module.decode_image(img_bytes)
assert img_decoded.shape[-1] == 4
assert np.allclose(fxt_img_four_channels, img_decoded)

if image_backend == image_module.ImageBackend.cv2:
assert np.allclose(fxt_img_four_channels, img_decoded)
else:
# PIL will return RGBA, thus we need to correct the fixture
to_rgb = fxt_img_four_channels[:, :, :3][:, :, ::-1]
fxt_img_four_channels[:, :, :3] = to_rgb

assert np.allclose(fxt_img_four_channels, img_decoded)

0 comments on commit d90f158

Please sign in to comment.