Skip to content

Commit

Permalink
Introduce ConstantMap
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Apr 11, 2023
1 parent d9838ef commit c3ee95f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
4 changes: 3 additions & 1 deletion zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available
from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
from zarr.context import Context
from zarr.errors import ArrayNotFoundError, ReadOnlyError, ArrayIndexError
from zarr.indexing import (
BasicIndexer,
Expand Down Expand Up @@ -41,6 +42,7 @@
normalize_store_arg,
)
from zarr.util import (
ConstantMap,
all_equal,
InfoReporter,
check_array_shape,
Expand Down Expand Up @@ -2022,7 +2024,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
partial_read_decode = False
contexts = {}
if not isinstance(self._meta_array, np.ndarray):
contexts = {k: {"meta_array": self._meta_array} for k in ckeys}
contexts = ConstantMap(ckeys, constant=Context(meta_array=self._meta_array))
cdatas = self.chunk_store.getitems(ckeys, contexts=contexts)

for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection):
Expand Down
13 changes: 11 additions & 2 deletions zarr/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import zarr
from zarr._storage.store import _get_hierarchy_metadata
from zarr.codecs import BZ2, AsType, Blosc, Zlib
from zarr.context import Context
from zarr.convenience import consolidate_metadata
from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataError
from zarr.hierarchy import group
Expand All @@ -37,7 +38,7 @@
from zarr.storage import FSStore, rename, listdir
from zarr._storage.v3 import KVStoreV3
from zarr.tests.util import CountingDict, have_fsspec, skip_test_env_var, abs_container, mktemp
from zarr.util import json_dumps
from zarr.util import ConstantMap, json_dumps


@contextmanager
Expand Down Expand Up @@ -2598,7 +2599,7 @@ def getitems(self, keys, *, contexts):
return super().getitems(keys, contexts=contexts)

store = MyStore()
z = zarr.create(shape=(10,), store=store)
z = zarr.create(shape=(10,), chunks=1, store=store)

# By default, not contexts are given to the store's getitems()
z[0]
Expand All @@ -2608,3 +2609,11 @@ def getitems(self, keys, *, contexts):
z._meta_array = "my_meta_array"
z[0]
assert store.last_contexts == {'0': {'meta_array': 'my_meta_array'}}
assert isinstance(store.last_contexts, ConstantMap)
# Accseeing different chunks should trigger different key request
z[1]
assert store.last_contexts == {'1': {'meta_array': 'my_meta_array'}}
assert isinstance(store.last_contexts, ConstantMap)
z[2:4]
assert store.last_contexts == ConstantMap(['2', '3'], Context({'meta_array': 'my_meta_array'}))
assert isinstance(store.last_contexts, ConstantMap)
52 changes: 50 additions & 2 deletions zarr/util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import collections.abc
import inspect
import json
import math
import numbers
from textwrap import TextWrapper
import mmap
import time
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterator,
Optional,
Tuple,
TypeVar,
Union,
Iterable
)

import numpy as np
from asciitree import BoxStyle, LeftAligned
from asciitree.traversal import Traversal
from collections.abc import Iterable
from numcodecs.compat import (
ensure_text,
ensure_ndarray_like,
Expand All @@ -21,6 +31,9 @@
from numcodecs.registry import codec_registry
from numcodecs.blosc import cbuffer_sizes, cbuffer_metainfo

KeyType = TypeVar('KeyType')
ValueType = TypeVar('ValueType')


def flatten(arg: Iterable) -> Iterable:
for element in arg:
Expand Down Expand Up @@ -745,3 +758,38 @@ def ensure_contiguous_ndarray_or_bytes(buf) -> Union[NDArrayLike, bytes]:
except TypeError:
# An error is raised if `buf` couldn't be zero-copy converted
return ensure_bytes(buf)


class ConstantMap(collections.abc.Mapping[KeyType, ValueType]):
"""A read-only map that maps all keys to the same constant value
Useful if you want to call `getitems()` with the same context for all keys.
Parameters
----------
keys
The keys of the map. Will be copied to a frozenset if it isn't already.
constant
The constant that all keys are mapping to.
"""

def __init__(self, keys: Iterable[KeyType], constant: ValueType) -> None:
self._keys = keys if isinstance(keys, frozenset) else frozenset(keys)
self._constant = constant

def __getitem__(self, key: KeyType) -> ValueType:
if key not in self._keys:
raise KeyError(repr(key))
return self._constant

def __iter__(self) -> Iterator[KeyType]:
return iter(self._keys)

def __len__(self) -> int:
return len(self._keys)

def __contains__(self, key: object) -> bool:
return key in self._keys

def __repr__(self) -> str:
return repr({k: v for k, v in self.items()})

0 comments on commit c3ee95f

Please sign in to comment.