From 6c9a1ccd0e09333c432d214ccc1b5678f03d28f9 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Mon, 30 Aug 2021 09:25:57 -0700 Subject: [PATCH] Add zarr read/write https://github.com/xarray-contrib/datatree/pull/30 * add test for roundtrip and support empty nodes * update roundtrip test, improves empty node handling in IO * add zarr read/write support * support netcdf4 or h5netcdf * netcdf is optional, zarr too! * Apply suggestions from code review Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> --- ci/environment.yml | 1 + datatree/datatree.py | 31 +++++++ datatree/io.py | 140 ++++++++++++++++++++++++++------ datatree/tests/test_datatree.py | 21 +++-- requirements.txt | 1 - 5 files changed, 159 insertions(+), 35 deletions(-) diff --git a/ci/environment.yml b/ci/environment.yml index 8486fc9..e379a9f 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -11,3 +11,4 @@ dependencies: - black - codecov - pytest-cov + - zarr diff --git a/datatree/datatree.py b/datatree/datatree.py index 1828f7c..7925734 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -854,6 +854,37 @@ def to_netcdf( **kwargs, ) + def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs): + """ + Write datatree contents to a netCDF file. + + Parameters + --------- + store : MutableMapping, str or Path, optional + Store or path to directory in file system + mode : {{"w", "w-", "a", "r+", None}, default: "w" + Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); + “a” means override existing variables (create if does not exist); “r+” means modify existing + array values only (raise an error if any metadata or shapes would change). The default mode + is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``. + See ``xarray.Dataset.to_zarr`` for available options. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` + """ + from .io import _datatree_to_zarr + + _datatree_to_zarr( + self, + store, + mode=mode, + encoding=encoding, + **kwargs, + ) + def plot(self): raise NotImplementedError diff --git a/datatree/io.py b/datatree/io.py index 84be248..f7bdf57 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -1,10 +1,9 @@ -import os -from typing import Dict, Sequence +import pathlib +from typing import Sequence -import netCDF4 from xarray import open_dataset -from .datatree import DataNode, DataTree, PathType +from .datatree import DataTree, PathType def _ds_or_none(ds): @@ -14,37 +13,87 @@ def _ds_or_none(ds): return None -def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs): - for g in ncgroup.groups.values(): +def _iter_zarr_groups(root, parrent=""): + parrent = pathlib.Path(parrent) + for path, group in root.groups(): + gpath = parrent / path + yield str(gpath) + yield from _iter_zarr_groups(group, parrent=gpath) - # Open and add this node's dataset to the tree - name = os.path.basename(g.path) - ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs) - ds = _ds_or_none(ds) - child_node = DataNode(name, ds) - node.add_child(child_node) - _open_group_children_recursively(filename, node[name], g, chunks, **kwargs) +def _iter_nc_groups(root, parrent=""): + parrent = pathlib.Path(parrent) + for path, group in root.groups.items(): + gpath = parrent / path + yield str(gpath) + yield from _iter_nc_groups(group, parrent=gpath) -def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree: +def _get_nc_dataset_class(engine): + if engine == "netcdf4": + from netCDF4 import Dataset + elif engine == "h5netcdf": + from h5netcdf import Dataset + elif engine is None: + try: + from netCDF4 import Dataset + except ImportError: + from h5netcdf import Dataset + else: + raise ValueError(f"unsupported engine: {engine}") + return Dataset + + +def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree: """ Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file. Parameters ---------- - filename - chunks + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : str, optional + Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. + kwargs : + Additional keyword arguments passed to ``xarray.open_dataset`` for each group. Returns ------- DataTree """ - with netCDF4.Dataset(filename, mode="r") as ncfile: - ds = open_dataset(filename, chunks=chunks, **kwargs) + if engine == "zarr": + return _open_datatree_zarr(filename_or_obj, **kwargs) + elif engine in [None, "netcdf4", "h5netcdf"]: + return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs) + + +def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree: + ncDataset = _get_nc_dataset_class(kwargs.get("engine", None)) + + with ncDataset(filename, mode="r") as ncds: + ds = open_dataset(filename, **kwargs).pipe(_ds_or_none) + tree_root = DataTree(data_objects={"root": ds}) + for key in _iter_nc_groups(ncds): + tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe( + _ds_or_none + ) + return tree_root + + +def _open_datatree_zarr(store, **kwargs) -> DataTree: + import zarr + + with zarr.open_group(store, mode="r") as zds: + ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none) tree_root = DataTree(data_objects={"root": ds}) - _open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs) + for key in _iter_zarr_groups(zds): + try: + tree_root[key] = open_dataset( + store, engine="zarr", group=key, **kwargs + ).pipe(_ds_or_none) + except zarr.errors.PathNotFoundError: + tree_root[key] = None return tree_root @@ -80,8 +129,10 @@ def _maybe_extract_group_kwargs(enc, group): return None -def _create_empty_group(filename, group, mode): - with netCDF4.Dataset(filename, mode=mode) as rootgrp: +def _create_empty_netcdf_group(filename, group, mode, engine): + ncDataset = _get_nc_dataset_class(engine) + + with ncDataset(filename, mode=mode) as rootgrp: rootgrp.createGroup(group) @@ -91,13 +142,14 @@ def _datatree_to_netcdf( mode: str = "w", encoding=None, unlimited_dims=None, - **kwargs + **kwargs, ): if kwargs.get("format", None) not in [None, "NETCDF4"]: raise ValueError("to_netcdf only supports the NETCDF4 format") - if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]: + engine = kwargs.get("engine", None) + if engine not in [None, "netcdf4", "h5netcdf"]: raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") if kwargs.get("group", None) is not None: @@ -118,14 +170,52 @@ def _datatree_to_netcdf( ds = node.ds group_path = node.pathstr.replace(dt.root.pathstr, "") if ds is None: - _create_empty_group(filepath, group_path, mode) + _create_empty_netcdf_group(filepath, group_path, mode, engine) else: + ds.to_netcdf( filepath, group=group_path, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), - **kwargs + **kwargs, ) mode = "a" + + +def _create_empty_zarr_group(store, group, mode): + import zarr + + root = zarr.open_group(store, mode=mode) + root.create_group(group, overwrite=True) + + +def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwargs): + + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + for node in dt.subtree: + ds = node.ds + group_path = node.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_zarr_group(store, group_path, mode) + else: + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + **kwargs, + ) + if "w" in mode: + mode = "a" diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index f13a7f3..4592643 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -322,12 +322,15 @@ def test_to_netcdf(self, tmpdir): roundtrip_dt = open_datatree(filepath) - original_dt.name == roundtrip_dt.name - assert original_dt.ds.identical(roundtrip_dt.ds) - for a, b in zip(original_dt.descendants, roundtrip_dt.descendants): - assert a.name == b.name - assert a.pathstr == b.pathstr - if a.has_data: - assert a.ds.identical(b.ds) - else: - assert a.ds is b.ds + assert_tree_equal(original_dt, roundtrip_dt) + + def test_to_zarr(self, tmpdir): + filepath = str( + tmpdir / "test.zarr" + ) # casting to str avoids a pathlib bug in xarray + original_dt = create_test_datatree() + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + + assert_tree_equal(original_dt, roundtrip_dt) diff --git a/requirements.txt b/requirements.txt index 67e19d1..a95f277 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ xarray>=0.19.0 -netcdf4 anytree future