diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f4c853d16..788983e085 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Enhancements - Enhance Datumaro data format stream importer performance () +- Change image default dtype from float32 to uint8 + () ### Bug fixes - Fix errata in the voc document. Color values in the labelmap.txt should be separated by commas, not colons. diff --git a/src/datumaro/components/media.py b/src/datumaro/components/media.py index dbfca986bc..adc685e33a 100644 --- a/src/datumaro/components/media.py +++ b/src/datumaro/components/media.py @@ -309,7 +309,7 @@ def __init__( @property def data(self) -> Optional[np.ndarray]: - """Image data in BGRA HWC [0; 255] (float) format""" + """Image data in BGRA HWC [0; 255] (uint8) format""" if not self.has_data: return None @@ -375,12 +375,12 @@ def __init__( @property def data(self) -> Optional[np.ndarray]: - """Image data in BGRA HWC [0; 255] (float) format""" + """Image data in BGRA HWC [0; 255] (uint8) format""" data = super().data - if isinstance(data, np.ndarray): - data = data.astype(np.float32) + if isinstance(data, np.ndarray) and data.dtype != np.uint8: + data = np.clip(data, 0.0, 255.0).astype(np.uint8) if self._size is None and data is not None: if not 2 <= data.ndim <= 3: raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.") @@ -420,14 +420,12 @@ def _guess_ext(cls, data: bytes) -> Optional[str]: @property def data(self) -> Optional[np.ndarray]: - """Image data in BGRA HWC [0; 255] (float) format""" + """Image data in BGRA HWC [0; 255] (uint8) format""" data = super().data if isinstance(data, bytes): - data = decode_image(data) - if isinstance(data, np.ndarray): - data = data.astype(np.float32) + data = decode_image(data, dtype=np.uint8) if self._size is None and data is not None: if not 2 <= data.ndim <= 3: raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.") diff --git a/src/datumaro/plugins/framework_converter.py b/src/datumaro/plugins/framework_converter.py index 469deb566b..556005e1b7 100644 --- a/src/datumaro/plugins/framework_converter.py +++ b/src/datumaro/plugins/framework_converter.py @@ -112,9 +112,6 @@ def __init__( def __getitem__(self, idx): image, label = self._gen_item(idx) - if image.dtype == np.uint8 or image.max() > 1: - image = image.astype(np.float32) / 255 - if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) diff --git a/tests/unit/operations/test_statistics.py b/tests/unit/operations/test_statistics.py index 438b3f3d86..fc76f3f48c 100644 --- a/tests/unit/operations/test_statistics.py +++ b/tests/unit/operations/test_statistics.py @@ -19,8 +19,9 @@ @pytest.fixture def fxt_image_dataset_expected_mean_std(): + np.random.seed(3003) expected_mean = [100, 50, 150] - expected_std = [20, 50, 10] + expected_std = [2, 1, 3] return expected_mean, expected_std @@ -90,9 +91,9 @@ def test_image_stats( actual_std = actual["subsets"]["default"]["image std"][::-1] for em, am in zip(expected_mean, actual_mean): - assert am == pytest.approx(em, 1e-2) + assert am == pytest.approx(em, 5e-1) for estd, astd in zip(expected_std, actual_std): - assert astd == pytest.approx(estd, 1e-2) + assert astd == pytest.approx(estd, 1e-1) @mark_requirement(Requirements.DATUM_BUG_873) def test_invalid_media_type( diff --git a/tests/unit/test_framework_converter.py b/tests/unit/test_framework_converter.py index 3c0fb9efb5..0933884293 100644 --- a/tests/unit/test_framework_converter.py +++ b/tests/unit/test_framework_converter.py @@ -296,7 +296,8 @@ def test_can_convert_torch_framework( label = np.sum(masks, axis=0, dtype=np.uint8) if fxt_convert_kwargs.get("transform", None): - assert np.array_equal(image, dm_torch_item[0].reshape(5, 5, 3).numpy()) + actual = dm_torch_item[0].permute(1, 2, 0).mul(255.0).to(torch.uint8).numpy() + assert np.array_equal(image, actual) else: assert np.array_equal(image, dm_torch_item[0]) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index cd6f4ab12c..b15399fd26 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -44,8 +44,9 @@ class TestOperations(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_mean_std(self): + np.random.seed(3000) expected_mean = [100, 50, 150] - expected_std = [20, 50, 10] + expected_std = [2, 1, 3] dataset = Dataset.from_iterable( [ @@ -62,9 +63,9 @@ def test_mean_std(self): actual_mean, actual_std = mean_std(dataset) for em, am in zip(expected_mean, actual_mean): - self.assertAlmostEqual(em, am, places=0) + assert np.allclose(em, am, atol=0.6) for estd, astd in zip(expected_std, actual_std): - self.assertAlmostEqual(estd, astd, places=0) + assert np.allclose(estd, astd, atol=0.1) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_stats(self):