From ef15e20192c294cc82b9194ca29190ac9806d6fa Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 29 May 2024 08:00:52 -0700 Subject: [PATCH] [v3] Feature: Store open mode (#1911) * wip * feature(store): set open mode on store initialization --- src/zarr/abc/store.py | 27 ++++++++++++++++++++++++++- src/zarr/common.py | 1 + src/zarr/store/core.py | 10 ++++++++-- src/zarr/store/local.py | 8 ++++++-- src/zarr/store/memory.py | 9 +++++++-- src/zarr/store/remote.py | 10 +++++++++- src/zarr/testing/store.py | 37 ++++++++++++++++++++++++++++++++++--- tests/v3/conftest.py | 14 +++++++------- tests/v3/test_codecs.py | 2 +- tests/v3/test_store.py | 29 ++++++++++++++++------------- tests/v3/test_v2.py | 2 +- 11 files changed, 116 insertions(+), 33 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 7087706b33..e86fe5d07a 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -3,10 +3,31 @@ from typing import Protocol, runtime_checkable from zarr.buffer import Buffer -from zarr.common import BytesLike +from zarr.common import BytesLike, OpenMode class Store(ABC): + _mode: OpenMode + + def __init__(self, mode: OpenMode = "r"): + if mode not in ("r", "r+", "w", "w-", "a"): + raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'") + self._mode = mode + + @property + def mode(self) -> OpenMode: + """Access mode of the store.""" + return self._mode + + @property + def writeable(self) -> bool: + """Is the store writeable?""" + return self.mode in ("a", "w", "w-") + + def _check_writable(self) -> None: + if not self.writeable: + raise ValueError("store mode does not support writing") + @abstractmethod async def get( self, key: str, byte_range: tuple[int | None, int | None] | None = None @@ -147,6 +168,10 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ ... + def close(self) -> None: # noqa: B027 + """Close the store.""" + pass + @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/common.py b/src/zarr/common.py index 9d8315abc8..9527efbbce 100644 --- a/src/zarr/common.py +++ b/src/zarr/common.py @@ -27,6 +27,7 @@ Selection = slice | SliceSelection ZarrFormat = Literal[2, 3] JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...] +OpenMode = Literal["r", "r+", "a", "w", "w-"] def product(tup: ChunkCoords) -> int: diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index 4e7a7fcca1..abb08291df 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -5,6 +5,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.common import OpenMode from zarr.store.local import LocalStore @@ -60,13 +61,18 @@ def __eq__(self, other: Any) -> bool: StoreLike = Store | StorePath | Path | str -def make_store_path(store_like: StoreLike) -> StorePath: +def make_store_path(store_like: StoreLike, *, mode: OpenMode | None = None) -> StorePath: if isinstance(store_like, StorePath): + if mode is not None: + assert mode == store_like.store.mode return store_like elif isinstance(store_like, Store): + if mode is not None: + assert mode == store_like.mode return StorePath(store_like) elif isinstance(store_like, str): - return StorePath(LocalStore(Path(store_like))) + assert mode is not None + return StorePath(LocalStore(Path(store_like), mode=mode)) raise TypeError diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 50fe9701fc..40abe12932 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -7,7 +7,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer -from zarr.common import concurrent_map, to_thread +from zarr.common import OpenMode, concurrent_map, to_thread def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer: @@ -69,7 +69,8 @@ class LocalStore(Store): root: Path - def __init__(self, root: Path | str): + def __init__(self, root: Path | str, *, mode: OpenMode = "r"): + super().__init__(mode=mode) if isinstance(root, str): root = Path(root) assert isinstance(root, Path) @@ -117,6 +118,7 @@ async def get_partial_values( return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: + self._check_writable() assert isinstance(key, str) if isinstance(value, bytes | bytearray): # TODO: to support the v2 tests, we convert bytes to Buffer here @@ -127,6 +129,7 @@ async def set(self, key: str, value: Buffer) -> None: await to_thread(_put, path, value) async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + self._check_writable() args = [] for key, start, value in key_start_values: assert isinstance(key, str) @@ -138,6 +141,7 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes] await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def delete(self, key: str) -> None: + self._check_writable() path = self.root / key if path.is_dir(): # TODO: support deleting directories? shutil.rmtree? shutil.rmtree(path) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 5e438919cf..74bb5454fe 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -4,7 +4,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer -from zarr.common import concurrent_map +from zarr.common import OpenMode, concurrent_map from zarr.store.core import _normalize_interval_index @@ -17,7 +17,10 @@ class MemoryStore(Store): _store_dict: MutableMapping[str, Buffer] - def __init__(self, store_dict: MutableMapping[str, Buffer] | None = None): + def __init__( + self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r" + ): + super().__init__(mode=mode) self._store_dict = store_dict or {} def __str__(self) -> str: @@ -47,6 +50,7 @@ async def exists(self, key: str) -> bool: return key in self._store_dict async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + self._check_writable() assert isinstance(key, str) if isinstance(value, bytes | bytearray): # TODO: to support the v2 tests, we convert bytes to Buffer here @@ -62,6 +66,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None self._store_dict[key] = value async def delete(self, key: str) -> None: + self._check_writable() try: del self._store_dict[key] except KeyError: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index a3395459fd..3b086f0a03 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -4,6 +4,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.common import OpenMode from zarr.store.core import _dereference_path if TYPE_CHECKING: @@ -18,10 +19,14 @@ class RemoteStore(Store): root: UPath - def __init__(self, url: UPath | str, **storage_options: dict[str, Any]): + def __init__( + self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any] + ): import fsspec from upath import UPath + super().__init__(mode=mode) + if isinstance(url, str): self.root = UPath(url, **storage_options) else: @@ -29,6 +34,7 @@ def __init__(self, url: UPath | str, **storage_options: dict[str, Any]): len(storage_options) == 0 ), "If constructed with a UPath object, no additional storage_options are allowed." self.root = url.rstrip("/") + # test instantiate file system fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs) assert fs.__class__.async_impl, "FileSystem needs to support async operations." @@ -67,6 +73,7 @@ async def get( return value async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + self._check_writable() assert isinstance(key, str) fs, root = self._make_fs() path = _dereference_path(root, key) @@ -80,6 +87,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None await fs._pipe_file(path, value) async def delete(self, key: str) -> None: + self._check_writable() fs, root = self._make_fs() path = _dereference_path(root, key) if await fs._exists(path): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1c0ed93734..b317f383f6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import pytest @@ -31,13 +31,43 @@ def get(self, store: S, key: str) -> Buffer: raise NotImplementedError @pytest.fixture(scope="function") - def store(self) -> Store: - return self.store_cls() + def store_kwargs(self) -> dict[str, Any]: + return {"mode": "w"} + + @pytest.fixture(scope="function") + def store(self, store_kwargs: dict[str, Any]) -> Store: + return self.store_cls(**store_kwargs) def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) + def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None: + assert store.mode == "w", store.mode + assert store.writeable + + with pytest.raises(AttributeError): + store.mode = "w" # type: ignore + + # read-only + kwargs = {**store_kwargs, "mode": "r"} + read_store = self.store_cls(**kwargs) + assert read_store.mode == "r", read_store.mode + assert not read_store.writeable + + async def test_not_writable_store_raises(self, store_kwargs: dict[str, Any]) -> None: + kwargs = {**store_kwargs, "mode": "r"} + store = self.store_cls(**kwargs) + assert not store.writeable + + # set + with pytest.raises(ValueError): + await store.set("foo", Buffer.from_bytes(b"bar")) + + # delete + with pytest.raises(ValueError): + await store.delete("foo") + def test_store_repr(self, store: S) -> None: raise NotImplementedError @@ -72,6 +102,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: """ Ensure that data can be written to the store using the store.set method. """ + assert store.writeable data_buf = Buffer.from_bytes(data) await store.set(key, data_buf) observed = self.get(store, key) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 21dc58197e..6b58cce412 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -22,11 +22,11 @@ def parse_store( store: Literal["local", "memory", "remote"], path: str ) -> LocalStore | MemoryStore | RemoteStore: if store == "local": - return LocalStore(path) + return LocalStore(path, mode="w") if store == "memory": - return MemoryStore() + return MemoryStore(mode="w") if store == "remote": - return RemoteStore() + return RemoteStore(mode="w") raise AssertionError @@ -38,24 +38,24 @@ def path_type(request): # todo: harmonize this with local_store fixture @pytest.fixture def store_path(tmpdir): - store = LocalStore(str(tmpdir)) + store = LocalStore(str(tmpdir), mode="w") p = StorePath(store) return p @pytest.fixture(scope="function") def local_store(tmpdir): - return LocalStore(str(tmpdir)) + return LocalStore(str(tmpdir), mode="w") @pytest.fixture(scope="function") def remote_store(): - return RemoteStore() + return RemoteStore(mode="w") @pytest.fixture(scope="function") def memory_store(): - return MemoryStore() + return MemoryStore(mode="w") @pytest.fixture(scope="function") diff --git a/tests/v3/test_codecs.py b/tests/v3/test_codecs.py index a595b12494..251570f767 100644 --- a/tests/v3/test_codecs.py +++ b/tests/v3/test_codecs.py @@ -50,7 +50,7 @@ async def set(self, value: np.ndarray): @pytest.fixture def store() -> Iterator[Store]: - yield StorePath(MemoryStore()) + yield StorePath(MemoryStore(mode="w")) @pytest.fixture diff --git a/tests/v3/test_store.py b/tests/v3/test_store.py index 75438f8612..52882ea78c 100644 --- a/tests/v3/test_store.py +++ b/tests/v3/test_store.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import MutableMapping +from typing import Any import pytest @@ -10,7 +10,6 @@ from zarr.testing.store import StoreTests -@pytest.mark.parametrize("store_dict", (None, {})) class TestMemoryStore(StoreTests[MemoryStore]): store_cls = MemoryStore @@ -20,21 +19,25 @@ def set(self, store: MemoryStore, key: str, value: Buffer) -> None: def get(self, store: MemoryStore, key: str) -> Buffer: return store._store_dict[key] + @pytest.fixture(scope="function", params=[None, {}]) + def store_kwargs(self, request) -> dict[str, Any]: + return {"store_dict": request.param, "mode": "w"} + @pytest.fixture(scope="function") - def store(self, store_dict: MutableMapping[str, Buffer] | None): - return MemoryStore(store_dict=store_dict) + def store(self, store_kwargs: dict[str, Any]) -> MemoryStore: + return self.store_cls(**store_kwargs) def test_store_repr(self, store: MemoryStore) -> None: assert str(store) == f"memory://{id(store._store_dict)}" def test_store_supports_writes(self, store: MemoryStore) -> None: - assert True + assert store.supports_writes def test_store_supports_listing(self, store: MemoryStore) -> None: - assert True + assert store.supports_listing def test_store_supports_partial_writes(self, store: MemoryStore) -> None: - assert True + assert store.supports_partial_writes def test_list_prefix(self, store: MemoryStore) -> None: assert True @@ -52,21 +55,21 @@ def set(self, store: LocalStore, key: str, value: Buffer) -> None: parent.mkdir(parents=True) (store.root / key).write_bytes(value.to_bytes()) - @pytest.fixture(scope="function") - def store(self, tmpdir) -> LocalStore: - return self.store_cls(str(tmpdir)) + @pytest.fixture + def store_kwargs(self, tmpdir) -> dict[str, str]: + return {"root": str(tmpdir), "mode": "w"} def test_store_repr(self, store: LocalStore) -> None: assert str(store) == f"file://{store.root!s}" def test_store_supports_writes(self, store: LocalStore) -> None: - assert True + assert store.supports_writes def test_store_supports_partial_writes(self, store: LocalStore) -> None: - assert True + assert store.supports_partial_writes def test_store_supports_listing(self, store: LocalStore) -> None: - assert True + assert store.supports_listing def test_list_prefix(self, store: LocalStore) -> None: assert True diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 2a38dc8fdc..41555bbd26 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -10,7 +10,7 @@ @pytest.fixture def store() -> Iterator[Store]: - yield StorePath(MemoryStore()) + yield StorePath(MemoryStore(mode="w")) def test_simple(store: Store):