Skip to content

Commit

Permalink
Fix generic typing in zarr.codecs
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby committed May 10, 2024
1 parent 666a8b9 commit a25bad6
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 34 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,6 @@ check_untyped_defs = false
module = [
"zarr.v2.*",
"zarr.abc.codec",
"zarr.codecs.bytes",
"zarr.codecs.pipeline",
"zarr.codecs.sharding",
"zarr.codecs.transpose",
"zarr.array_v2",
"zarr.array",
"zarr.sync",
Expand Down
9 changes: 5 additions & 4 deletions src/zarr/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from enum import Enum
import sys

from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayBytesCodec
from zarr.codecs.registry import register_codec
Expand Down Expand Up @@ -60,7 +61,7 @@ def evolve(self, array_spec: ArraySpec) -> Self:
)
return self

def _get_byteorder(self, array: np.ndarray) -> Endian:
def _get_byteorder(self, array: npt.NDArray[Any]) -> Endian:
if array.dtype.byteorder == "<":
return Endian.little
elif array.dtype.byteorder == ">":
Expand All @@ -73,7 +74,7 @@ async def decode(
chunk_bytes: BytesLike,
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
if chunk_spec.dtype.itemsize > 0:
if self.endian == Endian.little:
prefix = "<"
Expand All @@ -93,7 +94,7 @@ async def decode(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
_chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand Down
12 changes: 6 additions & 6 deletions src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable
import numpy as np
from typing import TYPE_CHECKING, Any, Iterable
import numpy.typing as npt
from dataclasses import dataclass
from warnings import warn

Expand Down Expand Up @@ -152,7 +152,7 @@ async def decode(
chunk_bytes: BytesLike,
array_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
(
aa_codecs_with_spec,
ab_codec_with_spec,
Expand All @@ -176,7 +176,7 @@ async def decode_partial(
selection: SliceSelection,
chunk_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
assert self.supports_partial_decode
assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin)
return await self.array_bytes_codec.decode_partial(
Expand All @@ -185,7 +185,7 @@ async def decode_partial(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
array_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand Down Expand Up @@ -222,7 +222,7 @@ async def encode(
async def encode_partial(
self,
store_path: StorePath,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
selection: SliceSelection,
chunk_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
Expand Down
32 changes: 18 additions & 14 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Iterable, Mapping, NamedTuple, Union
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Union, Optional
from dataclasses import dataclass, replace
from functools import lru_cache


import numpy as np
import numpy.typing as npt
from zarr.abc.codec import (
Codec,
ArrayBytesCodec,
Expand All @@ -18,7 +19,9 @@
from zarr.codecs.registry import register_codec
from zarr.common import (
ArraySpec,
BytesLike,
ChunkCoordsLike,
ChunkCoords,
concurrent_map,
parse_enum,
parse_named_configuration,
Expand All @@ -39,14 +42,12 @@
)

if TYPE_CHECKING:
from typing import Awaitable, Callable, Dict, Iterator, List, Optional, Set, Tuple
from typing import Awaitable, Callable, Dict, Iterator, List, Set, Tuple
from typing_extensions import Self

from zarr.store import StorePath
from zarr.common import (
JSON,
ChunkCoords,
BytesLike,
SliceSelection,
)
from zarr.config import RuntimeConfiguration
Expand All @@ -65,7 +66,7 @@ def parse_index_location(data: JSON) -> ShardingCodecIndexLocation:

class _ShardIndex(NamedTuple):
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
offsets_and_lengths: np.ndarray
offsets_and_lengths: npt.NDArray[np.uint64]

@property
def chunks_per_shard(self) -> ChunkCoords:
Expand Down Expand Up @@ -126,7 +127,10 @@ def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardIndex:
return cls(offsets_and_lengths)


class _ShardProxy(Mapping):
_ShardMapping = Mapping[ChunkCoords, Optional[BytesLike]]


class _ShardProxy(_ShardMapping):
index: _ShardIndex
buf: BytesLike

Expand Down Expand Up @@ -175,7 +179,7 @@ def merge_with_morton_order(
cls,
chunks_per_shard: ChunkCoords,
tombstones: Set[ChunkCoords],
*shard_dicts: Mapping[ChunkCoords, BytesLike],
*shard_dicts: _ShardMapping,
) -> _ShardBuilder:
obj = cls.create_empty(chunks_per_shard)
for chunk_coords in morton_order_iter(chunks_per_shard):
Expand Down Expand Up @@ -303,7 +307,7 @@ async def decode(
shard_bytes: BytesLike,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
# print("decode")
shard_shape = shard_spec.shape
chunk_shape = self.chunk_shape
Expand Down Expand Up @@ -353,7 +357,7 @@ async def decode_partial(
selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
shard_shape = shard_spec.shape
chunk_shape = self.chunk_shape
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
Expand All @@ -375,7 +379,7 @@ async def decode_partial(
all_chunk_coords = set(chunk_coords for chunk_coords, _, _ in indexed_chunks)

# reading bytes of all requested chunks
shard_dict: Mapping[ChunkCoords, BytesLike] = {}
shard_dict: _ShardMapping = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(store_path, chunks_per_shard)
Expand Down Expand Up @@ -423,7 +427,7 @@ async def _read_chunk(
out_selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
out: np.ndarray,
out: npt.NDArray[Any],
) -> None:
chunk_spec = self._get_chunk_spec(shard_spec)
chunk_bytes = shard_dict.get(chunk_coords, None)
Expand All @@ -436,7 +440,7 @@ async def _read_chunk(

async def encode(
self,
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand All @@ -453,7 +457,7 @@ async def encode(
)

async def _write_chunk(
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
chunk_coords: ChunkCoords,
chunk_selection: SliceSelection,
out_selection: SliceSelection,
Expand Down Expand Up @@ -498,7 +502,7 @@ async def _write_chunk(
async def encode_partial(
self,
store_path: StorePath,
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
Expand Down
12 changes: 6 additions & 6 deletions src/zarr/codecs/transpose.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Iterable, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union, cast

from dataclasses import dataclass, replace

Expand All @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from typing_extensions import Self

import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec
from zarr.codecs.registry import register_codec
Expand Down Expand Up @@ -75,10 +75,10 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:

async def decode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
inverse_order = [0] * chunk_spec.ndim
for x, i in enumerate(self.order):
inverse_order[x] = i
Expand All @@ -87,10 +87,10 @@ async def decode(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
chunk_array = chunk_array.transpose(self.order)
return chunk_array

Expand Down

0 comments on commit a25bad6

Please sign in to comment.