From fc7fa4f30dc376a228cce08a7078378a1fe946b5 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Tue, 28 May 2024 23:52:30 +0200 Subject: [PATCH 1/2] [V3] Expand store tests (#1900) * Fill in some test methods with NotImplementedError to force implementations to implement them; make StoreTests generic w.r.t. the store class being tested; update store.get abc to match actual type signature * remove auto_mkdir from LocalStore; add set and get methods to StoreTests class * fix: use from_bytes method on buffer * fix: use Buffer instead of bytes for store tests * docstrings, add some Nones to test_get_partial_values; normalize function signatures --- src/zarr/abc/store.py | 4 +- src/zarr/store/core.py | 24 ++ src/zarr/store/local.py | 16 +- src/zarr/store/memory.py | 10 +- src/zarr/store/remote.py | 2 +- src/zarr/testing/store.py | 125 ++++-- tests/v3/test_store.py | 816 ++------------------------------------ 7 files changed, 179 insertions(+), 818 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index fee5422e9e..7087706b33 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -9,7 +9,7 @@ class Store(ABC): @abstractmethod async def get( - self, key: str, byte_range: tuple[int, int | None] | None = None + self, key: str, byte_range: tuple[int | None, int | None] | None = None ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -26,7 +26,7 @@ async def get( @abstractmethod async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int, int]]] + self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index 31cce65095..4e7a7fcca1 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -68,3 +68,27 @@ def make_store_path(store_like: StoreLike) -> StorePath: elif isinstance(store_like, str): return StorePath(LocalStore(Path(store_like))) raise TypeError + + +def _normalize_interval_index( + data: Buffer, interval: None | tuple[int | None, int | None] +) -> tuple[int, int]: + """ + Convert an implicit interval into an explicit start and length + """ + if interval is None: + start = 0 + length = len(data) + else: + maybe_start, maybe_len = interval + if maybe_start is None: + start = 0 + else: + start = maybe_start + + if maybe_len is None: + length = len(data) - start + else: + length = maybe_len + + return (start, length) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 64eb8632b9..50fe9701fc 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -10,7 +10,7 @@ from zarr.common import concurrent_map, to_thread -def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer: +def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer: """ Fetch a contiguous region of bytes from a file. @@ -51,10 +51,8 @@ def _put( path: Path, value: Buffer, start: int | None = None, - auto_mkdir: bool = True, ) -> int | None: - if auto_mkdir: - path.parent.mkdir(parents=True, exist_ok=True) + path.parent.mkdir(parents=True, exist_ok=True) if start is not None: with path.open("r+b") as f: f.seek(start) @@ -70,15 +68,13 @@ class LocalStore(Store): supports_listing: bool = True root: Path - auto_mkdir: bool - def __init__(self, root: Path | str, auto_mkdir: bool = True): + def __init__(self, root: Path | str): if isinstance(root, str): root = Path(root) assert isinstance(root, Path) self.root = root - self.auto_mkdir = auto_mkdir def __str__(self) -> str: return f"file://{self.root}" @@ -90,7 +86,7 @@ def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root async def get( - self, key: str, byte_range: tuple[int, int | None] | None = None + self, key: str, byte_range: tuple[int | None, int | None] | None = None ) -> Buffer | None: assert isinstance(key, str) path = self.root / key @@ -101,7 +97,7 @@ async def get( return None async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int, int]]] + self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] ) -> list[Buffer | None]: """ Read byte ranges from multiple keys. @@ -128,7 +124,7 @@ async def set(self, key: str, value: Buffer) -> None: if not isinstance(value, Buffer): raise TypeError("LocalStore.set(): `value` must a Buffer instance") path = self.root / key - await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir) + await to_thread(_put, path, value) async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: args = [] diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index c6e838417e..5e438919cf 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -5,6 +5,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer from zarr.common import concurrent_map +from zarr.store.core import _normalize_interval_index # TODO: this store could easily be extended to wrap any MutableMapping store from v2 @@ -26,19 +27,18 @@ def __repr__(self) -> str: return f"MemoryStore({str(self)!r})" async def get( - self, key: str, byte_range: tuple[int, int | None] | None = None + self, key: str, byte_range: tuple[int | None, int | None] | None = None ) -> Buffer | None: assert isinstance(key, str) try: value = self._store_dict[key] - if byte_range is not None: - value = value[byte_range[0] : byte_range[1]] - return value + start, length = _normalize_interval_index(value, byte_range) + return value[start : start + length] except KeyError: return None async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int, int]]] + self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] ) -> list[Buffer | None]: vals = await concurrent_map(key_ranges, self.get, limit=None) return vals diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 8058c61035..a3395459fd 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -49,7 +49,7 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]: return fs, root async def get( - self, key: str, byte_range: tuple[int, int | None] | None = None + self, key: str, byte_range: tuple[int | None, int | None] | None = None ) -> Buffer | None: assert isinstance(key, str) fs, root = self._make_fs() diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1e6fe09a9f..1c0ed93734 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,61 +1,130 @@ +from typing import Generic, TypeVar + import pytest from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.store.core import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal +S = TypeVar("S", bound=Store) + + +class StoreTests(Generic[S]): + store_cls: type[S] -class StoreTests: - store_cls: type[Store] + def set(self, store: S, key: str, value: Buffer) -> None: + """ + Insert a value into a storage backend, with a specific key. + This should not not use any store methods. Bypassing the store methods allows them to be + tested. + """ + raise NotImplementedError + + def get(self, store: S, key: str) -> Buffer: + """ + Retrieve a value from a storage backend, by key. + This should not not use any store methods. Bypassing the store methods allows them to be + tested. + """ + + raise NotImplementedError @pytest.fixture(scope="function") def store(self) -> Store: return self.store_cls() - def test_store_type(self, store: Store) -> None: + def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) - def test_store_repr(self, store: Store) -> None: - assert repr(store) + def test_store_repr(self, store: S) -> None: + raise NotImplementedError + + def test_store_supports_writes(self, store: S) -> None: + raise NotImplementedError - def test_store_capabilities(self, store: Store) -> None: - assert store.supports_writes - assert store.supports_partial_writes - assert store.supports_listing + def test_store_supports_partial_writes(self, store: S) -> None: + raise NotImplementedError + + def test_store_supports_listing(self, store: S) -> None: + raise NotImplementedError @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) - async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None: - await store.set(key, Buffer.from_bytes(data)) - assert_bytes_equal(await store.get(key), data) - - @pytest.mark.parametrize("key", ["foo/c/0"]) + @pytest.mark.parametrize("byte_range", (None, (0, None), (1, None), (1, 2), (None, 1))) + async def test_get( + self, store: S, key: str, data: bytes, byte_range: None | tuple[int | None, int | None] + ) -> None: + """ + Ensure that data can be read from the store using the store.get method. + """ + data_buf = Buffer.from_bytes(data) + self.set(store, key, data_buf) + observed = await store.get(key, byte_range=byte_range) + start, length = _normalize_interval_index(data_buf, interval=byte_range) + expected = data_buf[start : start + length] + assert_bytes_equal(observed, expected) + + @pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) - async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None: + async def test_set(self, store: S, key: str, data: bytes) -> None: + """ + Ensure that data can be written to the store using the store.set method. + """ + data_buf = Buffer.from_bytes(data) + await store.set(key, data_buf) + observed = self.get(store, key) + assert_bytes_equal(observed, data_buf) + + @pytest.mark.parametrize( + "key_ranges", + ( + [], + [("zarr.json", (0, 1))], + [("c/0", (0, 1)), ("zarr.json", (0, None))], + [("c/0/0", (0, 1)), ("c/0/1", (None, 2)), ("c/0/2", (0, 3))], + ), + ) + async def test_get_partial_values( + self, store: S, key_ranges: list[tuple[str, tuple[int | None, int | None]]] + ) -> None: # put all of the data - await store.set(key, Buffer.from_bytes(data)) + for key, _ in key_ranges: + self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8"))) + # read back just part of it - vals = await store.get_partial_values([(key, (0, 2))]) - assert_bytes_equal(vals[0], data[0:2]) + observed_maybe = await store.get_partial_values(key_ranges=key_ranges) + + observed: list[Buffer] = [] + expected: list[Buffer] = [] + + for obs in observed_maybe: + assert obs is not None + observed.append(obs) + + for idx in range(len(observed)): + key, byte_range = key_ranges[idx] + result = await store.get(key, byte_range=byte_range) + assert result is not None + expected.append(result) - # read back multiple parts of it at once - vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))]) - assert_bytes_equal(vals[0], data[0:2]) - assert_bytes_equal(vals[1], data[2:4]) + assert all( + obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) + ) - async def test_exists(self, store: Store) -> None: + async def test_exists(self, store: S) -> None: assert not await store.exists("foo") await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) assert await store.exists("foo/zarr.json") - async def test_delete(self, store: Store) -> None: + async def test_delete(self, store: S) -> None: await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) assert await store.exists("foo/zarr.json") await store.delete("foo/zarr.json") assert not await store.exists("foo/zarr.json") - async def test_list(self, store: Store) -> None: + async def test_list(self, store: S) -> None: assert [k async for k in store.list()] == [] await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) keys = [k async for k in store.list()] @@ -69,11 +138,11 @@ async def test_list(self, store: Store) -> None: f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little")) ) - async def test_list_prefix(self, store: Store) -> None: + async def test_list_prefix(self, store: S) -> None: # TODO: we currently don't use list_prefix anywhere - pass + raise NotImplementedError - async def test_list_dir(self, store: Store) -> None: + async def test_list_dir(self, store: S) -> None: assert [k async for k in store.list_dir("")] == [] assert [k async for k in store.list_dir("foo")] == [] await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) diff --git a/tests/v3/test_store.py b/tests/v3/test_store.py index 9bdf2f5a8f..75438f8612 100644 --- a/tests/v3/test_store.py +++ b/tests/v3/test_store.py @@ -1,800 +1,72 @@ from __future__ import annotations -from pathlib import Path +from collections.abc import MutableMapping import pytest +from zarr.buffer import Buffer from zarr.store.local import LocalStore from zarr.store.memory import MemoryStore from zarr.testing.store import StoreTests -from zarr.testing.utils import assert_bytes_equal -@pytest.mark.parametrize("auto_mkdir", (True, False)) -def test_local_store_init(tmpdir, auto_mkdir: bool) -> None: - tmpdir_str = str(tmpdir) - tmpdir_path = Path(tmpdir_str) - store = LocalStore(root=tmpdir_str, auto_mkdir=auto_mkdir) - - assert store.root == tmpdir_path - assert store.auto_mkdir == auto_mkdir - - # ensure that str and pathlib.Path get normalized to the same output - assert store == LocalStore(root=tmpdir_path, auto_mkdir=auto_mkdir) - - store_str = f"file://{tmpdir_str}" - assert str(store) == store_str - assert repr(store) == f"LocalStore({store_str!r})" - - -@pytest.mark.parametrize("byte_range", (None, (0, None), (1, None), (1, 2), (None, 1))) -async def test_local_store_get( - local_store, byte_range: None | tuple[int | None, int | None] -) -> None: - payload = b"\x01\x02\x03\x04" - object_name = "foo" - (local_store.root / object_name).write_bytes(payload) - observed = await local_store.get(object_name, byte_range=byte_range) - - if byte_range is None: - start = 0 - length = len(payload) - else: - maybe_start, maybe_len = byte_range - if maybe_start is None: - start = 0 - else: - start = maybe_start - - if maybe_len is None: - length = len(payload) - start - else: - length = maybe_len - - expected = payload[start : start + length] - assert_bytes_equal(observed, expected) - - # test that getting from a file that doesn't exist returns None - assert await local_store.get(object_name + "_absent", byte_range=byte_range) is None - - -@pytest.mark.parametrize( - "key_ranges", - ( - [], - [("key_0", (0, 1))], - [("dir/key_0", (0, 1)), ("key_1", (0, 2))], - [("key_0", (0, 1)), ("key_1", (0, 2)), ("key_1", (0, 2))], - ), -) -async def test_local_store_get_partial( - tmpdir, key_ranges: tuple[list[tuple[str, tuple[int, int]]]] -) -> None: - store = LocalStore(str(tmpdir), auto_mkdir=True) - # use the utf-8 encoding of the key as the bytes - for key, _ in key_ranges: - payload = bytes(key, encoding="utf-8") - target_path: Path = store.root / key - # create the parent directories - target_path.parent.mkdir(parents=True, exist_ok=True) - # write bytes - target_path.write_bytes(payload) - - results = await store.get_partial_values(key_ranges) - for idx, observed in enumerate(results): - key, byte_range = key_ranges[idx] - expected = await store.get(key, byte_range=byte_range) - assert_bytes_equal(observed, expected) - - -@pytest.mark.parametrize("path", ("foo", "foo/bar")) -@pytest.mark.parametrize("auto_mkdir", (True, False)) -async def test_local_store_set(tmpdir, path: str, auto_mkdir: bool) -> None: - store = LocalStore(str(tmpdir), auto_mkdir=auto_mkdir) - payload = b"\x01\x02\x03\x04" - - if "/" in path and not auto_mkdir: - with pytest.raises(FileNotFoundError): - await store.set(path, payload) - else: - x = await store.set(path, payload) - - # this method should not return anything - assert x is None - - assert (store.root / path).read_bytes() == payload - - -# import zarr -# from zarr._storage.store import _get_hierarchy_metadata, v3_api_available, StorageTransformer -# from zarr._storage.v3_storage_transformers import ( -# DummyStorageTransfomer, -# ShardingStorageTransformer, -# v3_sharding_available, -# ) -# from zarr.core import Array -# from zarr.meta import _default_entry_point_metadata_v3 -# from zarr.storage import ( -# atexit_rmglob, -# atexit_rmtree, -# data_root, -# default_compressor, -# getsize, -# init_array, -# meta_root, -# normalize_store_arg, -# ) -# from zarr._storage.v3 import ( -# ABSStoreV3, -# ConsolidatedMetadataStoreV3, -# DBMStoreV3, -# DirectoryStoreV3, -# FSStoreV3, -# KVStore, -# KVStoreV3, -# LMDBStoreV3, -# LRUStoreCacheV3, -# MemoryStoreV3, -# MongoDBStoreV3, -# RedisStoreV3, -# SQLiteStoreV3, -# StoreV3, -# ZipStoreV3, -# ) -# from .util import CountingDictV3, have_fsspec, skip_test_env_var, mktemp - -# # pytest will fail to run if the following fixtures aren't imported here -# from .test_storage import StoreTests as _StoreTests -# from .test_storage import TestABSStore as _TestABSStore -# from .test_storage import TestConsolidatedMetadataStore as _TestConsolidatedMetadataStore -# from .test_storage import TestDBMStore as _TestDBMStore -# from .test_storage import TestDBMStoreBerkeleyDB as _TestDBMStoreBerkeleyDB -# from .test_storage import TestDBMStoreDumb as _TestDBMStoreDumb -# from .test_storage import TestDBMStoreGnu as _TestDBMStoreGnu -# from .test_storage import TestDBMStoreNDBM as _TestDBMStoreNDBM -# from .test_storage import TestDirectoryStore as _TestDirectoryStore -# from .test_storage import TestFSStore as _TestFSStore -# from .test_storage import TestLMDBStore as _TestLMDBStore -# from .test_storage import TestLRUStoreCache as _TestLRUStoreCache -# from .test_storage import TestMemoryStore as _TestMemoryStore -# from .test_storage import TestSQLiteStore as _TestSQLiteStore -# from .test_storage import TestSQLiteStoreInMemory as _TestSQLiteStoreInMemory -# from .test_storage import TestZipStore as _TestZipStore -# from .test_storage import dimension_separator_fixture, s3, skip_if_nested_chunks - - -# pytestmark = pytest.mark.skipif(not v3_api_available, reason="v3 api is not available") - - -# @pytest.fixture( -# params=[ -# (None, "/"), -# (".", "."), -# ("/", "/"), -# ] -# ) -# def dimension_separator_fixture_v3(request): -# return request.param - - -# class DummyStore: -# # contains all methods expected of Mutable Mapping - -# def keys(self): -# """keys""" - -# def values(self): -# """values""" - -# def get(self, value, default=None): -# """get""" - -# def __setitem__(self, key, value): -# """__setitem__""" - -# def __getitem__(self, key): -# """__getitem__""" - -# def __delitem__(self, key): -# """__delitem__""" - -# def __contains__(self, key): -# """__contains__""" - - -# class InvalidDummyStore: -# # does not contain expected methods of a MutableMapping - -# def keys(self): -# """keys""" - - -# def test_ensure_store_v3(): -# class InvalidStore: -# pass - -# with pytest.raises(ValueError): -# StoreV3._ensure_store(InvalidStore()) - -# # cannot initialize with a store from a different Zarr version -# with pytest.raises(ValueError): -# StoreV3._ensure_store(KVStore(dict())) - -# assert StoreV3._ensure_store(None) is None - -# # class with all methods of a MutableMapping will become a KVStoreV3 -# assert isinstance(StoreV3._ensure_store(DummyStore), KVStoreV3) - -# with pytest.raises(ValueError): -# # does not have the methods expected of a MutableMapping -# StoreV3._ensure_store(InvalidDummyStore) - - -# def test_valid_key(): -# store = KVStoreV3(dict) - -# # only ascii keys are valid -# assert not store._valid_key(5) -# assert not store._valid_key(2.8) - -# for key in store._valid_key_characters: -# assert store._valid_key(key) - -# # other characters not in store._valid_key_characters are not allowed -# assert not store._valid_key("*") -# assert not store._valid_key("~") -# assert not store._valid_key("^") - - -# def test_validate_key(): -# store = KVStoreV3(dict) - -# # zarr.json is a valid key -# store._validate_key("zarr.json") -# # but other keys not starting with meta/ or data/ are not -# with pytest.raises(ValueError): -# store._validate_key("zar.json") - -# # valid ascii keys -# for valid in [ -# meta_root + "arr1.array.json", -# data_root + "arr1.array.json", -# meta_root + "subfolder/item_1-0.group.json", -# ]: -# store._validate_key(valid) -# # but otherwise valid keys cannot end in / -# with pytest.raises(ValueError): -# assert store._validate_key(valid + "/") - -# for invalid in [0, "*", "~", "^", "&"]: -# with pytest.raises(ValueError): -# store._validate_key(invalid) - - -# class StoreV3Tests(_StoreTests): - -# version = 3 -# root = meta_root - -# def test_getsize(self): -# # TODO: determine proper getsize() behavior for v3 -# # Currently returns the combined size of entries under -# # meta/root/path and data/root/path. -# # Any path not under meta/root/ or data/root/ (including zarr.json) -# # returns size 0. - -# store = self.create_store() -# if isinstance(store, dict) or hasattr(store, "getsize"): -# assert 0 == getsize(store, "zarr.json") -# store[meta_root + "foo/a"] = b"x" -# assert 1 == getsize(store) -# assert 1 == getsize(store, "foo") -# store[meta_root + "foo/b"] = b"x" -# assert 2 == getsize(store, "foo") -# assert 1 == getsize(store, "foo/b") -# store[meta_root + "bar/a"] = b"yy" -# assert 2 == getsize(store, "bar") -# store[data_root + "bar/a"] = b"zzz" -# assert 5 == getsize(store, "bar") -# store[data_root + "baz/a"] = b"zzz" -# assert 3 == getsize(store, "baz") -# assert 10 == getsize(store) -# store[data_root + "quux"] = array.array("B", b"zzzz") -# assert 14 == getsize(store) -# assert 4 == getsize(store, "quux") -# store[data_root + "spong"] = np.frombuffer(b"zzzzz", dtype="u1") -# assert 19 == getsize(store) -# assert 5 == getsize(store, "spong") -# store.close() - -# def test_init_array(self, dimension_separator_fixture_v3): - -# pass_dim_sep, want_dim_sep = dimension_separator_fixture_v3 - -# store = self.create_store() -# path = "arr1" -# transformer = DummyStorageTransfomer( -# "dummy_type", test_value=DummyStorageTransfomer.TEST_CONSTANT -# ) -# init_array( -# store, -# path=path, -# shape=1000, -# chunks=100, -# dimension_separator=pass_dim_sep, -# storage_transformers=[transformer], -# ) - -# # check metadata -# mkey = meta_root + path + ".array.json" -# assert mkey in store -# meta = store._metadata_class.decode_array_metadata(store[mkey]) -# assert (1000,) == meta["shape"] -# assert (100,) == meta["chunk_grid"]["chunk_shape"] -# assert np.dtype(None) == meta["data_type"] -# assert default_compressor == meta["compressor"] -# assert meta["fill_value"] is None -# # Missing MUST be assumed to be "/" -# assert meta["chunk_grid"]["separator"] is want_dim_sep -# assert len(meta["storage_transformers"]) == 1 -# assert isinstance(meta["storage_transformers"][0], DummyStorageTransfomer) -# assert meta["storage_transformers"][0].test_value == DummyStorageTransfomer.TEST_CONSTANT -# store.close() - -# def test_list_prefix(self): - -# store = self.create_store() -# path = "arr1" -# init_array(store, path=path, shape=1000, chunks=100) - -# expected = [meta_root + "arr1.array.json", "zarr.json"] -# assert sorted(store.list_prefix("")) == expected - -# expected = [meta_root + "arr1.array.json"] -# assert sorted(store.list_prefix(meta_root.rstrip("/"))) == expected - -# # cannot start prefix with '/' -# with pytest.raises(ValueError): -# store.list_prefix(prefix="/" + meta_root.rstrip("/")) - -# def test_equal(self): -# store = self.create_store() -# assert store == store - -# def test_rename_nonexisting(self): -# store = self.create_store() -# if store.is_erasable(): -# with pytest.raises(ValueError): -# store.rename("a", "b") -# else: -# with pytest.raises(NotImplementedError): -# store.rename("a", "b") - -# def test_get_partial_values(self): -# store = self.create_store() -# store.supports_efficient_get_partial_values in [True, False] -# store[data_root + "foo"] = b"abcdefg" -# store[data_root + "baz"] = b"z" -# assert [b"a"] == store.get_partial_values([(data_root + "foo", (0, 1))]) -# assert [ -# b"d", -# b"b", -# b"z", -# b"abc", -# b"defg", -# b"defg", -# b"g", -# b"ef", -# ] == store.get_partial_values( -# [ -# (data_root + "foo", (3, 1)), -# (data_root + "foo", (1, 1)), -# (data_root + "baz", (0, 1)), -# (data_root + "foo", (0, 3)), -# (data_root + "foo", (3, 4)), -# (data_root + "foo", (3, None)), -# (data_root + "foo", (-1, None)), -# (data_root + "foo", (-3, 2)), -# ] -# ) - -# def test_set_partial_values(self): -# store = self.create_store() -# store.supports_efficient_set_partial_values() -# store[data_root + "foo"] = b"abcdefg" -# store.set_partial_values([(data_root + "foo", 0, b"hey")]) -# assert store[data_root + "foo"] == b"heydefg" - -# store.set_partial_values([(data_root + "baz", 0, b"z")]) -# assert store[data_root + "baz"] == b"z" -# store.set_partial_values( -# [ -# (data_root + "foo", 1, b"oo"), -# (data_root + "baz", 1, b"zzz"), -# (data_root + "baz", 4, b"aaaa"), -# (data_root + "foo", 6, b"done"), -# ] -# ) -# assert store[data_root + "foo"] == b"hoodefdone" -# assert store[data_root + "baz"] == b"zzzzaaaa" -# store.set_partial_values( -# [ -# (data_root + "foo", -2, b"NE"), -# (data_root + "baz", -5, b"q"), -# ] -# ) -# assert store[data_root + "foo"] == b"hoodefdoNE" -# assert store[data_root + "baz"] == b"zzzq" - - -# class TestMappingStoreV3(StoreV3Tests): -# def create_store(self, **kwargs): -# return KVStoreV3(dict()) - -# def test_set_invalid_content(self): -# # Generic mappings support non-buffer types -# pass - - -# class TestMemoryStoreV3(_TestMemoryStore, StoreV3Tests): -# def create_store(self, **kwargs): -# skip_if_nested_chunks(**kwargs) -# return MemoryStoreV3(**kwargs) - - -# class TestDirectoryStoreV3(_TestDirectoryStore, StoreV3Tests): -# def create_store(self, normalize_keys=False, **kwargs): -# # For v3, don't have to skip if nested. -# # skip_if_nested_chunks(**kwargs) - -# path = tempfile.mkdtemp() -# atexit.register(atexit_rmtree, path) -# store = DirectoryStoreV3(path, normalize_keys=normalize_keys, **kwargs) -# return store - -# def test_rename_nonexisting(self): -# store = self.create_store() -# with pytest.raises(FileNotFoundError): -# store.rename(meta_root + "a", meta_root + "b") - - -# @pytest.mark.skipif(have_fsspec is False, reason="needs fsspec") -# class TestFSStoreV3(_TestFSStore, StoreV3Tests): -# def create_store(self, normalize_keys=False, dimension_separator=".", path=None, **kwargs): - -# if path is None: -# path = tempfile.mkdtemp() -# atexit.register(atexit_rmtree, path) - -# store = FSStoreV3( -# path, normalize_keys=normalize_keys, dimension_separator=dimension_separator, **kwargs -# ) -# return store - -# def test_init_array(self): -# store = self.create_store() -# path = "arr1" -# init_array(store, path=path, shape=1000, chunks=100) - -# # check metadata -# mkey = meta_root + path + ".array.json" -# assert mkey in store -# meta = store._metadata_class.decode_array_metadata(store[mkey]) -# assert (1000,) == meta["shape"] -# assert (100,) == meta["chunk_grid"]["chunk_shape"] -# assert np.dtype(None) == meta["data_type"] -# assert meta["chunk_grid"]["separator"] == "/" - - -# @pytest.mark.skipif(have_fsspec is False, reason="needs fsspec") -# class TestFSStoreV3WithKeySeparator(StoreV3Tests): -# def create_store(self, normalize_keys=False, key_separator=".", **kwargs): - -# # Since the user is passing key_separator, that will take priority. -# skip_if_nested_chunks(**kwargs) - -# path = tempfile.mkdtemp() -# atexit.register(atexit_rmtree, path) -# return FSStoreV3(path, normalize_keys=normalize_keys, key_separator=key_separator) - - -# # TODO: enable once N5StoreV3 has been implemented -# # @pytest.mark.skipif(True, reason="N5StoreV3 not yet fully implemented") -# # class TestN5StoreV3(_TestN5Store, TestDirectoryStoreV3, StoreV3Tests): - - -# class TestZipStoreV3(_TestZipStore, StoreV3Tests): - -# ZipStoreClass = ZipStoreV3 - -# def create_store(self, **kwargs): -# path = mktemp(suffix=".zip") -# atexit.register(os.remove, path) -# store = ZipStoreV3(path, mode="w", **kwargs) -# return store - - -# class TestDBMStoreV3(_TestDBMStore, StoreV3Tests): -# def create_store(self, dimension_separator=None): -# path = mktemp(suffix=".anydbm") -# atexit.register(atexit_rmglob, path + "*") -# # create store using default dbm implementation -# store = DBMStoreV3(path, flag="n", dimension_separator=dimension_separator) -# return store - - -# class TestDBMStoreV3Dumb(_TestDBMStoreDumb, StoreV3Tests): -# def create_store(self, **kwargs): -# path = mktemp(suffix=".dumbdbm") -# atexit.register(atexit_rmglob, path + "*") - -# import dbm.dumb as dumbdbm - -# store = DBMStoreV3(path, flag="n", open=dumbdbm.open, **kwargs) -# return store - - -# class TestDBMStoreV3Gnu(_TestDBMStoreGnu, StoreV3Tests): -# def create_store(self, **kwargs): -# gdbm = pytest.importorskip("dbm.gnu") -# path = mktemp(suffix=".gdbm") # pragma: no cover -# atexit.register(os.remove, path) # pragma: no cover -# store = DBMStoreV3( -# path, flag="n", open=gdbm.open, write_lock=False, **kwargs -# ) # pragma: no cover -# return store # pragma: no cover - - -# class TestDBMStoreV3NDBM(_TestDBMStoreNDBM, StoreV3Tests): -# def create_store(self, **kwargs): -# ndbm = pytest.importorskip("dbm.ndbm") -# path = mktemp(suffix=".ndbm") # pragma: no cover -# atexit.register(atexit_rmglob, path + "*") # pragma: no cover -# store = DBMStoreV3(path, flag="n", open=ndbm.open, **kwargs) # pragma: no cover -# return store # pragma: no cover - - -# class TestDBMStoreV3BerkeleyDB(_TestDBMStoreBerkeleyDB, StoreV3Tests): -# def create_store(self, **kwargs): -# bsddb3 = pytest.importorskip("bsddb3") -# path = mktemp(suffix=".dbm") -# atexit.register(os.remove, path) -# store = DBMStoreV3(path, flag="n", open=bsddb3.btopen, write_lock=False, **kwargs) -# return store - - -# class TestLMDBStoreV3(_TestLMDBStore, StoreV3Tests): -# def create_store(self, **kwargs): -# pytest.importorskip("lmdb") -# path = mktemp(suffix=".lmdb") -# atexit.register(atexit_rmtree, path) -# buffers = True -# store = LMDBStoreV3(path, buffers=buffers, **kwargs) -# return store - - -# class TestSQLiteStoreV3(_TestSQLiteStore, StoreV3Tests): -# def create_store(self, **kwargs): -# pytest.importorskip("sqlite3") -# path = mktemp(suffix=".db") -# atexit.register(atexit_rmtree, path) -# store = SQLiteStoreV3(path, **kwargs) -# return store - - -# class TestSQLiteStoreV3InMemory(_TestSQLiteStoreInMemory, StoreV3Tests): -# def create_store(self, **kwargs): -# pytest.importorskip("sqlite3") -# store = SQLiteStoreV3(":memory:", **kwargs) -# return store - - -# @skip_test_env_var("ZARR_TEST_MONGO") -# class TestMongoDBStoreV3(StoreV3Tests): -# def create_store(self, **kwargs): -# pytest.importorskip("pymongo") -# store = MongoDBStoreV3( -# host="127.0.0.1", database="zarr_tests", collection="zarr_tests", **kwargs -# ) -# # start with an empty store -# store.clear() -# return store - - -# @skip_test_env_var("ZARR_TEST_REDIS") -# class TestRedisStoreV3(StoreV3Tests): -# def create_store(self, **kwargs): -# # TODO: this is the default host for Redis on Travis, -# # we probably want to generalize this though -# pytest.importorskip("redis") -# store = RedisStoreV3(host="localhost", port=6379, **kwargs) -# # start with an empty store -# store.clear() -# return store - - -# @pytest.mark.skipif(not v3_sharding_available, reason="sharding is disabled") -# class TestStorageTransformerV3(TestMappingStoreV3): -# def create_store(self, **kwargs): -# inner_store = super().create_store(**kwargs) -# dummy_transformer = DummyStorageTransfomer( -# "dummy_type", test_value=DummyStorageTransfomer.TEST_CONSTANT -# ) -# sharding_transformer = ShardingStorageTransformer( -# "indexed", -# chunks_per_shard=2, -# ) -# path = "bla" -# init_array( -# inner_store, -# path=path, -# shape=1000, -# chunks=100, -# dimension_separator=".", -# storage_transformers=[dummy_transformer, sharding_transformer], -# ) -# store = Array(store=inner_store, path=path).chunk_store -# store.erase_prefix("data/root/bla/") -# store.clear() -# return store - -# def test_method_forwarding(self): -# store = self.create_store() -# inner_store = store.inner_store.inner_store -# assert store.list() == inner_store.list() -# assert store.list_dir(data_root) == inner_store.list_dir(data_root) - -# assert store.is_readable() -# assert store.is_writeable() -# assert store.is_listable() -# inner_store._readable = False -# inner_store._writeable = False -# inner_store._listable = False -# assert not store.is_readable() -# assert not store.is_writeable() -# assert not store.is_listable() - - -# class TestLRUStoreCacheV3(_TestLRUStoreCache, StoreV3Tests): - -# CountingClass = CountingDictV3 -# LRUStoreClass = LRUStoreCacheV3 - - -# @skip_test_env_var("ZARR_TEST_ABS") -# class TestABSStoreV3(_TestABSStore, StoreV3Tests): - -# ABSStoreClass = ABSStoreV3 - - -# def test_normalize_store_arg_v3(tmpdir): - -# fn = tmpdir.join("store.zip") -# store = normalize_store_arg(str(fn), zarr_version=3, mode="w") -# assert isinstance(store, ZipStoreV3) -# assert "zarr.json" in store - -# # can't pass storage_options to non-fsspec store -# with pytest.raises(ValueError): -# normalize_store_arg(str(fn), zarr_version=3, mode="w", storage_options={"some": "kwargs"}) - -# if have_fsspec: -# import fsspec - -# path = tempfile.mkdtemp() -# store = normalize_store_arg("file://" + path, zarr_version=3, mode="w") -# assert isinstance(store, FSStoreV3) -# assert "zarr.json" in store - -# store = normalize_store_arg(fsspec.get_mapper("file://" + path), zarr_version=3) -# assert isinstance(store, FSStoreV3) - -# # regression for https://github.com/zarr-developers/zarr-python/issues/1382 -# # contents of zarr.json are not important for this test -# out = {"version": 1, "refs": {"zarr.json": "{...}"}} -# store = normalize_store_arg( -# "reference://", -# storage_options={"fo": out, "remote_protocol": "memory"}, zarr_version=3 -# ) -# assert isinstance(store, FSStoreV3) - -# fn = tmpdir.join("store.n5") -# with pytest.raises(NotImplementedError): -# normalize_store_arg(str(fn), zarr_version=3, mode="w") - -# # error on zarr_version=3 with a v2 store -# with pytest.raises(ValueError): -# normalize_store_arg(KVStore(dict()), zarr_version=3, mode="w") - -# # error on zarr_version=2 with a v3 store -# with pytest.raises(ValueError): -# normalize_store_arg(KVStoreV3(dict()), zarr_version=2, mode="w") - - -# class TestConsolidatedMetadataStoreV3(_TestConsolidatedMetadataStore): - -# version = 3 -# ConsolidatedMetadataClass = ConsolidatedMetadataStoreV3 - -# @property -# def metadata_key(self): -# return meta_root + "consolidated/.zmetadata" +@pytest.mark.parametrize("store_dict", (None, {})) +class TestMemoryStore(StoreTests[MemoryStore]): + store_cls = MemoryStore -# def test_bad_store_version(self): -# with pytest.raises(ValueError): -# self.ConsolidatedMetadataClass(KVStore(dict())) + def set(self, store: MemoryStore, key: str, value: Buffer) -> None: + store._store_dict[key] = value + def get(self, store: MemoryStore, key: str) -> Buffer: + return store._store_dict[key] -# def test_get_hierarchy_metadata(): -# store = KVStoreV3({}) + @pytest.fixture(scope="function") + def store(self, store_dict: MutableMapping[str, Buffer] | None): + return MemoryStore(store_dict=store_dict) -# # error raised if 'jarr.json' is not in the store -# with pytest.raises(ValueError): -# _get_hierarchy_metadata(store) + def test_store_repr(self, store: MemoryStore) -> None: + assert str(store) == f"memory://{id(store._store_dict)}" -# store["zarr.json"] = _default_entry_point_metadata_v3 -# assert _get_hierarchy_metadata(store) == _default_entry_point_metadata_v3 + def test_store_supports_writes(self, store: MemoryStore) -> None: + assert True -# # ValueError if only a subset of keys are present -# store["zarr.json"] = {"zarr_format": "https://purl.org/zarr/spec/protocol/core/3.0"} -# with pytest.raises(ValueError): -# _get_hierarchy_metadata(store) + def test_store_supports_listing(self, store: MemoryStore) -> None: + assert True -# # ValueError if any unexpected keys are present -# extra_metadata = copy.copy(_default_entry_point_metadata_v3) -# extra_metadata["extra_key"] = "value" -# store["zarr.json"] = extra_metadata -# with pytest.raises(ValueError): -# _get_hierarchy_metadata(store) + def test_store_supports_partial_writes(self, store: MemoryStore) -> None: + assert True + def test_list_prefix(self, store: MemoryStore) -> None: + assert True -# def test_top_level_imports(): -# for store_name in [ -# "ABSStoreV3", -# "DBMStoreV3", -# "KVStoreV3", -# "DirectoryStoreV3", -# "LMDBStoreV3", -# "LRUStoreCacheV3", -# "MemoryStoreV3", -# "MongoDBStoreV3", -# "RedisStoreV3", -# "SQLiteStoreV3", -# "ZipStoreV3", -# ]: -# if v3_api_available: -# assert hasattr(zarr, store_name) # pragma: no cover -# else: -# assert not hasattr(zarr, store_name) # pragma: no cover +class TestLocalStore(StoreTests[LocalStore]): + store_cls = LocalStore -# def _get_public_and_dunder_methods(some_class): -# return set( -# name -# for name, _ in inspect.getmembers(some_class, predicate=inspect.isfunction) -# if not name.startswith("_") or name.startswith("__") -# ) + def get(self, store: LocalStore, key: str) -> Buffer: + return Buffer.from_bytes((store.root / key).read_bytes()) + def set(self, store: LocalStore, key: str, value: Buffer) -> None: + parent = (store.root / key).parent + if not parent.exists(): + parent.mkdir(parents=True) + (store.root / key).write_bytes(value.to_bytes()) -# def test_storage_transformer_interface(): -# store_v3_methods = _get_public_and_dunder_methods(StoreV3) -# store_v3_methods.discard("__init__") -# # Note, getitems() isn't mandatory when get_partial_values() is available -# store_v3_methods.discard("getitems") -# storage_transformer_methods = _get_public_and_dunder_methods(StorageTransformer) -# storage_transformer_methods.discard("__init__") -# storage_transformer_methods.discard("get_config") -# assert storage_transformer_methods == store_v3_methods + @pytest.fixture(scope="function") + def store(self, tmpdir) -> LocalStore: + return self.store_cls(str(tmpdir)) + def test_store_repr(self, store: LocalStore) -> None: + assert str(store) == f"file://{store.root!s}" -class TestMemoryStore(StoreTests): - store_cls = MemoryStore + def test_store_supports_writes(self, store: LocalStore) -> None: + assert True + def test_store_supports_partial_writes(self, store: LocalStore) -> None: + assert True -class TestLocalStore(StoreTests): - store_cls = LocalStore + def test_store_supports_listing(self, store: LocalStore) -> None: + assert True - @pytest.fixture(scope="function") - @pytest.mark.parametrize("auto_mkdir", (True, False)) - def store(self, tmpdir) -> LocalStore: - return self.store_cls(str(tmpdir)) + def test_list_prefix(self, store: LocalStore) -> None: + assert True From ef15e20192c294cc82b9194ca29190ac9806d6fa Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 29 May 2024 08:00:52 -0700 Subject: [PATCH 2/2] [v3] Feature: Store open mode (#1911) * wip * feature(store): set open mode on store initialization --- src/zarr/abc/store.py | 27 ++++++++++++++++++++++++++- src/zarr/common.py | 1 + src/zarr/store/core.py | 10 ++++++++-- src/zarr/store/local.py | 8 ++++++-- src/zarr/store/memory.py | 9 +++++++-- src/zarr/store/remote.py | 10 +++++++++- src/zarr/testing/store.py | 37 ++++++++++++++++++++++++++++++++++--- tests/v3/conftest.py | 14 +++++++------- tests/v3/test_codecs.py | 2 +- tests/v3/test_store.py | 29 ++++++++++++++++------------- tests/v3/test_v2.py | 2 +- 11 files changed, 116 insertions(+), 33 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 7087706b33..e86fe5d07a 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -3,10 +3,31 @@ from typing import Protocol, runtime_checkable from zarr.buffer import Buffer -from zarr.common import BytesLike +from zarr.common import BytesLike, OpenMode class Store(ABC): + _mode: OpenMode + + def __init__(self, mode: OpenMode = "r"): + if mode not in ("r", "r+", "w", "w-", "a"): + raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'") + self._mode = mode + + @property + def mode(self) -> OpenMode: + """Access mode of the store.""" + return self._mode + + @property + def writeable(self) -> bool: + """Is the store writeable?""" + return self.mode in ("a", "w", "w-") + + def _check_writable(self) -> None: + if not self.writeable: + raise ValueError("store mode does not support writing") + @abstractmethod async def get( self, key: str, byte_range: tuple[int | None, int | None] | None = None @@ -147,6 +168,10 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ ... + def close(self) -> None: # noqa: B027 + """Close the store.""" + pass + @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/common.py b/src/zarr/common.py index 9d8315abc8..9527efbbce 100644 --- a/src/zarr/common.py +++ b/src/zarr/common.py @@ -27,6 +27,7 @@ Selection = slice | SliceSelection ZarrFormat = Literal[2, 3] JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...] +OpenMode = Literal["r", "r+", "a", "w", "w-"] def product(tup: ChunkCoords) -> int: diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index 4e7a7fcca1..abb08291df 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -5,6 +5,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.common import OpenMode from zarr.store.local import LocalStore @@ -60,13 +61,18 @@ def __eq__(self, other: Any) -> bool: StoreLike = Store | StorePath | Path | str -def make_store_path(store_like: StoreLike) -> StorePath: +def make_store_path(store_like: StoreLike, *, mode: OpenMode | None = None) -> StorePath: if isinstance(store_like, StorePath): + if mode is not None: + assert mode == store_like.store.mode return store_like elif isinstance(store_like, Store): + if mode is not None: + assert mode == store_like.mode return StorePath(store_like) elif isinstance(store_like, str): - return StorePath(LocalStore(Path(store_like))) + assert mode is not None + return StorePath(LocalStore(Path(store_like), mode=mode)) raise TypeError diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 50fe9701fc..40abe12932 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -7,7 +7,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer -from zarr.common import concurrent_map, to_thread +from zarr.common import OpenMode, concurrent_map, to_thread def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer: @@ -69,7 +69,8 @@ class LocalStore(Store): root: Path - def __init__(self, root: Path | str): + def __init__(self, root: Path | str, *, mode: OpenMode = "r"): + super().__init__(mode=mode) if isinstance(root, str): root = Path(root) assert isinstance(root, Path) @@ -117,6 +118,7 @@ async def get_partial_values( return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: + self._check_writable() assert isinstance(key, str) if isinstance(value, bytes | bytearray): # TODO: to support the v2 tests, we convert bytes to Buffer here @@ -127,6 +129,7 @@ async def set(self, key: str, value: Buffer) -> None: await to_thread(_put, path, value) async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + self._check_writable() args = [] for key, start, value in key_start_values: assert isinstance(key, str) @@ -138,6 +141,7 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes] await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def delete(self, key: str) -> None: + self._check_writable() path = self.root / key if path.is_dir(): # TODO: support deleting directories? shutil.rmtree? shutil.rmtree(path) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 5e438919cf..74bb5454fe 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -4,7 +4,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer -from zarr.common import concurrent_map +from zarr.common import OpenMode, concurrent_map from zarr.store.core import _normalize_interval_index @@ -17,7 +17,10 @@ class MemoryStore(Store): _store_dict: MutableMapping[str, Buffer] - def __init__(self, store_dict: MutableMapping[str, Buffer] | None = None): + def __init__( + self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r" + ): + super().__init__(mode=mode) self._store_dict = store_dict or {} def __str__(self) -> str: @@ -47,6 +50,7 @@ async def exists(self, key: str) -> bool: return key in self._store_dict async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + self._check_writable() assert isinstance(key, str) if isinstance(value, bytes | bytearray): # TODO: to support the v2 tests, we convert bytes to Buffer here @@ -62,6 +66,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None self._store_dict[key] = value async def delete(self, key: str) -> None: + self._check_writable() try: del self._store_dict[key] except KeyError: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index a3395459fd..3b086f0a03 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -4,6 +4,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.common import OpenMode from zarr.store.core import _dereference_path if TYPE_CHECKING: @@ -18,10 +19,14 @@ class RemoteStore(Store): root: UPath - def __init__(self, url: UPath | str, **storage_options: dict[str, Any]): + def __init__( + self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any] + ): import fsspec from upath import UPath + super().__init__(mode=mode) + if isinstance(url, str): self.root = UPath(url, **storage_options) else: @@ -29,6 +34,7 @@ def __init__(self, url: UPath | str, **storage_options: dict[str, Any]): len(storage_options) == 0 ), "If constructed with a UPath object, no additional storage_options are allowed." self.root = url.rstrip("/") + # test instantiate file system fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs) assert fs.__class__.async_impl, "FileSystem needs to support async operations." @@ -67,6 +73,7 @@ async def get( return value async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + self._check_writable() assert isinstance(key, str) fs, root = self._make_fs() path = _dereference_path(root, key) @@ -80,6 +87,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None await fs._pipe_file(path, value) async def delete(self, key: str) -> None: + self._check_writable() fs, root = self._make_fs() path = _dereference_path(root, key) if await fs._exists(path): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1c0ed93734..b317f383f6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import pytest @@ -31,13 +31,43 @@ def get(self, store: S, key: str) -> Buffer: raise NotImplementedError @pytest.fixture(scope="function") - def store(self) -> Store: - return self.store_cls() + def store_kwargs(self) -> dict[str, Any]: + return {"mode": "w"} + + @pytest.fixture(scope="function") + def store(self, store_kwargs: dict[str, Any]) -> Store: + return self.store_cls(**store_kwargs) def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) + def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None: + assert store.mode == "w", store.mode + assert store.writeable + + with pytest.raises(AttributeError): + store.mode = "w" # type: ignore + + # read-only + kwargs = {**store_kwargs, "mode": "r"} + read_store = self.store_cls(**kwargs) + assert read_store.mode == "r", read_store.mode + assert not read_store.writeable + + async def test_not_writable_store_raises(self, store_kwargs: dict[str, Any]) -> None: + kwargs = {**store_kwargs, "mode": "r"} + store = self.store_cls(**kwargs) + assert not store.writeable + + # set + with pytest.raises(ValueError): + await store.set("foo", Buffer.from_bytes(b"bar")) + + # delete + with pytest.raises(ValueError): + await store.delete("foo") + def test_store_repr(self, store: S) -> None: raise NotImplementedError @@ -72,6 +102,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: """ Ensure that data can be written to the store using the store.set method. """ + assert store.writeable data_buf = Buffer.from_bytes(data) await store.set(key, data_buf) observed = self.get(store, key) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 21dc58197e..6b58cce412 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -22,11 +22,11 @@ def parse_store( store: Literal["local", "memory", "remote"], path: str ) -> LocalStore | MemoryStore | RemoteStore: if store == "local": - return LocalStore(path) + return LocalStore(path, mode="w") if store == "memory": - return MemoryStore() + return MemoryStore(mode="w") if store == "remote": - return RemoteStore() + return RemoteStore(mode="w") raise AssertionError @@ -38,24 +38,24 @@ def path_type(request): # todo: harmonize this with local_store fixture @pytest.fixture def store_path(tmpdir): - store = LocalStore(str(tmpdir)) + store = LocalStore(str(tmpdir), mode="w") p = StorePath(store) return p @pytest.fixture(scope="function") def local_store(tmpdir): - return LocalStore(str(tmpdir)) + return LocalStore(str(tmpdir), mode="w") @pytest.fixture(scope="function") def remote_store(): - return RemoteStore() + return RemoteStore(mode="w") @pytest.fixture(scope="function") def memory_store(): - return MemoryStore() + return MemoryStore(mode="w") @pytest.fixture(scope="function") diff --git a/tests/v3/test_codecs.py b/tests/v3/test_codecs.py index a595b12494..251570f767 100644 --- a/tests/v3/test_codecs.py +++ b/tests/v3/test_codecs.py @@ -50,7 +50,7 @@ async def set(self, value: np.ndarray): @pytest.fixture def store() -> Iterator[Store]: - yield StorePath(MemoryStore()) + yield StorePath(MemoryStore(mode="w")) @pytest.fixture diff --git a/tests/v3/test_store.py b/tests/v3/test_store.py index 75438f8612..52882ea78c 100644 --- a/tests/v3/test_store.py +++ b/tests/v3/test_store.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import MutableMapping +from typing import Any import pytest @@ -10,7 +10,6 @@ from zarr.testing.store import StoreTests -@pytest.mark.parametrize("store_dict", (None, {})) class TestMemoryStore(StoreTests[MemoryStore]): store_cls = MemoryStore @@ -20,21 +19,25 @@ def set(self, store: MemoryStore, key: str, value: Buffer) -> None: def get(self, store: MemoryStore, key: str) -> Buffer: return store._store_dict[key] + @pytest.fixture(scope="function", params=[None, {}]) + def store_kwargs(self, request) -> dict[str, Any]: + return {"store_dict": request.param, "mode": "w"} + @pytest.fixture(scope="function") - def store(self, store_dict: MutableMapping[str, Buffer] | None): - return MemoryStore(store_dict=store_dict) + def store(self, store_kwargs: dict[str, Any]) -> MemoryStore: + return self.store_cls(**store_kwargs) def test_store_repr(self, store: MemoryStore) -> None: assert str(store) == f"memory://{id(store._store_dict)}" def test_store_supports_writes(self, store: MemoryStore) -> None: - assert True + assert store.supports_writes def test_store_supports_listing(self, store: MemoryStore) -> None: - assert True + assert store.supports_listing def test_store_supports_partial_writes(self, store: MemoryStore) -> None: - assert True + assert store.supports_partial_writes def test_list_prefix(self, store: MemoryStore) -> None: assert True @@ -52,21 +55,21 @@ def set(self, store: LocalStore, key: str, value: Buffer) -> None: parent.mkdir(parents=True) (store.root / key).write_bytes(value.to_bytes()) - @pytest.fixture(scope="function") - def store(self, tmpdir) -> LocalStore: - return self.store_cls(str(tmpdir)) + @pytest.fixture + def store_kwargs(self, tmpdir) -> dict[str, str]: + return {"root": str(tmpdir), "mode": "w"} def test_store_repr(self, store: LocalStore) -> None: assert str(store) == f"file://{store.root!s}" def test_store_supports_writes(self, store: LocalStore) -> None: - assert True + assert store.supports_writes def test_store_supports_partial_writes(self, store: LocalStore) -> None: - assert True + assert store.supports_partial_writes def test_store_supports_listing(self, store: LocalStore) -> None: - assert True + assert store.supports_listing def test_list_prefix(self, store: LocalStore) -> None: assert True diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 2a38dc8fdc..41555bbd26 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -10,7 +10,7 @@ @pytest.fixture def store() -> Iterator[Store]: - yield StorePath(MemoryStore()) + yield StorePath(MemoryStore(mode="w")) def test_simple(store: Store):