Skip to content

Commit

Permalink
Merge branch 'v3' into mypy-error-codes
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby authored May 30, 2024
2 parents 52351c6 + ef15e20 commit dd97cce
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 835 deletions.
31 changes: 28 additions & 3 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,34 @@
from typing import Protocol, runtime_checkable

from zarr.buffer import Buffer
from zarr.common import BytesLike
from zarr.common import BytesLike, OpenMode


class Store(ABC):
_mode: OpenMode

def __init__(self, mode: OpenMode = "r"):
if mode not in ("r", "r+", "w", "w-", "a"):
raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'")
self._mode = mode

@property
def mode(self) -> OpenMode:
"""Access mode of the store."""
return self._mode

@property
def writeable(self) -> bool:
"""Is the store writeable?"""
return self.mode in ("a", "w", "w-")

def _check_writable(self) -> None:
if not self.writeable:
raise ValueError("store mode does not support writing")

@abstractmethod
async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
"""Retrieve the value associated with a given key.
Expand All @@ -26,7 +47,7 @@ async def get(

@abstractmethod
async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.
Expand Down Expand Up @@ -147,6 +168,10 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
...

def close(self) -> None: # noqa: B027
"""Close the store."""
pass


@runtime_checkable
class ByteGetter(Protocol):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Selection = slice | SliceSelection
ZarrFormat = Literal[2, 3]
JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...]
OpenMode = Literal["r", "r+", "a", "w", "w-"]


def product(tup: ChunkCoords) -> int:
Expand Down
34 changes: 32 additions & 2 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.local import LocalStore


Expand Down Expand Up @@ -60,11 +61,40 @@ def __eq__(self, other: Any) -> bool:
StoreLike = Store | StorePath | Path | str


def make_store_path(store_like: StoreLike) -> StorePath:
def make_store_path(store_like: StoreLike, *, mode: OpenMode | None = None) -> StorePath:
if isinstance(store_like, StorePath):
if mode is not None:
assert mode == store_like.store.mode
return store_like
elif isinstance(store_like, Store):
if mode is not None:
assert mode == store_like.mode
return StorePath(store_like)
elif isinstance(store_like, str):
return StorePath(LocalStore(Path(store_like)))
assert mode is not None
return StorePath(LocalStore(Path(store_like), mode=mode))
raise TypeError


def _normalize_interval_index(
data: Buffer, interval: None | tuple[int | None, int | None]
) -> tuple[int, int]:
"""
Convert an implicit interval into an explicit start and length
"""
if interval is None:
start = 0
length = len(data)
else:
maybe_start, maybe_len = interval
if maybe_start is None:
start = 0
else:
start = maybe_start

if maybe_len is None:
length = len(data) - start
else:
length = maybe_len

return (start, length)
22 changes: 11 additions & 11 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map, to_thread
from zarr.common import OpenMode, concurrent_map, to_thread


def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:
def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer:
"""
Fetch a contiguous region of bytes from a file.
Expand Down Expand Up @@ -51,10 +51,8 @@ def _put(
path: Path,
value: Buffer,
start: int | None = None,
auto_mkdir: bool = True,
) -> int | None:
if auto_mkdir:
path.parent.mkdir(parents=True, exist_ok=True)
path.parent.mkdir(parents=True, exist_ok=True)
if start is not None:
with path.open("r+b") as f:
f.seek(start)
Expand All @@ -70,15 +68,14 @@ class LocalStore(Store):
supports_listing: bool = True

root: Path
auto_mkdir: bool

def __init__(self, root: Path | str, auto_mkdir: bool = True):
def __init__(self, root: Path | str, *, mode: OpenMode = "r"):
super().__init__(mode=mode)
if isinstance(root, str):
root = Path(root)
assert isinstance(root, Path)

self.root = root
self.auto_mkdir = auto_mkdir

def __str__(self) -> str:
return f"file://{self.root}"
Expand All @@ -90,7 +87,7 @@ def __eq__(self, other: object) -> bool:
return isinstance(other, type(self)) and self.root == other.root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
path = self.root / key
Expand All @@ -101,7 +98,7 @@ async def get(
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
"""
Read byte ranges from multiple keys.
Expand All @@ -121,16 +118,18 @@ async def get_partial_values(
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
self._check_writable()
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
value = Buffer.from_bytes(value)
if not isinstance(value, Buffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir)
await to_thread(_put, path, value)

async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
self._check_writable()
args = []
for key, start, value in key_start_values:
assert isinstance(key, str)
Expand All @@ -142,6 +141,7 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]
await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def delete(self, key: str) -> None:
self._check_writable()
path = self.root / key
if path.is_dir(): # TODO: support deleting directories? shutil.rmtree?
shutil.rmtree(path)
Expand Down
19 changes: 12 additions & 7 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map
from zarr.common import OpenMode, concurrent_map
from zarr.store.core import _normalize_interval_index


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
Expand All @@ -16,7 +17,10 @@ class MemoryStore(Store):

_store_dict: MutableMapping[str, Buffer]

def __init__(self, store_dict: MutableMapping[str, Buffer] | None = None):
def __init__(
self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r"
):
super().__init__(mode=mode)
self._store_dict = store_dict or {}

def __str__(self) -> str:
Expand All @@ -26,19 +30,18 @@ def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
try:
value = self._store_dict[key]
if byte_range is not None:
value = value[byte_range[0] : byte_range[1]]
return value
start, length = _normalize_interval_index(value, byte_range)
return value[start : start + length]
except KeyError:
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
vals = await concurrent_map(key_ranges, self.get, limit=None)
return vals
Expand All @@ -47,6 +50,7 @@ async def exists(self, key: str) -> bool:
return key in self._store_dict

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
Expand All @@ -62,6 +66,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
self._store_dict[key] = value

async def delete(self, key: str) -> None:
self._check_writable()
try:
del self._store_dict[key]
except KeyError:
Expand Down
12 changes: 10 additions & 2 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.core import _dereference_path

if TYPE_CHECKING:
Expand All @@ -18,17 +19,22 @@ class RemoteStore(Store):

root: UPath

def __init__(self, url: UPath | str, **storage_options: dict[str, Any]):
def __init__(
self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any]
):
import fsspec
from upath import UPath

super().__init__(mode=mode)

if isinstance(url, str):
self.root = UPath(url, **storage_options)
else:
assert (
len(storage_options) == 0
), "If constructed with a UPath object, no additional storage_options are allowed."
self.root = url.rstrip("/")

# test instantiate file system
fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs)
assert fs.__class__.async_impl, "FileSystem needs to support async operations."
Expand All @@ -49,7 +55,7 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]:
return fs, root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
fs, root = self._make_fs()
Expand All @@ -67,6 +73,7 @@ async def get(
return value

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
fs, root = self._make_fs()
path = _dereference_path(root, key)
Expand All @@ -80,6 +87,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
await fs._pipe_file(path, value)

async def delete(self, key: str) -> None:
self._check_writable()
fs, root = self._make_fs()
path = _dereference_path(root, key)
if await fs._exists(path):
Expand Down
Loading

0 comments on commit dd97cce

Please sign in to comment.