Skip to content

Commit

Permalink
Merge branch 'main' into jd/fix/fix_bmz_env
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Feb 10, 2025
2 parents 143123e + 2c25e2a commit 1214bea
Show file tree
Hide file tree
Showing 13 changed files with 625 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
rev: v1.8.0
hooks:
- id: numpydoc-validation
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*"
exclude: "^src/careamics/dataset_ng/.*|^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*"

# # jupyter linting and formatting
# - repo: https://github.com/nbQA-dev/nbQA
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ convention = "numpy"
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"setup.py" = ["D"]
# temporarily ignore docstrings in next generation dataset development
"src/careamics/dataset_ng/*" = ["D"]

[tool.black]
line-length = 88
Expand Down
135 changes: 135 additions & 0 deletions src/careamics/dataset_ng/demo_patch_extractor.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from careamics.dataset_ng.patch_extractor import PatchExtractor\n",
"from careamics.dataset_ng.patch_extractor.image_stack import InMemoryImageStack\n",
"from careamics.dataset_ng.patching_strategies import RandomPatchSpecsGenerator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"array = np.arange(36).reshape(6, 6)\n",
"image_stack = InMemoryImageStack.from_array(data=array, axes=\"YX\")\n",
"image_stack.extract_patch(sample_idx=0, coords=(2, 2), patch_size=(3, 3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define example data\n",
"array1 = np.arange(36).reshape(1, 6, 6)\n",
"array2 = np.arange(50).reshape(2, 5, 5)\n",
"target1 = rng.integers(0, 1, size=array1.shape, endpoint=True)\n",
"target2 = rng.integers(0, 1, size=array2.shape, endpoint=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(array1)\n",
"print(array2)\n",
"print(target1)\n",
"print(target2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define example readers\n",
"input_patch_extractor = PatchExtractor.from_arrays([array1, array2], axes=\"SYX\")\n",
"target_patch_extractor = PatchExtractor.from_arrays([target1, target2], axes=\"SYX\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate random patch specification\n",
"data_shapes = [\n",
" image_stack.data_shape for image_stack in input_patch_extractor.image_stacks\n",
"]\n",
"patch_specs_generator = RandomPatchSpecsGenerator(data_shapes)\n",
"patch_specs = patch_specs_generator.generate(patch_size=(2, 2), seed=42)\n",
"patch_specs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# extract a subset of patches\n",
"input_patch_extractor.extract_patches(patch_specs[7:11])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_patch_extractor.extract_patches(patch_specs[7:11])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
7 changes: 7 additions & 0 deletions src/careamics/dataset_ng/patch_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = [
"PatchExtractor",
"PatchExtractorConstructor",
"PatchSpecs",
]

from .patch_extractor import PatchExtractor, PatchExtractorConstructor, PatchSpecs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = [
"ImageStack",
"InMemoryImageStack",
"ZarrImageStack",
]

from .image_stack_protocol import ImageStack
from .in_memory_image_stack import InMemoryImageStack
from .zarr_image_stack import ZarrImageStack
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Literal, Protocol, Union

from numpy.typing import NDArray


class ImageStack(Protocol):
"""
An interface for extracting patches from an image stack.
Attributes
----------
source: Path or "array"
Origin of the image data.
data_shape: Sequence[int]
The shape of the data, it is expected to be in the order (SC(Z)YX).
"""

# TODO: not sure how compatible using Path will be for a zarr array
# (for a zarr array need to specify file path and internal zarr path)
source: Union[Path, Literal["array"]]
data_shape: Sequence[int]

def extract_patch(
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
) -> NDArray:
"""
Extracts a patch for a given sample within the image stack.
Parameters
----------
sample_idx: int
Sample index. The first dimension of the image data will be indexed at this
value.
coords: Sequence of int
The coordinates that define the start of a patch.
patch_size: Sequence of int
The size of the patch in each spatial dimension.
Returns
-------
numpy.ndarray
A patch of the image data from a particlular sample. It will have the
dimensions C(Z)YX.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal, Union

from numpy.typing import NDArray
from typing_extensions import Self

from careamics.dataset.dataset_utils import reshape_array
from careamics.file_io.read import ReadFunc, read_tiff


class InMemoryImageStack:
"""
A class for extracting patches from an image stack that has been loaded into memory.
"""

def __init__(self, source: Union[Path, Literal["array"]], data: NDArray):
self.source: Union[Path, Literal["array"]] = source
# data expected to be in SC(Z)YX shape, reason to use from_array constructor
self._data = data
self.data_shape: Sequence[int] = self._data.shape

def extract_patch(
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
) -> NDArray:
if len(coords) != len(patch_size):
raise ValueError("Length of coords and extent must match.")
# TODO: test for 2D or 3D?
return self._data[
(
sample_idx, # type: ignore
..., # type: ignore
*[slice(c, c + e) for c, e in zip(coords, patch_size)], # type: ignore
)
]

@classmethod
def from_array(cls, data: NDArray, axes: str) -> Self:
data = reshape_array(data, axes)
return cls(source="array", data=data)

@classmethod
def from_tiff(cls, path: Path, axes: str) -> Self:
data = read_tiff(path)
data = reshape_array(data, axes)
return cls(source=path, data=data)

@classmethod
def from_custom_file_type(
cls, path: Path, axes: str, read_func: ReadFunc, **read_kwargs: Any
) -> Self:
data = read_func(path, **read_kwargs)
data = reshape_array(data, axes)
return cls(source=path, data=data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import Sequence
from pathlib import Path

from numpy.typing import NDArray


class ZarrImageStack:
"""
A class for extracting patches from an image stack that is stored as a zarr array.
"""

def __init__(
self,
source: Path,
# other args
):
# Note: will probably need to store axes from metadata
# - transformation will have to happen in `extract_patch`
raise NotImplementedError("Not implemented yet.")

def extract_patch(
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
) -> NDArray:
raise NotImplementedError("Not implemented yet.")
91 changes: 91 additions & 0 deletions src/careamics/dataset_ng/patch_extractor/patch_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Protocol, TypedDict, Union

from numpy.typing import NDArray
from typing_extensions import Self

from careamics.file_io.read import ReadFunc

from .image_stack import ImageStack, InMemoryImageStack


class PatchSpecs(TypedDict):
data_idx: int
sample_idx: int
coords: Sequence[int]
patch_size: Sequence[int]


class PatchExtractorConstructor(Protocol):

# TODO: expand Union for new constructors, or just type hint as Any
def __call__(
self,
source: Union[Sequence[NDArray], Sequence[Path]],
**kwargs: Any,
) -> "PatchExtractor": ...


class PatchExtractor:
"""
A class for extracting patches from multiple image stacks.
"""

def __init__(self, data_readers: Sequence[ImageStack]):
self.image_stacks: list[ImageStack] = list(data_readers)

@classmethod
def from_arrays(cls, source: Sequence[NDArray], *, axes: str) -> Self:
data_readers = [
InMemoryImageStack.from_array(data=array, axes=axes) for array in source
]
return cls(data_readers=data_readers)

# TODO: rename to load_from_tiff_files?
# - to distiguish from possible pointer to files
@classmethod
def from_tiff_files(cls, source: Sequence[Path], *, axes: str) -> Self:
data_readers = [
InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source
]
return cls(data_readers=data_readers)

# TODO: similar to tiff - rename to load_from_custom_file_type?
@classmethod
def from_custom_file_type(
cls,
source: Sequence[Path],
axes: str,
read_func: ReadFunc,
**read_kwargs,
) -> Self:
data_readers = [
InMemoryImageStack.from_custom_file_type(
path=path,
axes=axes,
read_func=read_func,
**read_kwargs,
)
for path in source
]
return cls(data_readers=data_readers)

@classmethod
def from_zarr_files(cls, source, **kwargs) -> Self:
# TODO: will this create a ZarrImageStack for each array in the zarr file?
raise NotImplementedError("Reading from zarr has not been implemented.")

def extract_patch(
self,
data_idx: int,
sample_idx: int,
coords: Sequence[int],
patch_size: Sequence[int],
) -> NDArray:
return self.image_stacks[data_idx].extract_patch(
sample_idx=sample_idx, coords=coords, patch_size=patch_size
)

def extract_patches(self, patch_specs: Sequence[PatchSpecs]) -> list[NDArray]:
return [self.extract_patch(**patch_spec) for patch_spec in patch_specs]
Loading

0 comments on commit 1214bea

Please sign in to comment.