Skip to content

Commit

Permalink
fix(storage): change StoreTests get/set methods to async
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Oct 8, 2024
1 parent 7e2be57 commit 474947c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
30 changes: 17 additions & 13 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class StoreTests(Generic[S, B]):
store_cls: type[S]
buffer_cls: type[B]

def set(self, store: S, key: str, value: Buffer) -> None:
async def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert a value into a storage backend, with a specific key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError

def get(self, store: S, key: str) -> Buffer:
async def get(self, store: S, key: str) -> Buffer:
"""
Retrieve a value from a storage backend, by key.
This should not not use any store methods. Bypassing the store methods allows them to be
Expand Down Expand Up @@ -106,7 +106,7 @@ async def test_get(
Ensure that data can be read from the store using the store.get method.
"""
data_buf = self.buffer_cls.from_bytes(data)
self.set(store, key, data_buf)
await self.set(store, key, data_buf)
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
start, length = _normalize_interval_index(data_buf, interval=byte_range)
expected = data_buf[start : start + length]
Expand All @@ -119,7 +119,7 @@ async def test_get_many(self, store: S) -> None:
keys = tuple(map(str, range(10)))
values = tuple(f"{k}".encode() for k in keys)
for k, v in zip(keys, values, strict=False):
self.set(store, k, self.buffer_cls.from_bytes(v))
await self.set(store, k, self.buffer_cls.from_bytes(v))
observed_buffers = await _collect_aiterator(
store._get_many(
zip(
Expand All @@ -143,7 +143,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
assert not store.mode.readonly
data_buf = self.buffer_cls.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
observed = await self.get(store, key)
assert_bytes_equal(observed, data_buf)

async def test_set_many(self, store: S) -> None:
Expand All @@ -156,7 +156,7 @@ async def test_set_many(self, store: S) -> None:
store_dict = dict(zip(keys, data_buf, strict=True))
await store._set_many(store_dict.items())
for k, v in store_dict.items():
assert self.get(store, k).to_bytes() == v.to_bytes()
assert (await self.get(store, k)).to_bytes() == v.to_bytes()

@pytest.mark.parametrize(
"key_ranges",
Expand All @@ -172,7 +172,7 @@ async def test_get_partial_values(
) -> None:
# put all of the data
for key, _ in key_ranges:
self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))

# read back just part of it
observed_maybe = await store.get_partial_values(
Expand Down Expand Up @@ -211,11 +211,15 @@ async def test_delete(self, store: S) -> None:

async def test_empty(self, store: S) -> None:
assert await store.empty()
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
await self.set(
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
assert not await store.empty()

async def test_clear(self, store: S) -> None:
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
await self.set(
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
await store.clear()
assert await store.empty()

Expand Down Expand Up @@ -277,8 +281,8 @@ async def test_list_dir(self, store: S) -> None:

async def test_with_mode(self, store: S) -> None:
data = b"0000"
self.set(store, "key", self.buffer_cls.from_bytes(data))
assert self.get(store, "key").to_bytes() == data
await self.set(store, "key", self.buffer_cls.from_bytes(data))
assert (await self.get(store, "key")).to_bytes() == data

for mode in ["r", "a"]:
mode = cast(AccessModeLiteral, mode)
Expand All @@ -294,7 +298,7 @@ async def test_with_mode(self, store: S) -> None:
assert result.to_bytes() == data

# writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
await self.set(store, "key-2", self.buffer_cls.from_bytes(data))
result = await clone.get("key-2", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data
Expand All @@ -313,7 +317,7 @@ async def test_with_mode(self, store: S) -> None:
async def test_set_if_not_exists(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
self.set(store, key, data_buf)
await self.set(store, key, data_buf)

new = self.buffer_cls.from_bytes(b"1111")
await store.set_if_not_exists("k", new) # no error
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]):
store_cls = LocalStore
buffer_cls = cpu.Buffer

def get(self, store: LocalStore, key: str) -> Buffer:
async def get(self, store: LocalStore, key: str) -> Buffer:
return self.buffer_cls.from_bytes((store.root / key).read_bytes())

def set(self, store: LocalStore, key: str, value: Buffer) -> None:
async def set(self, store: LocalStore, key: str, value: Buffer) -> None:
parent = (store.root / key).parent
if not parent.exists():
parent.mkdir(parents=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
store_cls = MemoryStore
buffer_cls = cpu.Buffer

def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
async def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
store._store_dict[key] = value

def get(self, store: MemoryStore, key: str) -> Buffer:
async def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(params=[None, True])
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
return self.store_cls(**store_kwargs)

def get(self, store: RemoteStore, key: str) -> Buffer:
async def get(self, store: RemoteStore, key: str) -> Buffer:
# make a new, synchronous instance of the filesystem because this test is run in sync code
new_fs = fsspec.filesystem(
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False
)
return self.buffer_cls.from_bytes(new_fs.cat(f"{store.path}/{key}"))

def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
async def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
# make a new, synchronous instance of the filesystem because this test is run in sync code
new_fs = fsspec.filesystem(
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def store_kwargs(self, request) -> dict[str, str | bool]:

return {"path": temp_path, "mode": "w"}

def get(self, store: ZipStore, key: str) -> Buffer:
async def get(self, store: ZipStore, key: str) -> Buffer:
return store._get(key, prototype=default_buffer_prototype())

def set(self, store: ZipStore, key: str, value: Buffer) -> None:
async def set(self, store: ZipStore, key: str, value: Buffer) -> None:
return store._set(key, value)

def test_store_mode(self, store: ZipStore, store_kwargs: dict[str, Any]) -> None:
Expand Down

0 comments on commit 474947c

Please sign in to comment.