diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3a381b3..d1fdffc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: rev: v1.8.0 hooks: - id: numpydoc-validation - exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*" + exclude: "^src/careamics/dataset_ng/.*|^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*" # # jupyter linting and formatting # - repo: https://github.com/nbQA-dev/nbQA diff --git a/pyproject.toml b/pyproject.toml index d55d1109..8bb74fbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,8 @@ convention = "numpy" [tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S"] "setup.py" = ["D"] +# temporarily ignore docstrings in next generation dataset development +"src/careamics/dataset_ng/*" = ["D"] [tool.black] line-length = 88 diff --git a/src/careamics/dataset_ng/demo_patch_extractor.ipynb b/src/careamics/dataset_ng/demo_patch_extractor.ipynb new file mode 100644 index 00000000..c4f6dc05 --- /dev/null +++ b/src/careamics/dataset_ng/demo_patch_extractor.ipynb @@ -0,0 +1,135 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from careamics.dataset_ng.patch_extractor import PatchExtractor\n", + "from careamics.dataset_ng.patch_extractor.image_stack import InMemoryImageStack\n", + "from careamics.dataset_ng.patching_strategies import RandomPatchSpecsGenerator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "array = np.arange(36).reshape(6, 6)\n", + "image_stack = InMemoryImageStack.from_array(data=array, axes=\"YX\")\n", + "image_stack.extract_patch(sample_idx=0, coords=(2, 2), patch_size=(3, 3))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define example data\n", + "array1 = np.arange(36).reshape(1, 6, 6)\n", + "array2 = np.arange(50).reshape(2, 5, 5)\n", + "target1 = rng.integers(0, 1, size=array1.shape, endpoint=True)\n", + "target2 = rng.integers(0, 1, size=array2.shape, endpoint=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(array1)\n", + "print(array2)\n", + "print(target1)\n", + "print(target2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define example readers\n", + "input_patch_extractor = PatchExtractor.from_arrays([array1, array2], axes=\"SYX\")\n", + "target_patch_extractor = PatchExtractor.from_arrays([target1, target2], axes=\"SYX\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate random patch specification\n", + "data_shapes = [\n", + " image_stack.data_shape for image_stack in input_patch_extractor.image_stacks\n", + "]\n", + "patch_specs_generator = RandomPatchSpecsGenerator(data_shapes)\n", + "patch_specs = patch_specs_generator.generate(patch_size=(2, 2), seed=42)\n", + "patch_specs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# extract a subset of patches\n", + "input_patch_extractor.extract_patches(patch_specs[7:11])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_patch_extractor.extract_patches(patch_specs[7:11])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/careamics/dataset_ng/patch_extractor/__init__.py b/src/careamics/dataset_ng/patch_extractor/__init__.py new file mode 100644 index 00000000..e23adb6b --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "PatchExtractor", + "PatchExtractorConstructor", + "PatchSpecs", +] + +from .patch_extractor import PatchExtractor, PatchExtractorConstructor, PatchSpecs diff --git a/src/careamics/dataset_ng/patch_extractor/image_stack/__init__.py b/src/careamics/dataset_ng/patch_extractor/image_stack/__init__.py new file mode 100644 index 00000000..41fcc093 --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/image_stack/__init__.py @@ -0,0 +1,9 @@ +__all__ = [ + "ImageStack", + "InMemoryImageStack", + "ZarrImageStack", +] + +from .image_stack_protocol import ImageStack +from .in_memory_image_stack import InMemoryImageStack +from .zarr_image_stack import ZarrImageStack diff --git a/src/careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py b/src/careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py new file mode 100644 index 00000000..16e0efe1 --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py @@ -0,0 +1,47 @@ +from collections.abc import Sequence +from pathlib import Path +from typing import Literal, Protocol, Union + +from numpy.typing import NDArray + + +class ImageStack(Protocol): + """ + An interface for extracting patches from an image stack. + + Attributes + ---------- + source: Path or "array" + Origin of the image data. + data_shape: Sequence[int] + The shape of the data, it is expected to be in the order (SC(Z)YX). + + """ + + # TODO: not sure how compatible using Path will be for a zarr array + # (for a zarr array need to specify file path and internal zarr path) + source: Union[Path, Literal["array"]] + data_shape: Sequence[int] + + def extract_patch( + self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int] + ) -> NDArray: + """ + Extracts a patch for a given sample within the image stack. + + Parameters + ---------- + sample_idx: int + Sample index. The first dimension of the image data will be indexed at this + value. + coords: Sequence of int + The coordinates that define the start of a patch. + patch_size: Sequence of int + The size of the patch in each spatial dimension. + + Returns + ------- + numpy.ndarray + A patch of the image data from a particlular sample. It will have the + dimensions C(Z)YX. + """ diff --git a/src/careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py b/src/careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py new file mode 100644 index 00000000..7342b8ff --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py @@ -0,0 +1,54 @@ +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Literal, Union + +from numpy.typing import NDArray +from typing_extensions import Self + +from careamics.dataset.dataset_utils import reshape_array +from careamics.file_io.read import ReadFunc, read_tiff + + +class InMemoryImageStack: + """ + A class for extracting patches from an image stack that has been loaded into memory. + """ + + def __init__(self, source: Union[Path, Literal["array"]], data: NDArray): + self.source: Union[Path, Literal["array"]] = source + # data expected to be in SC(Z)YX shape, reason to use from_array constructor + self._data = data + self.data_shape: Sequence[int] = self._data.shape + + def extract_patch( + self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int] + ) -> NDArray: + if len(coords) != len(patch_size): + raise ValueError("Length of coords and extent must match.") + # TODO: test for 2D or 3D? + return self._data[ + ( + sample_idx, # type: ignore + ..., # type: ignore + *[slice(c, c + e) for c, e in zip(coords, patch_size)], # type: ignore + ) + ] + + @classmethod + def from_array(cls, data: NDArray, axes: str) -> Self: + data = reshape_array(data, axes) + return cls(source="array", data=data) + + @classmethod + def from_tiff(cls, path: Path, axes: str) -> Self: + data = read_tiff(path) + data = reshape_array(data, axes) + return cls(source=path, data=data) + + @classmethod + def from_custom_file_type( + cls, path: Path, axes: str, read_func: ReadFunc, **read_kwargs: Any + ) -> Self: + data = read_func(path, **read_kwargs) + data = reshape_array(data, axes) + return cls(source=path, data=data) diff --git a/src/careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py b/src/careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py new file mode 100644 index 00000000..a8e70c98 --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py @@ -0,0 +1,24 @@ +from collections.abc import Sequence +from pathlib import Path + +from numpy.typing import NDArray + + +class ZarrImageStack: + """ + A class for extracting patches from an image stack that is stored as a zarr array. + """ + + def __init__( + self, + source: Path, + # other args + ): + # Note: will probably need to store axes from metadata + # - transformation will have to happen in `extract_patch` + raise NotImplementedError("Not implemented yet.") + + def extract_patch( + self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int] + ) -> NDArray: + raise NotImplementedError("Not implemented yet.") diff --git a/src/careamics/dataset_ng/patch_extractor/patch_extractor.py b/src/careamics/dataset_ng/patch_extractor/patch_extractor.py new file mode 100644 index 00000000..de1a2147 --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/patch_extractor.py @@ -0,0 +1,91 @@ +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Protocol, TypedDict, Union + +from numpy.typing import NDArray +from typing_extensions import Self + +from careamics.file_io.read import ReadFunc + +from .image_stack import ImageStack, InMemoryImageStack + + +class PatchSpecs(TypedDict): + data_idx: int + sample_idx: int + coords: Sequence[int] + patch_size: Sequence[int] + + +class PatchExtractorConstructor(Protocol): + + # TODO: expand Union for new constructors, or just type hint as Any + def __call__( + self, + source: Union[Sequence[NDArray], Sequence[Path]], + **kwargs: Any, + ) -> "PatchExtractor": ... + + +class PatchExtractor: + """ + A class for extracting patches from multiple image stacks. + """ + + def __init__(self, data_readers: Sequence[ImageStack]): + self.image_stacks: list[ImageStack] = list(data_readers) + + @classmethod + def from_arrays(cls, source: Sequence[NDArray], *, axes: str) -> Self: + data_readers = [ + InMemoryImageStack.from_array(data=array, axes=axes) for array in source + ] + return cls(data_readers=data_readers) + + # TODO: rename to load_from_tiff_files? + # - to distiguish from possible pointer to files + @classmethod + def from_tiff_files(cls, source: Sequence[Path], *, axes: str) -> Self: + data_readers = [ + InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source + ] + return cls(data_readers=data_readers) + + # TODO: similar to tiff - rename to load_from_custom_file_type? + @classmethod + def from_custom_file_type( + cls, + source: Sequence[Path], + axes: str, + read_func: ReadFunc, + **read_kwargs, + ) -> Self: + data_readers = [ + InMemoryImageStack.from_custom_file_type( + path=path, + axes=axes, + read_func=read_func, + **read_kwargs, + ) + for path in source + ] + return cls(data_readers=data_readers) + + @classmethod + def from_zarr_files(cls, source, **kwargs) -> Self: + # TODO: will this create a ZarrImageStack for each array in the zarr file? + raise NotImplementedError("Reading from zarr has not been implemented.") + + def extract_patch( + self, + data_idx: int, + sample_idx: int, + coords: Sequence[int], + patch_size: Sequence[int], + ) -> NDArray: + return self.image_stacks[data_idx].extract_patch( + sample_idx=sample_idx, coords=coords, patch_size=patch_size + ) + + def extract_patches(self, patch_specs: Sequence[PatchSpecs]) -> list[NDArray]: + return [self.extract_patch(**patch_spec) for patch_spec in patch_specs] diff --git a/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py b/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py new file mode 100644 index 00000000..fc68c65c --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py @@ -0,0 +1,69 @@ +from collections.abc import Sequence +from pathlib import Path +from typing import Optional, Union + +from numpy.typing import NDArray + +from careamics.config import GeneralDataConfig +from careamics.config.support import SupportedData +from careamics.dataset_ng.patch_extractor import ( + PatchExtractor, + PatchExtractorConstructor, +) + + +def get_patch_extractor_constructor( + data_config: GeneralDataConfig, +) -> PatchExtractorConstructor: + if data_config.data_type == SupportedData.ARRAY: + return PatchExtractor.from_arrays + elif data_config.data_type == SupportedData.TIFF: + return PatchExtractor.from_tiff_files + elif data_config.data_type == SupportedData.CUSTOM: + return PatchExtractor.from_custom_file_type + else: + raise ValueError(f"Data type {data_config.data_type} is not supported.") + + +def create_patch_extractors( + data_config: GeneralDataConfig, + train_data: Union[Sequence[NDArray], Sequence[Path]], + val_data: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, + train_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, + val_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, + **kwargs, +) -> tuple[ + PatchExtractor, + Optional[PatchExtractor], + Optional[PatchExtractor], + Optional[PatchExtractor], +]: + + # get correct constructor + constructor = get_patch_extractor_constructor(data_config) + + # build key word args + constructor_kwargs = {"axes": data_config.axes, **kwargs} + + # --- train data extractor + train_patch_extractor: PatchExtractor = constructor( + source=train_data, **constructor_kwargs + ) + # --- additional data extractors + additional_patch_extractors: list[Union[PatchExtractor, None]] = [] + additional_data_sources = [val_data, train_data_target, val_data_target] + for data_source in additional_data_sources: + if data_source is not None: + additional_patch_extractor: Optional[PatchExtractor] = constructor( + source=data_source, **constructor_kwargs + ) + else: + additional_patch_extractor = None + additional_patch_extractors.append(additional_patch_extractor) + + return ( + train_patch_extractor, + additional_patch_extractors[0], + additional_patch_extractors[1], + additional_patch_extractors[2], + ) diff --git a/src/careamics/dataset_ng/patch_extractor_factory_demo.ipynb b/src/careamics/dataset_ng/patch_extractor_factory_demo.ipynb new file mode 100644 index 00000000..8b5549a8 --- /dev/null +++ b/src/careamics/dataset_ng/patch_extractor_factory_demo.ipynb @@ -0,0 +1,90 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from careamics.config import create_n2n_configuration\n", + "from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (\n", + " create_patch_extractors,\n", + ")\n", + "\n", + "rng = np.random.default_rng()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define example data\n", + "array1 = np.arange(36).reshape(1, 6, 6)\n", + "array2 = np.arange(50).reshape(2, 5, 5)\n", + "target1 = rng.integers(0, 1, size=array1.shape, endpoint=True)\n", + "target2 = rng.integers(0, 1, size=array2.shape, endpoint=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = create_n2n_configuration(\n", + " \"test_exp\",\n", + " data_type=\"array\",\n", + " axes=\"SYX\",\n", + " patch_size=(8, 8),\n", + " batch_size=1,\n", + " num_epochs=1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_inputs, train_targets, _, _ = create_patch_extractors(\n", + " config.data_config, [array1, array2], [target1, target2]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_inputs.extract_patch(data_idx=0, sample_idx=0, coords=(2, 2), patch_size=(3,3))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/careamics/dataset_ng/patching_strategies/__init__.py b/src/careamics/dataset_ng/patching_strategies/__init__.py new file mode 100644 index 00000000..330bb8fc --- /dev/null +++ b/src/careamics/dataset_ng/patching_strategies/__init__.py @@ -0,0 +1,6 @@ +__all__ = ["PatchSpecsGenerator", "RandomPatchSpecsGenerator"] + +from .patch_specs_generator import ( + PatchSpecsGenerator, + RandomPatchSpecsGenerator, +) diff --git a/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py b/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py new file mode 100644 index 00000000..bd9b727c --- /dev/null +++ b/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py @@ -0,0 +1,90 @@ +from collections.abc import Sequence +from typing import ParamSpec, Protocol + +import numpy as np +from numpy.typing import NDArray + +from ..patch_extractor import PatchSpecs + +P = ParamSpec("P") + + +class PatchSpecsGenerator(Protocol[P]): + + def generate( + self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs + ) -> list[PatchSpecs]: ... + + # Should return the number of patches that will be produced for a set of args + # Will be for mapped dataset length + def n_patches( + self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs + ) -> int: ... + + +class RandomPatchSpecsGenerator: + + def __init__(self, data_shapes: Sequence[Sequence[int]]): + self.data_shapes = data_shapes + + def generate(self, patch_size: Sequence[int], seed: int) -> list[PatchSpecs]: + rng = np.random.default_rng(seed=seed) + patch_specs: list[PatchSpecs] = [] + for data_idx, data_shape in enumerate(self.data_shapes): + + # shape on which data is patched + data_spatial_shape = data_shape[-len(patch_size) :] + + n_patches = self._n_patches_in_sample(patch_size, data_spatial_shape) + data_patch_specs = [ + PatchSpecs( + data_idx=data_idx, + sample_idx=sample_idx, + coords=tuple( + rng.integers( + np.zeros(len(patch_size), dtype=int), + np.array(data_spatial_shape) - np.array(patch_size), + endpoint=True, + ) + ), + patch_size=patch_size, + ) + for sample_idx in range(data_shape[0]) + for _ in range(n_patches) + ] + patch_specs.extend(data_patch_specs) + return patch_specs + + # NOTE: enerate and n_patches methods must have matching signatures + # as dictated by protocol + def n_patches(self, patch_size: Sequence[int], seed: int) -> int: + n_sample_patches: NDArray[np.int_] = np.array( + [ + self._n_patches_in_sample(patch_size, data_shape[-len(patch_size) :]) + for data_shape in self.data_shapes + ], + dtype=int, + ) + n_samples = np.array( + [data_shape[0] for data_shape in self.data_shapes], dtype=int + ) + n_data_patches = n_samples * n_sample_patches + return int(n_data_patches.sum()) + + @staticmethod + def _n_patches_in_sample( + patch_size: Sequence[int], spatial_shape: Sequence[int] + ) -> int: + if len(patch_size) != len(spatial_shape): + raise ValueError( + "Number of patch dimension do not match the number of spatial " + "dimensions." + ) + return int(np.ceil(np.prod(spatial_shape) / np.prod(patch_size))) + + +if __name__ == "__main__": + # testing mypy accepts protocol type + patch_specs_generator: PatchSpecsGenerator = RandomPatchSpecsGenerator( + [(1, 1, 6, 6), (1, 1, 4, 4)] + )