From 272382658d7e817f6447ade1706e67e9d8aecf79 Mon Sep 17 00:00:00 2001 From: Jonathan Striebel Date: Fri, 21 Jan 2022 18:19:33 +0100 Subject: [PATCH 1/4] sharding prototype --- sharding_test.py | 86 ++++++++++++ zarrita.py | 337 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 388 insertions(+), 35 deletions(-) create mode 100644 sharding_test.py diff --git a/sharding_test.py b/sharding_test.py new file mode 100644 index 0000000..c8b79b9 --- /dev/null +++ b/sharding_test.py @@ -0,0 +1,86 @@ +import json +import os +import shutil + +import zarrita + + +shutil.rmtree("sharding_test.zr3", ignore_errors=True) +h = zarrita.create_hierarchy("sharding_test.zr3") +a = h.create_array( + path="testarray", + shape=(20, 3), + dtype="float64", + chunk_shape=(3, 2), + shards=(2, 2), +) + +a[:10, :] = 42 +a[15, 1] = 389 +a[19, 2] = 1 +a[0, 1] = -4.2 + +assert a.store._shards == (2, 2) +assert a[15, 1] == 389 +assert a[19, 2] == 1 +assert a[0, 1] == -4.2 +assert a[0, 0] == 42 + +array_json = a.store["meta/root/testarray.array.json"].decode() + +print(array_json) +# { +# "shape": [ +# 20, +# 3 +# ], +# "data_type": " 0: + print(" ", root.ljust(40), *sorted(files)) +print("UNDERLYING STORE", sorted(i.rsplit("c")[-1] for i in a.store._store if i.startswith("data"))) +print("STORE", sorted(i.rsplit("c")[-1] for i in a.store if i.startswith("data"))) +# ONDISK +# sharding_test.zr3 zarr.json +# sharding_test.zr3/data/root/testarray/c0 0 +# sharding_test.zr3/data/root/testarray/c1 0 +# sharding_test.zr3/data/root/testarray/c2 0 +# sharding_test.zr3/data/root/testarray/c3 0 +# sharding_test.zr3/meta/root testarray.array.json +# UNDERLYING STORE ['0/0', '1/0', '2/0', '3/0'] +# STORE ['0/0', '0/1', '1/0', '1/1', '2/0', '2/1', '3/0', '3/1', '5/0', '6/1'] + +index_bytes = a.store._store["data/root/testarray/c0/0"][-2*2*16:] +print("INDEX 0.0", [int.from_bytes(index_bytes[i:i+8], byteorder="little") for i in range(0, len(index_bytes), 8)]) +# INDEX 0.0 [0, 48, 48, 48, 96, 48, 144, 48] + + +a_reopened = zarrita.get_hierarchy("sharding_test.zr3").get_array("testarray") +assert a_reopened.store._shards == (2, 2) +assert a_reopened[15, 1] == 389 +assert a_reopened[19, 2] == 1 +assert a_reopened[0, 1] == -4.2 +assert a_reopened[0, 0] == 42 diff --git a/zarrita.py b/zarrita.py index a10214b..2eeeef2 100644 --- a/zarrita.py +++ b/zarrita.py @@ -3,10 +3,12 @@ import json import numbers import itertools +import functools import math import re +from collections import defaultdict from collections.abc import Mapping, MutableMapping -from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple +from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple, Iterable, Type # third-party dependencies @@ -170,6 +172,23 @@ def _check_compressor(compressor: Optional[Codec]) -> None: assert compressor is None or isinstance(compressor, Codec) +def _check_shard_format(shard_format: str) -> None: + assert shard_format in SHARDED_STORES, ( + f"Shard format {shard_format} is not supported, " + + f"use one of {list(SHARDED_STORES)}" + ) + + +def _check_shards(shards: Union[int, Tuple[int, ...], None]) -> Optional[Tuple[int, ...]]: + if shards is None: + return None + assert isinstance(shards, (int, tuple)) + if isinstance(shards, int): + shards = shards, + assert all([isinstance(s, int) for s in shards]) + return shards + + def _encode_codec_metadata(codec: Codec) -> Optional[Mapping]: if codec is None: return None @@ -265,7 +284,9 @@ def create_array(self, chunk_separator: str = "/", compressor: Optional[Codec] = None, fill_value: Any = None, - attrs: Optional[Mapping] = None) -> Array: + attrs: Optional[Mapping] = None, + shard_format: str = "indexed", + shards: Union[int, Tuple[int, ...], None] = None) -> Array: # sanity checks path = _check_path(path) @@ -274,6 +295,8 @@ def create_array(self, chunk_shape = _check_chunk_shape(chunk_shape, shape) _check_compressor(compressor) attrs = _check_attrs(attrs) + _check_shard_format(shard_format) + shards = _check_shards(shards) # encode data type if dtype == np.bool_: @@ -297,6 +320,9 @@ def create_array(self, ) if compressor is not None: meta["compressor"] = _encode_codec_metadata(compressor) + if shards is not None: + meta["shards"] = shards + meta["shard_format"] = shard_format # serialise and store metadata document meta_doc = _json_encode_object(meta) @@ -307,7 +333,8 @@ def create_array(self, array = Array(store=self.store, path=path, owner=self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, - fill_value=fill_value, attrs=attrs) + fill_value=fill_value, attrs=attrs, + shard_format=shard_format, shards=shards) return array @@ -341,12 +368,17 @@ def get_array(self, path: str) -> Array: if spec["must_understand"]: raise NotImplementedError(spec) attrs = meta["attributes"] + shards = meta.get("shards", None) + if shards is not None: + shards = tuple(shards) + shard_format = meta.get("shard_format", "indexed") # instantiate array a = Array(store=self.store, path=path, owner=self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, - fill_value=fill_value, attrs=attrs) + fill_value=fill_value, attrs=attrs, + shard_format=shard_format, shards=shards) return a @@ -587,7 +619,16 @@ def __init__(self, chunk_separator: str, compressor: Optional[Codec], fill_value: Any = None, - attrs: Optional[Mapping] = None): + attrs: Optional[Mapping] = None, + shard_format: str = "indexed", + shards: Optional[Tuple[int, ...]] = None, + ): + if shards is not None: + store = SHARDED_STORES[shard_format]( # type: ignore + store=store, + shards=shards, + chunk_separator=chunk_separator, + ) super().__init__(store=store, path=path, owner=owner) self.shape = shape self.dtype = dtype @@ -613,41 +654,33 @@ def _get_selection(self, indexer): # setup output array out = np.zeros(indexer.shape, dtype=self.dtype, order="C") - # iterate over chunks + chunk_keys = {chunk_coords: self._chunk_key(chunk_coords) for chunk_coords, _, _ in indexer} + encoded_chunk_data = self.store.getitems(chunk_keys.values()) + + # load chunk selection into output array for chunk_coords, chunk_selection, out_selection in indexer: + chunk_key = chunk_keys[chunk_coords] + if chunk_key in encoded_chunk_data: + encoded_chunk_data = encoded_chunk_data[chunk_key] + # decode chunk + chunk = self._decode_chunk(encoded_chunk_data) + + # select data from chunk + tmp = chunk[chunk_selection] - # load chunk selection into output array - self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection) + # store selected data in output + out[out_selection] = tmp + + else: + # chunk not initialized, maybe fill + if self.fill_value is not None: + out[out_selection] = self.fill_value if out.shape: return out else: return out[()] - def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection): - - # obtain key for chunk - chunk_key = self._chunk_key(chunk_coords) - - try: - # obtain encoded data for chunk - encoded_chunk_data = self.store[chunk_key] - - except KeyError: - # chunk not initialized, maybe fill - if self.fill_value is not None: - out[out_selection] = self.fill_value - - else: - # decode chunk - chunk = self._decode_chunk(encoded_chunk_data) - - # select data from chunk - tmp = chunk[chunk_selection] - - # store selected data in output - out[out_selection] = tmp - def _chunk_key(self, chunk_coords): chunk_identifier = "c" + self.chunk_separator.join(map(str, chunk_coords)) chunk_key = f"data/root{self.path}/{chunk_identifier}" @@ -703,6 +736,7 @@ def _set_selection(self, indexer, value): assert value.shape == sel_shape # iterate over chunks in range + tmp_result = {} for chunk_coords, chunk_selection, out_selection in indexer: # extract data to store @@ -714,9 +748,10 @@ def _set_selection(self, indexer, value): chunk_value = value[out_selection] # put data - self._chunk_setitem(chunk_coords, chunk_selection, chunk_value) + self._chunk_setitem(chunk_coords, chunk_selection, chunk_value, tmp_result) + self.store.setitems(tmp_result) - def _chunk_setitem(self, chunk_coords, chunk_selection, value): + def _chunk_setitem(self, chunk_coords, chunk_selection, value, tmp_result): # obtain key for chunk storage chunk_key = self._chunk_key(chunk_coords) @@ -744,6 +779,7 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value): try: # obtain compressed data for chunk + # this might be optimized to use getitems encoded_chunk_data = self.store[chunk_key] except KeyError: @@ -771,7 +807,7 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value): encoded_chunk_data = self._encode_chunk(chunk) # store - self.store[chunk_key] = encoded_chunk_data + tmp_result[chunk_key] = encoded_chunk_data.tobytes() def _encode_chunk(self, chunk): @@ -1038,12 +1074,29 @@ class Store(MutableMapping): def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: raise NotImplementedError + def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: + result = {} + for key in keys: + try: + result[key] = self[key] + except KeyError: + pass + return result + def __setitem__(self, key: str, value: bytes) -> None: raise NotImplementedError + def setitems(self, values: Dict[str, bytes]) -> None: + for key, value in values.items(): + self[key] = value + def __delitem__(self, key: str) -> None: raise NotImplementedError + def delitems(self, keys: Iterable[str]) -> None: + for key in keys: + del self[key] + def __iter__(self) -> Iterator[str]: raise NotImplementedError @@ -1057,6 +1110,32 @@ def list_dir(self, prefix: str) -> ListDirResult: raise NotImplementedError +class MultiStore(Store): + # Stores that can optimize reads and writes of multiple keys + # should inherit from MultiStore + + def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: + result = self.getitems([key]).get(key, default) + if result is None: + raise KeyError(key) + return result + + def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: + raise NotImplementedError + + def __setitem__(self, key: str, value: bytes) -> None: + self.setitems({key: value}) + + def setitems(self, values: Dict[str, bytes]) -> None: + raise NotImplementedError + + def __delitem__(self, key: str) -> None: + self.delitems([key]) + + def delitems(self, keys: Iterable[str]) -> None: + raise NotImplementedError + + class FileSystemStore(Store): # TODO ultimately replace this with the fsspec FSMap class, but for now roll @@ -1146,3 +1225,191 @@ def __repr__(self) -> str: if isinstance(protocol, tuple): protocol = protocol[-1] return f"{protocol}://{self.root}" + + +MAX_UINT_64 = 2 ** 64 - 1 + + +def _partition(pred, iterable): + 'Use a predicate to partition entries into false entries and true entries' + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = itertools.tee(iterable) + return itertools.filterfalse(pred, t1), filter(pred, t2) + + +def _is_data_key(key: str) -> bool: + return key.startswith("data/root") + + +class _ShardIndex(NamedTuple): + store: "IndexedShardedStore" + offsets_and_lengths: np.ndarray # dtype uint64, shape (shards_0, _shards_1, ..., 2) + + def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]: + return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._shards)) + + def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]: + localized_chunk = self.__localize_chunk__(chunk) + chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk] + if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): + return None + else: + return slice(chunk_start, chunk_start + chunk_len) + + def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None: + localized_chunk = self.__localize_chunk__(chunk) + if chunk_slice is None: + self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64) + else: + self.offsets_and_lengths[localized_chunk] = ( + chunk_slice.start, + chunk_slice.stop - chunk_slice.start + ) + + def to_bytes(self) -> bytes: + return self.offsets_and_lengths.tobytes(order='C') + + @classmethod + def from_bytes( + cls, buffer: Union[bytes, bytearray], store: "IndexedShardedStore" + ) -> "_ShardIndex": + return cls( + store=store, + offsets_and_lengths=np.frombuffer( + bytearray(buffer), dtype=" None: + self._store = store + self._shards = shards + self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, shards, 1) + self._chunk_separator = chunk_separator + + def __keys_to_shard_groups__( + self, keys: Iterable[str] + ) -> Dict[str, List[Tuple[str, Tuple[int, ...]]]]: + shard_indices_per_shard_key = defaultdict(list) + for chunk_key in keys: + prefix, _, chunk_string = chunk_key.rpartition("c") + chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator))) + shard_key_tuple = ( + subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards) + ) + shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple)) + shard_indices_per_shard_key[shard_key].append((chunk_key, chunk_subkeys)) + return shard_indices_per_shard_key + + def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: + # At the end of each shard 2*64bit per chunk for offset and length define the index: + return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self) + + def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: + _, _, chunk_string = shard_key.rpartition("c") + shard_key_tuple = tuple(map(int, chunk_string.split(self._chunk_separator))) + for chunk_offset in itertools.product(*(range(i) for i in self._shards)): + yield tuple( + shard_key_i * shards_i + offset_i + for shard_key_i, offset_i, shards_i + in zip(shard_key_tuple, chunk_offset, self._shards) + ) + + def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: + other_keys, data_keys = _partition(_is_data_key, keys) + result = self._store.getitems(other_keys) + + for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(data_keys).items(): + full_shard_value = self._store[shard_key] + index = self.__get_index__(full_shard_value) + for chunk_key, chunk_subkeys in chunks_in_shard: + chunk_slice = index.get_chunk_slice(chunk_subkeys) + if chunk_slice is not None: + result[chunk_key] = full_shard_value[chunk_slice] + return result + + def setitems(self, values: Dict[str, bytes]) -> None: + other_keys, data_keys = _partition(_is_data_key, values.keys()) + to_set = {key: values[key] for key in other_keys} + shard_groups = self.__keys_to_shard_groups__(data_keys) + full_shard_values = self._store.getitems(shard_groups.keys()) + for shard_key, chunks_in_shard in shard_groups.items(): + all_chunks = set(self.__get_chunks_in_shard(shard_key)) + chunks_to_set = set(chunk_subkeys for _chunk_key, chunk_subkeys in chunks_in_shard) + chunks_to_read = all_chunks - chunks_to_set + new_content = { + chunk_subkeys: values[chunk_key] for chunk_key, chunk_subkeys in chunks_in_shard + } + try: + full_shard_value = full_shard_values[shard_key] + except KeyError: + index = _ShardIndex.create_empty(self) + else: + index = self.__get_index__(full_shard_value) + for chunk_to_read in chunks_to_read: + chunk_slice = index.get_chunk_slice(chunk_to_read) + if chunk_slice is not None: + new_content[chunk_to_read] = full_shard_value[chunk_slice] + + shard_content = b"" + # TODO: order the chunks in the shard: + for chunk_subkeys, chunk_content in new_content.items(): + chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content)) + index.set_chunk_slice(chunk_subkeys, chunk_slice) + shard_content += chunk_content + # Appending the index at the end of the shard: + shard_content += index.to_bytes() + to_set[shard_key] = shard_content + self._store.setitems(to_set) + + def delitems(self, keys: Iterable[str]) -> None: + raise NotImplementedError + + def __shard_key_to_original_keys__(self, key: str) -> Iterator[str]: + if not _is_data_key(key): + # Special keys such as meta-keys are passed on as-is + yield key + else: + index = self.__get_index__(self._store[key]) + prefix, _, _ = key.rpartition("c") + for chunk_tuple in self.__get_chunks_in_shard(key): + if index.get_chunk_slice(chunk_tuple) is not None: + yield prefix + "c" + self._chunk_separator.join(map(str, chunk_tuple)) + + def __iter__(self) -> Iterator[str]: + for key in self._store: + yield from self.__shard_key_to_original_keys__(key) + + def list_prefix(self, prefix: str) -> List[str]: + if _is_data_key(prefix): + # Needs translation of the prefix to shard_key + raise NotImplementedError + return self._store.list_prefix(prefix) + + def list_dir(self, prefix: str) -> ListDirResult: + if _is_data_key(prefix): + # Needs translation of the prefix to shard_key + raise NotImplementedError + return self._store.list_dir(prefix) + + +SHARDED_STORES: Dict[str, Type[Store]] = { + "indexed": IndexedShardedStore, +} From 27a99d085c572cbe6ee8580d17b2cd384913da00 Mon Sep 17 00:00:00 2001 From: Jonathan Striebel Date: Mon, 24 Jan 2022 12:45:13 +0100 Subject: [PATCH 2/4] simplify implementation, remove getitems/setitems --- zarrita.py | 177 +++++++++++++++++++---------------------------------- 1 file changed, 64 insertions(+), 113 deletions(-) diff --git a/zarrita.py b/zarrita.py index 2eeeef2..9adba8d 100644 --- a/zarrita.py +++ b/zarrita.py @@ -6,7 +6,6 @@ import functools import math import re -from collections import defaultdict from collections.abc import Mapping, MutableMapping from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple, Iterable, Type @@ -654,33 +653,41 @@ def _get_selection(self, indexer): # setup output array out = np.zeros(indexer.shape, dtype=self.dtype, order="C") - chunk_keys = {chunk_coords: self._chunk_key(chunk_coords) for chunk_coords, _, _ in indexer} - encoded_chunk_data = self.store.getitems(chunk_keys.values()) - - # load chunk selection into output array + # iterate over chunks for chunk_coords, chunk_selection, out_selection in indexer: - chunk_key = chunk_keys[chunk_coords] - if chunk_key in encoded_chunk_data: - encoded_chunk_data = encoded_chunk_data[chunk_key] - # decode chunk - chunk = self._decode_chunk(encoded_chunk_data) - - # select data from chunk - tmp = chunk[chunk_selection] - # store selected data in output - out[out_selection] = tmp - - else: - # chunk not initialized, maybe fill - if self.fill_value is not None: - out[out_selection] = self.fill_value + # load chunk selection into output array + self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection) if out.shape: return out else: return out[()] + def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection): + + # obtain key for chunk + chunk_key = self._chunk_key(chunk_coords) + + try: + # obtain encoded data for chunk + encoded_chunk_data = self.store[chunk_key] + + except KeyError: + # chunk not initialized, maybe fill + if self.fill_value is not None: + out[out_selection] = self.fill_value + + else: + # decode chunk + chunk = self._decode_chunk(encoded_chunk_data) + + # select data from chunk + tmp = chunk[chunk_selection] + + # store selected data in output + out[out_selection] = tmp + def _chunk_key(self, chunk_coords): chunk_identifier = "c" + self.chunk_separator.join(map(str, chunk_coords)) chunk_key = f"data/root{self.path}/{chunk_identifier}" @@ -736,7 +743,6 @@ def _set_selection(self, indexer, value): assert value.shape == sel_shape # iterate over chunks in range - tmp_result = {} for chunk_coords, chunk_selection, out_selection in indexer: # extract data to store @@ -748,10 +754,9 @@ def _set_selection(self, indexer, value): chunk_value = value[out_selection] # put data - self._chunk_setitem(chunk_coords, chunk_selection, chunk_value, tmp_result) - self.store.setitems(tmp_result) + self._chunk_setitem(chunk_coords, chunk_selection, chunk_value) - def _chunk_setitem(self, chunk_coords, chunk_selection, value, tmp_result): + def _chunk_setitem(self, chunk_coords, chunk_selection, value): # obtain key for chunk storage chunk_key = self._chunk_key(chunk_coords) @@ -779,7 +784,6 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value, tmp_result): try: # obtain compressed data for chunk - # this might be optimized to use getitems encoded_chunk_data = self.store[chunk_key] except KeyError: @@ -807,7 +811,7 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value, tmp_result): encoded_chunk_data = self._encode_chunk(chunk) # store - tmp_result[chunk_key] = encoded_chunk_data.tobytes() + self.store[chunk_key] = encoded_chunk_data.tobytes() def _encode_chunk(self, chunk): @@ -1074,22 +1078,9 @@ class Store(MutableMapping): def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: raise NotImplementedError - def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: - result = {} - for key in keys: - try: - result[key] = self[key] - except KeyError: - pass - return result - def __setitem__(self, key: str, value: bytes) -> None: raise NotImplementedError - def setitems(self, values: Dict[str, bytes]) -> None: - for key, value in values.items(): - self[key] = value - def __delitem__(self, key: str) -> None: raise NotImplementedError @@ -1110,32 +1101,6 @@ def list_dir(self, prefix: str) -> ListDirResult: raise NotImplementedError -class MultiStore(Store): - # Stores that can optimize reads and writes of multiple keys - # should inherit from MultiStore - - def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: - result = self.getitems([key]).get(key, default) - if result is None: - raise KeyError(key) - return result - - def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: - raise NotImplementedError - - def __setitem__(self, key: str, value: bytes) -> None: - self.setitems({key: value}) - - def setitems(self, values: Dict[str, bytes]) -> None: - raise NotImplementedError - - def __delitem__(self, key: str) -> None: - self.delitems([key]) - - def delitems(self, keys: Iterable[str]) -> None: - raise NotImplementedError - - class FileSystemStore(Store): # TODO ultimately replace this with the fsspec FSMap class, but for now roll @@ -1230,13 +1195,6 @@ def __repr__(self) -> str: MAX_UINT_64 = 2 ** 64 - 1 -def _partition(pred, iterable): - 'Use a predicate to partition entries into false entries and true entries' - # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 - t1, t2 = itertools.tee(iterable) - return itertools.filterfalse(pred, t1), filter(pred, t2) - - def _is_data_key(key: str) -> bool: return key.startswith("data/root") @@ -1289,7 +1247,7 @@ def create_empty(cls, store: "IndexedShardedStore"): ) -class IndexedShardedStore(MultiStore): +class IndexedShardedStore(Store): """This class should not be used directly, but is added to an Array as a wrapper when needed automatically.""" @@ -1304,19 +1262,16 @@ def __init__( self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, shards, 1) self._chunk_separator = chunk_separator - def __keys_to_shard_groups__( - self, keys: Iterable[str] - ) -> Dict[str, List[Tuple[str, Tuple[int, ...]]]]: - shard_indices_per_shard_key = defaultdict(list) - for chunk_key in keys: - prefix, _, chunk_string = chunk_key.rpartition("c") - chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator))) - shard_key_tuple = ( - subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards) - ) - shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple)) - shard_indices_per_shard_key[shard_key].append((chunk_key, chunk_subkeys)) - return shard_indices_per_shard_key + def __key_to_shard__( + self, chunk_key: str + ) -> Tuple[str, Tuple[int, ...]]: + prefix, _, chunk_string = chunk_key.rpartition("c") + chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator))) + shard_key_tuple = ( + subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards) + ) + shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple)) + return shard_key, chunk_subkeys def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: # At the end of each shard 2*64bit per chunk for offset and length define the index: @@ -1332,33 +1287,29 @@ def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: in zip(shard_key_tuple, chunk_offset, self._shards) ) - def getitems(self, keys: Iterable[str]) -> Dict[str, bytes]: - other_keys, data_keys = _partition(_is_data_key, keys) - result = self._store.getitems(other_keys) - - for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(data_keys).items(): + def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: + if _is_data_key(key): + shard_key, chunk_subkeys = self.__key_to_shard__(key) full_shard_value = self._store[shard_key] index = self.__get_index__(full_shard_value) - for chunk_key, chunk_subkeys in chunks_in_shard: - chunk_slice = index.get_chunk_slice(chunk_subkeys) - if chunk_slice is not None: - result[chunk_key] = full_shard_value[chunk_slice] - return result - - def setitems(self, values: Dict[str, bytes]) -> None: - other_keys, data_keys = _partition(_is_data_key, values.keys()) - to_set = {key: values[key] for key in other_keys} - shard_groups = self.__keys_to_shard_groups__(data_keys) - full_shard_values = self._store.getitems(shard_groups.keys()) - for shard_key, chunks_in_shard in shard_groups.items(): - all_chunks = set(self.__get_chunks_in_shard(shard_key)) - chunks_to_set = set(chunk_subkeys for _chunk_key, chunk_subkeys in chunks_in_shard) - chunks_to_read = all_chunks - chunks_to_set - new_content = { - chunk_subkeys: values[chunk_key] for chunk_key, chunk_subkeys in chunks_in_shard - } + chunk_slice = index.get_chunk_slice(chunk_subkeys) + if chunk_slice is not None: + return full_shard_value[chunk_slice] + else: + if default is not None: + return default + raise KeyError(key) + else: + return self._store.__getitem__(key, default) + + def __setitem__(self, key: str, value: bytes) -> None: + if _is_data_key(key): + shard_key, chunk_subkeys = self.__key_to_shard__(key) + chunks_to_read = set(self.__get_chunks_in_shard(shard_key)) + chunks_to_read.remove(chunk_subkeys) + new_content = {chunk_subkeys: value} try: - full_shard_value = full_shard_values[shard_key] + full_shard_value = self._store[shard_key] except KeyError: index = _ShardIndex.create_empty(self) else: @@ -1369,15 +1320,15 @@ def setitems(self, values: Dict[str, bytes]) -> None: new_content[chunk_to_read] = full_shard_value[chunk_slice] shard_content = b"" - # TODO: order the chunks in the shard: for chunk_subkeys, chunk_content in new_content.items(): chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content)) index.set_chunk_slice(chunk_subkeys, chunk_slice) shard_content += chunk_content # Appending the index at the end of the shard: shard_content += index.to_bytes() - to_set[shard_key] = shard_content - self._store.setitems(to_set) + self._store[shard_key] = shard_content + else: + self._store[key] = value def delitems(self, keys: Iterable[str]) -> None: raise NotImplementedError From 2cfcb5eb72cb298094e91f60cada28406829c4af Mon Sep 17 00:00:00 2001 From: Jonathan Striebel Date: Mon, 24 Jan 2022 12:46:19 +0100 Subject: [PATCH 3/4] rm delitems --- zarrita.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/zarrita.py b/zarrita.py index 9adba8d..b9d944a 100644 --- a/zarrita.py +++ b/zarrita.py @@ -1084,10 +1084,6 @@ def __setitem__(self, key: str, value: bytes) -> None: def __delitem__(self, key: str) -> None: raise NotImplementedError - def delitems(self, keys: Iterable[str]) -> None: - for key in keys: - del self[key] - def __iter__(self) -> Iterator[str]: raise NotImplementedError @@ -1330,9 +1326,6 @@ def __setitem__(self, key: str, value: bytes) -> None: else: self._store[key] = value - def delitems(self, keys: Iterable[str]) -> None: - raise NotImplementedError - def __shard_key_to_original_keys__(self, key: str) -> Iterator[str]: if not _is_data_key(key): # Special keys such as meta-keys are passed on as-is From 31789f65360cb639cff44248af8388bcf331b122 Mon Sep 17 00:00:00 2001 From: Jonathan Striebel Date: Mon, 24 Jan 2022 18:30:28 +0100 Subject: [PATCH 4/4] cleanup --- sharding_test.py | 20 +++++----- zarrita.py | 95 +++++++++++++++++++++--------------------------- 2 files changed, 53 insertions(+), 62 deletions(-) diff --git a/sharding_test.py b/sharding_test.py index c8b79b9..0c1b3c3 100644 --- a/sharding_test.py +++ b/sharding_test.py @@ -12,7 +12,7 @@ shape=(20, 3), dtype="float64", chunk_shape=(3, 2), - shards=(2, 2), + sharding={"chunks_per_shard": (2, 2)}, ) a[:10, :] = 42 @@ -20,7 +20,7 @@ a[19, 2] = 1 a[0, 1] = -4.2 -assert a.store._shards == (2, 2) +assert a.store._chunks_per_shard == (2, 2) assert a[15, 1] == 389 assert a[19, 2] == 1 assert a[0, 1] == -4.2 @@ -47,14 +47,16 @@ # "fill_value": null, # "extensions": [], # "attributes": {}, -# "shards": [ -# 2, -# 2 -# ], -# "shard_format": "indexed" +# "sharding": { +# "chunks_per_shard": [ +# 2, +# 2 +# ], +# "format": "indexed" +# } # } -assert json.loads(array_json)["shards"] == [2, 2] +assert json.loads(array_json)["sharding"]["chunks_per_shard"] == [2, 2] print("ONDISK") for root, dirs, files in os.walk("sharding_test.zr3"): @@ -79,7 +81,7 @@ a_reopened = zarrita.get_hierarchy("sharding_test.zr3").get_array("testarray") -assert a_reopened.store._shards == (2, 2) +assert a_reopened.store._chunks_per_shard == (2, 2) assert a_reopened[15, 1] == 389 assert a_reopened[19, 2] == 1 assert a_reopened[0, 1] == -4.2 diff --git a/zarrita.py b/zarrita.py index b9d944a..37ed72b 100644 --- a/zarrita.py +++ b/zarrita.py @@ -171,21 +171,16 @@ def _check_compressor(compressor: Optional[Codec]) -> None: assert compressor is None or isinstance(compressor, Codec) -def _check_shard_format(shard_format: str) -> None: - assert shard_format in SHARDED_STORES, ( - f"Shard format {shard_format} is not supported, " +def _check_sharding(sharding: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if sharding is None: + return None + if "format" not in sharding: + sharding["format"] = "indexed" + assert sharding["format"] in SHARDED_STORES, ( + f"Shard format {sharding['format']} is not supported, " + f"use one of {list(SHARDED_STORES)}" ) - - -def _check_shards(shards: Union[int, Tuple[int, ...], None]) -> Optional[Tuple[int, ...]]: - if shards is None: - return None - assert isinstance(shards, (int, tuple)) - if isinstance(shards, int): - shards = shards, - assert all([isinstance(s, int) for s in shards]) - return shards + return sharding def _encode_codec_metadata(codec: Codec) -> Optional[Mapping]: @@ -284,8 +279,7 @@ def create_array(self, compressor: Optional[Codec] = None, fill_value: Any = None, attrs: Optional[Mapping] = None, - shard_format: str = "indexed", - shards: Union[int, Tuple[int, ...], None] = None) -> Array: + sharding: Optional[Dict[str, Any]] = None) -> Array: # sanity checks path = _check_path(path) @@ -294,8 +288,7 @@ def create_array(self, chunk_shape = _check_chunk_shape(chunk_shape, shape) _check_compressor(compressor) attrs = _check_attrs(attrs) - _check_shard_format(shard_format) - shards = _check_shards(shards) + sharding = _check_sharding(sharding) # encode data type if dtype == np.bool_: @@ -319,9 +312,8 @@ def create_array(self, ) if compressor is not None: meta["compressor"] = _encode_codec_metadata(compressor) - if shards is not None: - meta["shards"] = shards - meta["shard_format"] = shard_format + if sharding is not None: + meta["sharding"] = sharding # serialise and store metadata document meta_doc = _json_encode_object(meta) @@ -333,7 +325,7 @@ def create_array(self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, fill_value=fill_value, attrs=attrs, - shard_format=shard_format, shards=shards) + sharding=sharding) return array @@ -367,17 +359,13 @@ def get_array(self, path: str) -> Array: if spec["must_understand"]: raise NotImplementedError(spec) attrs = meta["attributes"] - shards = meta.get("shards", None) - if shards is not None: - shards = tuple(shards) - shard_format = meta.get("shard_format", "indexed") + sharding = meta.get("sharding", None) # instantiate array a = Array(store=self.store, path=path, owner=self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, - fill_value=fill_value, attrs=attrs, - shard_format=shard_format, shards=shards) + fill_value=fill_value, attrs=attrs, sharding=sharding) return a @@ -619,14 +607,13 @@ def __init__(self, compressor: Optional[Codec], fill_value: Any = None, attrs: Optional[Mapping] = None, - shard_format: str = "indexed", - shards: Optional[Tuple[int, ...]] = None, + sharding: Optional[Dict[str, Any]] = None, ): - if shards is not None: - store = SHARDED_STORES[shard_format]( # type: ignore + if sharding is not None: + store = SHARDED_STORES[sharding["format"]]( # type: ignore store=store, - shards=shards, chunk_separator=chunk_separator, + **sharding, ) super().__init__(store=store, path=path, owner=owner) self.shape = shape @@ -1197,10 +1184,10 @@ def _is_data_key(key: str) -> bool: class _ShardIndex(NamedTuple): store: "IndexedShardedStore" - offsets_and_lengths: np.ndarray # dtype uint64, shape (shards_0, _shards_1, ..., 2) + offsets_and_lengths: np.ndarray # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]: - return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._shards)) + return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._chunks_per_shard)) def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]: localized_chunk = self.__localize_chunk__(chunk) @@ -1231,7 +1218,7 @@ def from_bytes( store=store, offsets_and_lengths=np.frombuffer( bytearray(buffer), dtype=" None: self._store = store - self._shards = shards - self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, shards, 1) + self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, chunks_per_shard, 1) self._chunk_separator = chunk_separator + assert all(isinstance(s, int) for s in chunks_per_shard) + self._chunks_per_shard = tuple(chunks_per_shard) - def __key_to_shard__( + def _key_to_shard( self, chunk_key: str ) -> Tuple[str, Tuple[int, ...]]: prefix, _, chunk_string = chunk_key.rpartition("c") chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator))) shard_key_tuple = ( - subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards) + subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._chunks_per_shard) ) shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple)) return shard_key, chunk_subkeys - def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: + def _get_index(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: # At the end of each shard 2*64bit per chunk for offset and length define the index: return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self) - def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: + def _get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: _, _, chunk_string = shard_key.rpartition("c") shard_key_tuple = tuple(map(int, chunk_string.split(self._chunk_separator))) - for chunk_offset in itertools.product(*(range(i) for i in self._shards)): + for chunk_offset in itertools.product(*(range(i) for i in self._chunks_per_shard)): yield tuple( shard_key_i * shards_i + offset_i for shard_key_i, offset_i, shards_i - in zip(shard_key_tuple, chunk_offset, self._shards) + in zip(shard_key_tuple, chunk_offset, self._chunks_per_shard) ) def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: if _is_data_key(key): - shard_key, chunk_subkeys = self.__key_to_shard__(key) + shard_key, chunk_subkeys = self._key_to_shard(key) full_shard_value = self._store[shard_key] - index = self.__get_index__(full_shard_value) + index = self._get_index(full_shard_value) chunk_slice = index.get_chunk_slice(chunk_subkeys) if chunk_slice is not None: return full_shard_value[chunk_slice] @@ -1300,8 +1289,8 @@ def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: def __setitem__(self, key: str, value: bytes) -> None: if _is_data_key(key): - shard_key, chunk_subkeys = self.__key_to_shard__(key) - chunks_to_read = set(self.__get_chunks_in_shard(shard_key)) + shard_key, chunk_subkeys = self._key_to_shard(key) + chunks_to_read = set(self._get_chunks_in_shard(shard_key)) chunks_to_read.remove(chunk_subkeys) new_content = {chunk_subkeys: value} try: @@ -1309,7 +1298,7 @@ def __setitem__(self, key: str, value: bytes) -> None: except KeyError: index = _ShardIndex.create_empty(self) else: - index = self.__get_index__(full_shard_value) + index = self._get_index(full_shard_value) for chunk_to_read in chunks_to_read: chunk_slice = index.get_chunk_slice(chunk_to_read) if chunk_slice is not None: @@ -1326,20 +1315,20 @@ def __setitem__(self, key: str, value: bytes) -> None: else: self._store[key] = value - def __shard_key_to_original_keys__(self, key: str) -> Iterator[str]: + def _shard_key_to_original_keys(self, key: str) -> Iterator[str]: if not _is_data_key(key): # Special keys such as meta-keys are passed on as-is yield key else: - index = self.__get_index__(self._store[key]) + index = self._get_index(self._store[key]) prefix, _, _ = key.rpartition("c") - for chunk_tuple in self.__get_chunks_in_shard(key): + for chunk_tuple in self._get_chunks_in_shard(key): if index.get_chunk_slice(chunk_tuple) is not None: yield prefix + "c" + self._chunk_separator.join(map(str, chunk_tuple)) def __iter__(self) -> Iterator[str]: for key in self._store: - yield from self.__shard_key_to_original_keys__(key) + yield from self._shard_key_to_original_keys(key) def list_prefix(self, prefix: str) -> List[str]: if _is_data_key(prefix):