Skip to content

Commit

Permalink
allow both h5netcdf and netcdf4 to be used
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Dec 22, 2022
1 parent ed7b6c5 commit 7d86250
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 20 deletions.
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features
- Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152))
- Add `CmdStanPySamplingWrapper` and `PyMCSamplingWrapper` classes ([2158](https://github.com/arviz-devs/arviz/pull/2158))
- Changed dependency on netcdf4-python to h5netcdf ([2122](https://github.com/arviz-devs/arviz/pull/2122))

### Maintenance and fixes
- Fix `reloo` outdated usage of `ELPDData` ([2158](https://github.com/arviz-devs/arviz/pull/2158))
Expand Down Expand Up @@ -55,9 +56,6 @@
* Update tests and docs for updated example data ([2137](https://github.com/arviz-devs/arviz/pull/2137))
* Copy coords before modifying in ppcplot ([2160](https://github.com/arviz-devs/arviz/pull/2160))

* Changed dependency on netCDF4 to h5netcdf
* Changed dependency on netCDF4 to h5netcdf ([2122](https://github.com/arviz-devs/arviz/pull/2122))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))

Expand Down
49 changes: 38 additions & 11 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Union,
overload,
)
import h5netcdf

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -337,7 +336,9 @@ def items(self) -> "InferenceData.InferenceDataItemsView":
return InferenceData.InferenceDataItemsView(self)

@staticmethod
def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
def from_netcdf(
filename, *, engine="h5netcdf", group_kwargs=None, regex=False
) -> "InferenceData":
"""Initialize object from a netcdf file.
Expects that the file will have groups, each of which can be loaded by xarray.
Expand All @@ -349,6 +350,8 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
----------
filename : str
location of netcdf file
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
group_kwargs : dict of {str: dict}, optional
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
The keys of the higher level should be group names or regex matching group
Expand All @@ -360,13 +363,24 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
Returns
-------
InferenceData object
InferenceData
"""
groups = {}
attrs = {}

if engine == "h5netcdf":
import h5netcdf
elif engine == "netcdf4":
import netCDF4 as nc
else:
raise ValueError(
f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4"
)

try:
with h5netcdf.File(filename, mode="r") as data:
with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
filename, mode="r"
) as data:
data_groups = list(data.groups)

for group in data_groups:
Expand All @@ -379,14 +393,14 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
for key, kws in group_kwargs.items():
if re.search(key, group):
group_kws = kws
group_kws.setdefault("engine", "h5netcdf")
group_kws.setdefault("engine", engine)
with xr.open_dataset(filename, group=group, **group_kws) as data:
if rcParams["data.load"] == "eager":
groups[group] = data.load()
else:
groups[group] = data

with xr.open_dataset(filename, mode="r") as data:
with xr.open_dataset(filename, engine=engine) as data:
attrs.update(data.load().attrs)

return InferenceData(attrs=attrs, **groups)
Expand All @@ -405,9 +419,13 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
raise err

def to_netcdf(
self, filename: str, compress: bool = True, groups: Optional[List[str]] = None
self,
filename: str,
compress: bool = True,
groups: Optional[List[str]] = None,
engine: str = "h5netcdf",
) -> str:
"""Write InferenceData to file using netcdf4.
"""Write InferenceData to netcdf4 file.
Parameters
----------
Expand All @@ -418,6 +436,8 @@ def to_netcdf(
saving and loading somewhat slower (default: True).
groups : list, optional
Write only these groups to netcdf file.
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
Returns
-------
Expand All @@ -426,7 +446,7 @@ def to_netcdf(
"""
mode = "w" # overwrite first, then append
if self._attrs:
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode)
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode, engine=engine)
mode = "a"

if self._groups_all: # check's whether a group is present or not.
Expand All @@ -437,7 +457,7 @@ def to_netcdf(

for group in groups:
data = getattr(self, group)
kwargs = {'engine':'h5netcdf'}
kwargs = {"engine": engine}
if compress:
kwargs["encoding"] = {
var_name: {"zlib": True}
Expand All @@ -448,7 +468,14 @@ def to_netcdf(
data.close()
mode = "a"
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
empty_netcdf_file = h5netcdf.File(filename, mode="w")
if engine == "h5netcdf":
import h5netcdf

empty_netcdf_file = h5netcdf.File(filename, mode="w")
elif engine == "netcdf4":
import netCDF4 as nc

empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
empty_netcdf_file.close()
return filename

Expand Down
14 changes: 10 additions & 4 deletions arviz/data/io_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from .inference_data import InferenceData


def from_netcdf(filename, group_kwargs=None, regex=False):
def from_netcdf(filename, *, engine="h5netcdf", group_kwargs=None, regex=False):
"""Load netcdf file back into an arviz.InferenceData.
Parameters
----------
filename : str
name or path of the file to load trace
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
group_kwargs : dict of {str: dict}
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
The keys of the higher level should be group names or regex matching group
Expand All @@ -31,10 +33,12 @@ def from_netcdf(filename, group_kwargs=None, regex=False):
"""
if group_kwargs is None:
group_kwargs = {}
return InferenceData.from_netcdf(filename, group_kwargs, regex)
return InferenceData.from_netcdf(
filename, engine=engine, group_kwargs=group_kwargs, regex=regex
)


def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
def to_netcdf(data, filename, *, group="posterior", engine="h5netcdf", coords=None, dims=None):
"""Save dataset as a netcdf file.
WARNING: Only idempotent in case `data` is InferenceData
Expand All @@ -47,6 +51,8 @@ def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
name or path of the file to load trace
group : str (optional)
In case `data` is not InferenceData, this is the group it will be saved to
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
coords : dict (optional)
See `convert_to_inference_data`
dims : dict (optional)
Expand All @@ -58,5 +64,5 @@ def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
filename saved to
"""
inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
file_name = inference_data.to_netcdf(filename)
file_name = inference_data.to_netcdf(filename, engine=engine)
return file_name
13 changes: 12 additions & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,11 +1296,22 @@ def test_io_function(self, data, eight_schools_params):

@pytest.mark.parametrize("groups_arg", [False, True])
@pytest.mark.parametrize("compress", [True, False])
def test_io_method(self, data, eight_schools_params, groups_arg, compress):
@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
def test_io_method(self, data, eight_schools_params, groups_arg, compress, engine):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
if engine == "h5netcdf":
try:
import h5netcdf # pylint: disable=unused-import
except ImportError:
pytest.skip("h5netcdf not installed")
elif engine == "netcdf4":
try:
import netCDF4 # pylint: disable=unused-import
except ImportError:
pytest.skip("netcdf4 not installed")
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
"posterior_predictive": ["eta", "theta", "mu", "tau"],
Expand Down
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numba
netcdf4
bokeh>=1.4.0,<3.0
contourpy
ujson
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pytest-cov
cloudpickle

-r requirements-optional.txt
-r requirements-external.txt
-r requirements-external.txt

0 comments on commit 7d86250

Please sign in to comment.