Skip to content

Commit

Permalink
Add zarr read/write xarray-contrib/datatree#30
Browse files Browse the repository at this point in the history
* add test for roundtrip and support empty nodes

* update roundtrip test, improves empty node handling in IO

* add zarr read/write support

* support netcdf4 or h5netcdf

* netcdf is optional, zarr too!

* Apply suggestions from code review

Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com>

Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com>
  • Loading branch information
Joe Hamman and TomNicholas authored Aug 30, 2021
1 parent 6807504 commit 6c9a1cc
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 35 deletions.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dependencies:
- black
- codecov
- pytest-cov
- zarr
31 changes: 31 additions & 0 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,37 @@ def to_netcdf(
**kwargs,
)

def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs):
"""
Write datatree contents to a netCDF file.
Parameters
---------
store : MutableMapping, str or Path, optional
Store or path to directory in file system
mode : {{"w", "w-", "a", "r+", None}, default: "w"
Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists);
“a” means override existing variables (create if does not exist); “r+” means modify existing
array values only (raise an error if any metadata or shapes would change). The default mode
is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise.
encoding : dict, optional
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``.
See ``xarray.Dataset.to_zarr`` for available options.
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
"""
from .io import _datatree_to_zarr

_datatree_to_zarr(
self,
store,
mode=mode,
encoding=encoding,
**kwargs,
)

def plot(self):
raise NotImplementedError

Expand Down
140 changes: 115 additions & 25 deletions datatree/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from typing import Dict, Sequence
import pathlib
from typing import Sequence

import netCDF4
from xarray import open_dataset

from .datatree import DataNode, DataTree, PathType
from .datatree import DataTree, PathType


def _ds_or_none(ds):
Expand All @@ -14,37 +13,87 @@ def _ds_or_none(ds):
return None


def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
for g in ncgroup.groups.values():
def _iter_zarr_groups(root, parrent=""):
parrent = pathlib.Path(parrent)
for path, group in root.groups():
gpath = parrent / path
yield str(gpath)
yield from _iter_zarr_groups(group, parrent=gpath)

# Open and add this node's dataset to the tree
name = os.path.basename(g.path)
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
ds = _ds_or_none(ds)
child_node = DataNode(name, ds)
node.add_child(child_node)

_open_group_children_recursively(filename, node[name], g, chunks, **kwargs)
def _iter_nc_groups(root, parrent=""):
parrent = pathlib.Path(parrent)
for path, group in root.groups.items():
gpath = parrent / path
yield str(gpath)
yield from _iter_nc_groups(group, parrent=gpath)


def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree:
def _get_nc_dataset_class(engine):
if engine == "netcdf4":
from netCDF4 import Dataset
elif engine == "h5netcdf":
from h5netcdf import Dataset
elif engine is None:
try:
from netCDF4 import Dataset
except ImportError:
from h5netcdf import Dataset
else:
raise ValueError(f"unsupported engine: {engine}")
return Dataset


def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
"""
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.
Parameters
----------
filename
chunks
filename_or_obj : str, Path, file-like, or DataStore
Strings and Path objects are interpreted as a path to a netCDF file or Zarr store.
engine : str, optional
Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`.
kwargs :
Additional keyword arguments passed to ``xarray.open_dataset`` for each group.
Returns
-------
DataTree
"""

with netCDF4.Dataset(filename, mode="r") as ncfile:
ds = open_dataset(filename, chunks=chunks, **kwargs)
if engine == "zarr":
return _open_datatree_zarr(filename_or_obj, **kwargs)
elif engine in [None, "netcdf4", "h5netcdf"]:
return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs)


def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree:
ncDataset = _get_nc_dataset_class(kwargs.get("engine", None))

with ncDataset(filename, mode="r") as ncds:
ds = open_dataset(filename, **kwargs).pipe(_ds_or_none)
tree_root = DataTree(data_objects={"root": ds})
for key in _iter_nc_groups(ncds):
tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe(
_ds_or_none
)
return tree_root


def _open_datatree_zarr(store, **kwargs) -> DataTree:
import zarr

with zarr.open_group(store, mode="r") as zds:
ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none)
tree_root = DataTree(data_objects={"root": ds})
_open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs)
for key in _iter_zarr_groups(zds):
try:
tree_root[key] = open_dataset(
store, engine="zarr", group=key, **kwargs
).pipe(_ds_or_none)
except zarr.errors.PathNotFoundError:
tree_root[key] = None
return tree_root


Expand Down Expand Up @@ -80,8 +129,10 @@ def _maybe_extract_group_kwargs(enc, group):
return None


def _create_empty_group(filename, group, mode):
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
def _create_empty_netcdf_group(filename, group, mode, engine):
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


Expand All @@ -91,13 +142,14 @@ def _datatree_to_netcdf(
mode: str = "w",
encoding=None,
unlimited_dims=None,
**kwargs
**kwargs,
):

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
engine = kwargs.get("engine", None)
if engine not in [None, "netcdf4", "h5netcdf"]:
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")

if kwargs.get("group", None) is not None:
Expand All @@ -118,14 +170,52 @@ def _datatree_to_netcdf(
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
_create_empty_netcdf_group(filepath, group_path, mode, engine)
else:

ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
**kwargs,
)
mode = "a"


def _create_empty_zarr_group(store, group, mode):
import zarr

root = zarr.open_group(store, mode=mode)
root.create_group(group, overwrite=True)


def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwargs):

if kwargs.get("group", None) is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)

if not kwargs.get("compute", True):
raise NotImplementedError("compute=False has not been implemented yet")

if encoding is None:
encoding = {}

for node in dt.subtree:
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
else:
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
**kwargs,
)
if "w" in mode:
mode = "a"
21 changes: 12 additions & 9 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,15 @@ def test_to_netcdf(self, tmpdir):

roundtrip_dt = open_datatree(filepath)

original_dt.name == roundtrip_dt.name
assert original_dt.ds.identical(roundtrip_dt.ds)
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
assert a.ds.identical(b.ds)
else:
assert a.ds is b.ds
assert_tree_equal(original_dt, roundtrip_dt)

def test_to_zarr(self, tmpdir):
filepath = str(
tmpdir / "test.zarr"
) # casting to str avoids a pathlib bug in xarray
original_dt = create_test_datatree()
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")

assert_tree_equal(original_dt, roundtrip_dt)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
xarray>=0.19.0
netcdf4
anytree
future

0 comments on commit 6c9a1cc

Please sign in to comment.