Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(storage): change StoreTests get/set methods to async #2313

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class StoreTests(Generic[S, B]):
store_cls: type[S]
buffer_cls: type[B]

def set(self, store: S, key: str, value: Buffer) -> None:
async def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert a value into a storage backend, with a specific key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError

def get(self, store: S, key: str) -> Buffer:
async def get(self, store: S, key: str) -> Buffer:
"""
Retrieve a value from a storage backend, by key.
This should not not use any store methods. Bypassing the store methods allows them to be
Expand Down Expand Up @@ -106,7 +106,7 @@ async def test_get(
Ensure that data can be read from the store using the store.get method.
"""
data_buf = self.buffer_cls.from_bytes(data)
self.set(store, key, data_buf)
await self.set(store, key, data_buf)
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
start, length = _normalize_interval_index(data_buf, interval=byte_range)
expected = data_buf[start : start + length]
Expand All @@ -119,7 +119,7 @@ async def test_get_many(self, store: S) -> None:
keys = tuple(map(str, range(10)))
values = tuple(f"{k}".encode() for k in keys)
for k, v in zip(keys, values, strict=False):
self.set(store, k, self.buffer_cls.from_bytes(v))
await self.set(store, k, self.buffer_cls.from_bytes(v))
observed_buffers = await _collect_aiterator(
store._get_many(
zip(
Expand All @@ -143,7 +143,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
assert not store.mode.readonly
data_buf = self.buffer_cls.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
observed = await self.get(store, key)
assert_bytes_equal(observed, data_buf)

async def test_set_many(self, store: S) -> None:
Expand All @@ -156,7 +156,7 @@ async def test_set_many(self, store: S) -> None:
store_dict = dict(zip(keys, data_buf, strict=True))
await store._set_many(store_dict.items())
for k, v in store_dict.items():
assert self.get(store, k).to_bytes() == v.to_bytes()
assert (await self.get(store, k)).to_bytes() == v.to_bytes()

@pytest.mark.parametrize(
"key_ranges",
Expand All @@ -172,7 +172,7 @@ async def test_get_partial_values(
) -> None:
# put all of the data
for key, _ in key_ranges:
self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))

# read back just part of it
observed_maybe = await store.get_partial_values(
Expand Down Expand Up @@ -211,11 +211,15 @@ async def test_delete(self, store: S) -> None:

async def test_empty(self, store: S) -> None:
assert await store.empty()
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
await self.set(
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
assert not await store.empty()

async def test_clear(self, store: S) -> None:
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
await self.set(
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
await store.clear()
assert await store.empty()

Expand Down Expand Up @@ -277,8 +281,8 @@ async def test_list_dir(self, store: S) -> None:

async def test_with_mode(self, store: S) -> None:
data = b"0000"
self.set(store, "key", self.buffer_cls.from_bytes(data))
assert self.get(store, "key").to_bytes() == data
await self.set(store, "key", self.buffer_cls.from_bytes(data))
assert (await self.get(store, "key")).to_bytes() == data

for mode in ["r", "a"]:
mode = cast(AccessModeLiteral, mode)
Expand All @@ -294,7 +298,7 @@ async def test_with_mode(self, store: S) -> None:
assert result.to_bytes() == data

# writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
await self.set(store, "key-2", self.buffer_cls.from_bytes(data))
result = await clone.get("key-2", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data
Expand All @@ -313,7 +317,7 @@ async def test_with_mode(self, store: S) -> None:
async def test_set_if_not_exists(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
self.set(store, key, data_buf)
await self.set(store, key, data_buf)

new = self.buffer_cls.from_bytes(b"1111")
await store.set_if_not_exists("k", new) # no error
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]):
store_cls = LocalStore
buffer_cls = cpu.Buffer

def get(self, store: LocalStore, key: str) -> Buffer:
async def get(self, store: LocalStore, key: str) -> Buffer:
return self.buffer_cls.from_bytes((store.root / key).read_bytes())

def set(self, store: LocalStore, key: str, value: Buffer) -> None:
async def set(self, store: LocalStore, key: str, value: Buffer) -> None:
parent = (store.root / key).parent
if not parent.exists():
parent.mkdir(parents=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
store_cls = MemoryStore
buffer_cls = cpu.Buffer

def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
async def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
store._store_dict[key] = value

def get(self, store: MemoryStore, key: str) -> Buffer:
async def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(params=[None, True])
Expand Down Expand Up @@ -52,10 +52,10 @@ class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]):
store_cls = GpuMemoryStore
buffer_cls = gpu.Buffer

def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
async def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
store._store_dict[key] = value

def get(self, store: MemoryStore, key: str) -> Buffer:
async def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(params=[None, True])
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
return self.store_cls(**store_kwargs)

def get(self, store: RemoteStore, key: str) -> Buffer:
async def get(self, store: RemoteStore, key: str) -> Buffer:
# make a new, synchronous instance of the filesystem because this test is run in sync code
new_fs = fsspec.filesystem(
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False
)
return self.buffer_cls.from_bytes(new_fs.cat(f"{store.path}/{key}"))

def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
async def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
# make a new, synchronous instance of the filesystem because this test is run in sync code
new_fs = fsspec.filesystem(
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False
Expand Down
4 changes: 2 additions & 2 deletions tests/v3/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def store_kwargs(self, request) -> dict[str, str | bool]:

return {"path": temp_path, "mode": "w"}

def get(self, store: ZipStore, key: str) -> Buffer:
async def get(self, store: ZipStore, key: str) -> Buffer:
return store._get(key, prototype=default_buffer_prototype())

def set(self, store: ZipStore, key: str, value: Buffer) -> None:
async def set(self, store: ZipStore, key: str, value: Buffer) -> None:
return store._set(key, value)

def test_store_mode(self, store: ZipStore, store_kwargs: dict[str, Any]) -> None:
Expand Down