Skip to content

Commit

Permalink
Getitems: support meta_array (#1131)
Browse files Browse the repository at this point in the history
* Use _chunk_getitems() always

* Implement getitems() always

* FSStore.getitems(): accept meta_array and on_error

* getitems(): handle on_error="omit"

* Removed the `on_error argument`

* remove redundant check

* getitems(): use Sequence instead of Iterable

* Typo

Co-authored-by: Josh Moore <josh@openmicroscopy.org>

* Introduce a contexts argument

* CountingDict: impl. getitems()

* added test_getitems()

* Introduce Context

* doc

* support the new get_partial_values() method

* Resolve conflict with get_partial_values()

* make contexts keyword-only

* Introduce ConstantMap

* use typing.Mapping

* test_constant_map

---------

Co-authored-by: jakirkham <jakirkham@gmail.com>
Co-authored-by: Josh Moore <josh@openmicroscopy.org>
  • Loading branch information
3 people authored Apr 13, 2023
1 parent 4b0705c commit b14f15f
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 70 deletions.
28 changes: 28 additions & 0 deletions zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from zarr.meta import Metadata2, Metadata3
from zarr.util import normalize_storage_path
from zarr.context import Context

# v2 store keys
array_meta_key = '.zarray'
Expand Down Expand Up @@ -131,6 +132,33 @@ def _ensure_store(store: Any):
f"wrap it in Zarr.storage.KVStore. Got {store}"
)

def getitems(
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
) -> Mapping[str, Any]:
"""Retrieve data from multiple keys.
Parameters
----------
keys : Iterable[str]
The keys to retrieve
contexts: Mapping[str, Context]
A mapping of keys to their context. Each context is a mapping of store
specific information. E.g. a context could be a dict telling the store
the preferred output array type: `{"meta_array": cupy.empty(())}`
Returns
-------
Mapping
A collection mapping the input keys to their results.
Notes
-----
This default implementation uses __getitem__() to read each key sequentially and
ignores contexts. Overwrite this method to implement concurrent reads of multiple
keys and/or to utilize the contexts.
"""
return {k: self[k] for k in keys if k in self}


class Store(BaseStore):
"""Abstract store class used by implementations following the Zarr v2 spec.
Expand Down
19 changes: 19 additions & 0 deletions zarr/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from typing import TypedDict

from numcodecs.compat import NDArrayLike


class Context(TypedDict, total=False):
""" A context for component specific information
All keys are optional. Any component reading the context must provide
a default implementation in the case a key cannot be found.
Items
-----
meta_array : array-like, optional
An array-like instance to use for determining the preferred output
array type.
"""
meta_array: NDArrayLike
92 changes: 28 additions & 64 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available
from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
from zarr.context import Context
from zarr.errors import ArrayNotFoundError, ReadOnlyError, ArrayIndexError
from zarr.indexing import (
BasicIndexer,
Expand Down Expand Up @@ -41,6 +42,7 @@
normalize_store_arg,
)
from zarr.util import (
ConstantMap,
all_equal,
InfoReporter,
check_array_shape,
Expand Down Expand Up @@ -1275,24 +1277,14 @@ def _get_selection(self, indexer, out=None, fields=None):
check_array_shape('out', out, out_shape)

# iterate over chunks
if (
not hasattr(self.chunk_store, "getitems") and not (
hasattr(self.chunk_store, "get_partial_values") and
self.chunk_store.supports_efficient_get_partial_values
)
) or any(map(lambda x: x == 0, self.shape)):
# sequentially get one key at a time from storage
for chunk_coords, chunk_selection, out_selection in indexer:

# load chunk selection into output array
self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection,
drop_axes=indexer.drop_axes, fields=fields)
else:
if math.prod(out_shape) > 0:
# allow storage to get multiple items at once
lchunk_coords, lchunk_selection, lout_selection = zip(*indexer)
self._chunk_getitems(lchunk_coords, lchunk_selection, out, lout_selection,
drop_axes=indexer.drop_axes, fields=fields)

self._chunk_getitems(
lchunk_coords, lchunk_selection, out, lout_selection,
drop_axes=indexer.drop_axes, fields=fields
)
if out.shape:
return out
else:
Expand Down Expand Up @@ -1963,68 +1955,36 @@ def _process_chunk(
# store selected data in output
out[out_selection] = tmp

def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection,
drop_axes=None, fields=None):
"""Obtain part or whole of a chunk.
def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
drop_axes=None, fields=None):
"""Obtain part or whole of chunks.
Parameters
----------
chunk_coords : tuple of ints
Indices of the chunk.
chunk_selection : selection
Location of region within the chunk to extract.
chunk_coords : list of tuple of ints
Indices of the chunks.
chunk_selection : list of selections
Location of region within the chunks to extract.
out : ndarray
Array to store result in.
out_selection : selection
Location of region within output array to store results in.
out_selection : list of selections
Location of regions within output array to store results in.
drop_axes : tuple of ints
Axes to squeeze out of the chunk.
fields
TODO
"""
out_is_ndarray = True
try:
out = ensure_ndarray_like(out)
except TypeError:
out_is_ndarray = False

assert len(chunk_coords) == len(self._cdata_shape)

# obtain key for chunk
ckey = self._chunk_key(chunk_coords)

try:
# obtain compressed data for chunk
cdata = self.chunk_store[ckey]

except KeyError:
# chunk not initialized
if self._fill_value is not None:
if fields:
fill_value = self._fill_value[fields]
else:
fill_value = self._fill_value
out[out_selection] = fill_value

else:
self._process_chunk(out, cdata, chunk_selection, drop_axes,
out_is_ndarray, fields, out_selection)

def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
drop_axes=None, fields=None):
"""As _chunk_getitem, but for lists of chunks
This gets called where the storage supports ``getitems``, so that
it can decide how to fetch the keys, allowing concurrency.
"""
out_is_ndarray = True
try:
out = ensure_ndarray_like(out)
except TypeError: # pragma: no cover
out_is_ndarray = False

# Keys to retrieve
ckeys = [self._chunk_key(ch) for ch in lchunk_coords]

# Check if we can do a partial read
if (
self._partial_decompress
and self._compressor
Expand Down Expand Up @@ -2056,13 +2016,17 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
for ckey in ckeys
if ckey in self.chunk_store
}
elif hasattr(self.chunk_store, "get_partial_values"):
partial_read_decode = False
values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys])
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}
else:
partial_read_decode = False
if not hasattr(self.chunk_store, "getitems"):
values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys])
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}
else:
cdatas = self.chunk_store.getitems(ckeys, on_error="omit")
contexts = {}
if not isinstance(self._meta_array, np.ndarray):
contexts = ConstantMap(ckeys, constant=Context(meta_array=self._meta_array))
cdatas = self.chunk_store.getitems(ckeys, contexts=contexts)

for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection):
if ckey in cdatas:
self._process_chunk(
Expand Down
8 changes: 6 additions & 2 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from os import scandir
from pickle import PicklingError
from threading import Lock, RLock
from typing import Optional, Union, List, Tuple, Dict, Any
from typing import Sequence, Mapping, Optional, Union, List, Tuple, Dict, Any
import uuid
import time

Expand All @@ -42,6 +42,7 @@
ensure_contiguous_ndarray_like
)
from numcodecs.registry import codec_registry
from zarr.context import Context

from zarr.errors import (
MetadataError,
Expand Down Expand Up @@ -1380,7 +1381,10 @@ def _normalize_key(self, key):

return key.lower() if self.normalize_keys else key

def getitems(self, keys, **kwargs):
def getitems(
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
) -> Mapping[str, Any]:

keys_transformed = [self._normalize_key(key) for key in keys]
results = self.map.getitems(keys_transformed, on_error="omit")
# The function calling this method may not recognize the transformed keys
Expand Down
35 changes: 34 additions & 1 deletion zarr/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import zarr
from zarr._storage.store import _get_hierarchy_metadata
from zarr.codecs import BZ2, AsType, Blosc, Zlib
from zarr.context import Context
from zarr.convenience import consolidate_metadata
from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataError
from zarr.hierarchy import group
Expand All @@ -37,7 +38,7 @@
from zarr.storage import FSStore, rename, listdir
from zarr._storage.v3 import KVStoreV3
from zarr.tests.util import CountingDict, have_fsspec, skip_test_env_var, abs_container, mktemp
from zarr.util import json_dumps
from zarr.util import ConstantMap, json_dumps


@contextmanager
Expand Down Expand Up @@ -2584,3 +2585,35 @@ def test_meta_prefix_6853():

fixtures = group(store=DirectoryStore(str(fixture)))
assert list(fixtures.arrays())


def test_getitems_contexts():

class MyStore(CountingDict):
def __init__(self):
super().__init__()
self.last_contexts = None

def getitems(self, keys, *, contexts):
self.last_contexts = contexts
return super().getitems(keys, contexts=contexts)

store = MyStore()
z = zarr.create(shape=(10,), chunks=1, store=store)

# By default, not contexts are given to the store's getitems()
z[0]
assert len(store.last_contexts) == 0

# Setting a non-default meta_array, will create contexts for the store's getitems()
z._meta_array = "my_meta_array"
z[0]
assert store.last_contexts == {'0': {'meta_array': 'my_meta_array'}}
assert isinstance(store.last_contexts, ConstantMap)
# Accseeing different chunks should trigger different key request
z[1]
assert store.last_contexts == {'1': {'meta_array': 'my_meta_array'}}
assert isinstance(store.last_contexts, ConstantMap)
z[2:4]
assert store.last_contexts == ConstantMap(['2', '3'], Context({'meta_array': 'my_meta_array'}))
assert isinstance(store.last_contexts, ConstantMap)
2 changes: 2 additions & 0 deletions zarr/tests/test_storage_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ def _get_public_and_dunder_methods(some_class):
def test_storage_transformer_interface():
store_v3_methods = _get_public_and_dunder_methods(StoreV3)
store_v3_methods.discard("__init__")
# Note, getitems() isn't mandatory when get_partial_values() is available
store_v3_methods.discard("getitems")
storage_transformer_methods = _get_public_and_dunder_methods(StorageTransformer)
storage_transformer_methods.discard("__init__")
storage_transformer_methods.discard("get_config")
Expand Down
15 changes: 14 additions & 1 deletion zarr/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from zarr.core import Array
from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size,
from zarr.util import (ConstantMap, all_equal, flatten, guess_chunks, human_readable_size,
info_html_report, info_text_report, is_total_slice,
json_dumps, normalize_chunks,
normalize_dimension_separator,
Expand Down Expand Up @@ -248,3 +248,16 @@ def test_json_dumps_numpy_dtype():
# Check that we raise the error of the superclass for unsupported object
with pytest.raises(TypeError):
json_dumps(Array)


def test_constant_map():
val = object()
m = ConstantMap(keys=[1, 2], constant=val)
assert len(m) == 2
assert m[1] is val
assert m[2] is val
assert 1 in m
assert 0 not in m
with pytest.raises(KeyError):
m[0]
assert repr(m) == repr({1: val, 2: val})
9 changes: 9 additions & 0 deletions zarr/tests/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import os
import tempfile
from typing import Any, Mapping, Sequence
from zarr.context import Context

from zarr.storage import Store
from zarr._storage.v3 import StoreV3
Expand Down Expand Up @@ -42,6 +44,13 @@ def __delitem__(self, key):
self.counter['__delitem__', key] += 1
del self.wrapped[key]

def getitems(
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
) -> Mapping[str, Any]:
for key in keys:
self.counter['__getitem__', key] += 1
return {k: self.wrapped[k] for k in keys if k in self.wrapped}


class CountingDictV3(CountingDict, StoreV3):
pass
Expand Down
Loading

0 comments on commit b14f15f

Please sign in to comment.