From 6a5f845c328cecb7a02211b141781bc66bea674c Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Tue, 25 Jun 2024 12:40:08 +1000 Subject: [PATCH] Zarr reader driver --- odc/loader/_driver.py | 4 +- odc/loader/_zarr.py | 235 ++++++++++++++++++++++++++++------- odc/loader/test_memreader.py | 21 ++-- 3 files changed, 206 insertions(+), 54 deletions(-) diff --git a/odc/loader/_driver.py b/odc/loader/_driver.py index 68938b6..e9319c0 100644 --- a/odc/loader/_driver.py +++ b/odc/loader/_driver.py @@ -10,10 +10,12 @@ from typing import Any, Callable from ._rio import RioDriver +from ._zarr import XrMemReaderDriver from .types import ReaderDriver, ReaderDriverSpec, is_reader_driver _available_drivers: dict[str, Callable[[], ReaderDriver] | ReaderDriver] = { - "rio": RioDriver + "rio": RioDriver, + "zarr": XrMemReaderDriver, } diff --git a/odc/loader/_zarr.py b/odc/loader/_zarr.py index 699b516..4351059 100644 --- a/odc/loader/_zarr.py +++ b/odc/loader/_zarr.py @@ -5,14 +5,20 @@ from __future__ import annotations import json +import math +from collections.abc import Mapping, Sequence from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Sequence +from typing import Any, Iterator +import dask.array as da import fsspec import numpy as np import xarray as xr +from dask import is_dask_collection +from dask.array.core import normalize_chunks from dask.delayed import Delayed, delayed -from odc.geo.geobox import GeoBox +from fsspec.core import url_to_fs +from odc.geo.geobox import GeoBox, GeoBoxBase from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject from .types import ( @@ -30,40 +36,58 @@ # TODO: tighten specs for Zarr* SomeDoc = Mapping[str, Any] ZarrSpec = Mapping[str, Any] -ZarrSpecFs = Mapping[str, Any] -ZarrSpecFsDict = dict[str, Any] +ZarrSpecDict = dict[str, Any] # pylint: disable=too-few-public-methods -def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFsDict | None: - if ".zmetadata" in src: +def extract_zarr_spec(src: SomeDoc) -> ZarrSpecDict | None: + if ".zgroup" in src: return dict(src) if "zarr:metadata" in src: # TODO: handle zarr:chunks for reference filesystem - zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]} - elif "zarr_consolidated_format" in src: - zmd = dict(src) - else: - zmd = {"zarr_consolidated_format": 1, "metadata": src} + return dict(src["zarr:metadata"]) + + if "zarr_consolidated_format" in src: + return dict(src["metadata"]) - return {".zmetadata": json.dumps(zmd)} + if ".zmetadata" in src: + return dict(json.loads(src[".zmetadata"])["metadata"]) + + return None def _from_zarr_spec( - spec_doc: ZarrSpecFs, + spec_doc: ZarrSpecDict, + *, regen_coords: bool = False, - fs: fsspec.AbstractFileSystem | None = None, + chunk_store: fsspec.AbstractFileSystem | Mapping[str, Any] | None = None, chunks=None, target: str | None = None, fsspec_opts: dict[str, Any] | None = None, + drop_variables: Sequence[str] = (), ) -> xr.Dataset: fsspec_opts = fsspec_opts or {} - rfs = fsspec.filesystem( - "reference", fo=spec_doc, fs=fs, target=target, **fsspec_opts + if target is not None: + if chunk_store is None: + fs, target = url_to_fs(target, **fsspec_opts) + chunk_store = fs.get_mapper(target) + elif isinstance(chunk_store, fsspec.AbstractFileSystem): + chunk_store = chunk_store.get_mapper(target) + + # TODO: deal with coordinates being loaded at open time. + # + # When chunk store is supplied xarray will try to load index coords (i.e. + # name == dim, coords) + + xx = xr.open_zarr( + spec_doc, + chunk_store=chunk_store, + drop_variables=drop_variables, + chunks=chunks, + decode_coords="all", + consolidated=False, ) - - xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr", mode="r", chunks=chunks) gbox = xx.odc.geobox if gbox is not None and regen_coords: # re-gen x,y coords from geobox @@ -160,7 +184,13 @@ def __init__( def with_env(self, env: dict[str, Any]) -> "Context": assert isinstance(env, dict) - return Context(self.geobox, self.chunks) + return Context(self.geobox, self.chunks, driver=self.driver) + + @property + def fs(self) -> fsspec.AbstractFileSystem | None: + if self.driver is None: + return None + return self.driver.fs class XrSource: @@ -168,33 +198,89 @@ class XrSource: RasterSource -> xr.DataArray|xr.Dataset """ - def __init__(self, src: RasterSource, chunks: Any | None = None) -> None: + def __init__( + self, + src: RasterSource, + chunks: Any | None = None, + chunk_store: ( + fsspec.AbstractFileSystem | fsspec.FSMap | Mapping[str, Any] | None + ) = None, + drop_variables: Sequence[str] = (), + ) -> None: + if isinstance(chunk_store, fsspec.AbstractFileSystem): + chunk_store = chunk_store.get_mapper(src.uri) + driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data - self._spec: ZarrSpecFs | None = None + self._spec: ZarrSpecDict | None = None self._ds: xr.Dataset | None = None self._xx: xr.DataArray | None = None self._src = src self._chunks = chunks + self._chunk_store = chunk_store + self._drop_variables = drop_variables + + subdataset = self._src.subdataset if isinstance(driver_data, xr.DataArray): self._xx = driver_data elif isinstance(driver_data, xr.Dataset): - subdataset = src.subdataset self._ds = driver_data + assert subdataset is not None assert subdataset in driver_data.data_vars self._xx = driver_data.data_vars[subdataset] elif isinstance(driver_data, dict): - self._spec = extract_zarr_spec(driver_data) + spec = extract_zarr_spec(driver_data) + if spec is None: + raise ValueError(f"Unsupported driver data: {type(driver_data)}") + + # create unloadable xarray.Dataset + ds = xr.open_zarr(spec, consolidated=False, decode_coords="all", chunks={}) + assert subdataset is not None + assert subdataset in ds.data_vars + + if chunk_store is None: + chunk_store = fsspec.get_mapper(src.uri) + + # recreate xr.DataArray with all the dims/coords/attrs + # but this time loadable from chunk_store + xx = ds.data_vars[subdataset] + xx = xr.DataArray( + da.from_zarr( + spec, + component=subdataset, + chunk_store=chunk_store, + ), + coords=xx.coords, + dims=xx.dims, + name=xx.name, + attrs=xx.attrs, + ) + assert xx.odc.geobox is not None + self._spec = spec + self._ds = ds + self._xx = xx + elif driver_data is not None: raise ValueError(f"Unsupported driver data type: {type(driver_data)}") - assert driver_data is None or (self._spec is not None or self._xx is not None) - @property - def spec(self) -> ZarrSpecFs | None: + def spec(self) -> ZarrSpecDict | None: return self._spec - def base(self, regen_coords: bool = False) -> xr.Dataset | None: + @property + def geobox(self) -> GeoBoxBase | None: + if self._src.geobox is not None: + return self._src.geobox + return self.resolve().odc.geobox + + def base( + self, + regen_coords: bool = False, + refresh: bool = False, + ) -> xr.Dataset | None: + if refresh and self._spec: + self._ds = None + if self._ds is not None: return self._ds if self._spec is None: @@ -202,6 +288,7 @@ def base(self, regen_coords: bool = False) -> xr.Dataset | None: self._ds = _from_zarr_spec( self._spec, regen_coords=regen_coords, + chunk_store=self._chunk_store, target=self._src.uri, chunks=self._chunks, ) @@ -210,18 +297,20 @@ def base(self, regen_coords: bool = False) -> xr.Dataset | None: def resolve( self, regen_coords: bool = False, + refresh: bool = False, ) -> xr.DataArray: + if refresh: + self._xx = None + if self._xx is not None: return self._xx - src_ds = self.base(regen_coords=regen_coords) + src_ds = self.base(regen_coords=regen_coords, refresh=refresh) if src_ds is None: raise ValueError("Failed to interpret driver data") subdataset = self._src.subdataset - if subdataset is None: - _first, *_ = src_ds.data_vars - subdataset = str(_first) + assert subdataset is not None if subdataset not in src_ds.data_vars: raise ValueError(f"Band {subdataset!r} not found in dataset") @@ -251,8 +340,7 @@ class XrMemReader: """ def __init__(self, src: RasterSource, ctx: Context) -> None: - self._src = XrSource(src, chunks=None) - self._ctx = ctx + self._src = XrSource(src, chunks=None, chunk_store=ctx.fs) def read( self, @@ -266,12 +354,17 @@ def read( src = _subset_src(src, selection, cfg) warped = xr_reproject(src, dst_geobox, resampling=cfg.resampling) - assert isinstance(warped.data, np.ndarray) + if is_dask_collection(warped): + warped = warped.data.compute(scheduler="synchronous") + else: + warped = warped.data + + assert isinstance(warped, np.ndarray) if dst is None: - dst = warped.data + dst = warped else: - dst[...] = warped.data + dst[...] = warped yx_roi = (slice(None), slice(None)) return yx_roi, dst @@ -288,15 +381,13 @@ class XrMemReaderDask: def __init__( self, - src: RasterSource | None = None, - ctx: Context | None = None, + src: xr.DataArray | None = None, layer_name: str = "", idx: int = -1, ) -> None: - self._src = XrSource(src, chunks={}) if src is not None else None - self._ctx = ctx self._layer_name = layer_name self._idx = idx + self._xx = src def read( self, @@ -306,15 +397,19 @@ def read( selection: ReaderSubsetSelection | None = None, idx: tuple[int, ...] = (), ) -> Delayed: - assert self._src is not None + assert self._xx is not None assert isinstance(idx, tuple) - xx = self._src.resolve(regen_coords=True) - xx = _subset_src(xx, selection, cfg) + xx = _subset_src(self._xx, selection, cfg) + assert xx.odc.geobox is not None + assert not math.isnan(xx.odc.geobox.transform.a) + yy = xr_reproject( xx, dst_geobox, resampling=cfg.resampling, + dst_nodata=cfg.fill_value, + dtype=cfg.dtype, chunks=dst_geobox.shape.yx, ) return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx)) @@ -327,7 +422,12 @@ def open( layer_name: str, idx: int, ) -> DaskRasterReader: - return XrMemReaderDask(src, ctx, layer_name=layer_name, idx=idx) + assert ctx is not None + _src = XrSource(src, chunks={}, chunk_store=ctx.fs) + xx = _src.resolve(regen_coords=True) + assert xx.odc.geobox is not None + assert not any(map(math.isnan, xx.odc.geobox.transform[:6])) + return XrMemReaderDask(xx, layer_name=layer_name, idx=idx) class XrMemReaderDriver: @@ -341,6 +441,7 @@ def __init__( self, src: xr.Dataset | None = None, template: RasterGroupMetadata | None = None, + fs: fsspec.AbstractFileSystem | None = None, ) -> None: if src is not None and template is None: template = raster_group_md(src) @@ -348,6 +449,7 @@ def __init__( template = RasterGroupMetadata({}, {}, {}, []) self.src = src self.template = template + self.fs = fs def new_load( self, @@ -457,3 +559,48 @@ def raster_group_md( extra_dims=edims, extra_coords=extra_coords, ) + + +def _zarr_chunk_refs( + zspec: SomeDoc, + href: str, + *, + bands: Sequence[str] | None = None, + sep: str = ".", + overrides: dict[str, Any] | None = None, +) -> Iterator[tuple[str, Any]]: + if ".zmetadata" in zspec: + zspec = json.loads(zspec[".zmetadata"])["metadata"] + elif "zarr:metadata" in zspec: + zspec = zspec["zarr:metadata"] + + assert ".zgroup" in zspec, "Not a zarr spec" + + href = href.rstrip("/") + + if bands is None: + _bands = [k.rsplit("/", 1)[0] for k in zspec if k.endswith("/.zarray")] + else: + _bands = list(bands) + + if overrides is None: + overrides = {} + + for b in _bands: + meta = zspec[f"{b}/.zarray"] + assert "chunks" in meta and "shape" in meta + + shape_in_blocks = tuple( + map(len, normalize_chunks(meta["chunks"], shape=meta["shape"])) + ) + + for idx in np.ndindex(shape_in_blocks): + if idx == (): + k = f"{b}/0" + else: + k = f"{b}/{sep.join(map(str, idx))}" + v = overrides.get(k, None) + if v is None: + v = (f"{href}/{k}",) + + yield (k, v) diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 65717d7..795b55f 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -186,9 +186,7 @@ def test_memreader_zarr(sample_ds: xr.Dataset): zarr = pytest.importorskip("zarr") assert zarr is not None - _gbox = sample_ds.odc.geobox - chunks = None assert _gbox is not None gbox = _gbox.approx if isinstance(_gbox, GCPGeoBox) else _gbox @@ -205,23 +203,28 @@ def test_memreader_zarr(sample_ds: xr.Dataset): driver_data=zmd, ) assert src.driver_data is zmd - cfg = RasterLoadParams.same_as(src) - ctx = Context(gbox, chunks) - rdr = XrMemReader(src, ctx) + driver = XrMemReaderDriver() + ctx = driver.new_load(gbox, chunks=None) + rdr = driver.open(src, ctx) roi, xx = rdr.read(cfg, gbox) assert isinstance(xx, np.ndarray) assert xx.shape == gbox[roi].shape.yx assert gbox == gbox[roi] + assert driver.dask_reader is not None + tk = tokenize(src, cfg, gbox) - ctx = Context(gbox, {}) - rdr = XrMemReaderDask().open(src, ctx, layer_name=f"xx-{tk}", idx=0) + + ctx = driver.new_load(gbox, chunks={}) + assert isinstance(ctx, Context) + + rdr = driver.dask_reader.open(src, ctx, layer_name=f"xx-{tk}", idx=0) assert isinstance(rdr, XrMemReaderDask) - assert rdr._src is not None - assert rdr._src._chunks == {} + assert rdr._xx is not None + assert is_dask_collection(rdr._xx) fut = rdr.read(cfg, gbox) assert is_dask_collection(fut)