Skip to content

Commit

Permalink
Added the ability to append an InferenceData object to an existing ne…
Browse files Browse the repository at this point in the history
…tCDF file (#2227)

* added the ability to append to an existing file

* set initial mode for netcdf write

* addressed comments in PR

* added spaces in doc

* added a test for base group

* pylint fix and changelog entry

* update changelog

---------

Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com>
  • Loading branch information
varchasgopalaswamy and OriolAbril authored Jun 10, 2023
1 parent 8a1b37e commit 353508d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features

- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))

### Maintenance and fixes
- Fix numba deprecation warning ([2246](https://github.com/arviz-devs/arviz/pull/2246))
Expand Down
44 changes: 37 additions & 7 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Union,
overload,
)
import os

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -337,7 +338,12 @@ def items(self) -> "InferenceData.InferenceDataItemsView":

@staticmethod
def from_netcdf(
filename, *, engine="h5netcdf", group_kwargs=None, regex=False
filename,
*,
engine="h5netcdf",
group_kwargs=None,
regex=False,
base_group: str = "/",
) -> "InferenceData":
"""Initialize object from a netcdf file.
Expand All @@ -360,6 +366,9 @@ def from_netcdf(
regex : bool, default False
Specifies where regex search should be used to extend the keyword arguments.
This feature is currently experimental.
base_group : str, default "/"
The group in the netCDF file where the InferenceData is stored. By default,
assumes that the file only contains an InferenceData object.
Returns
-------
Expand All @@ -380,7 +389,12 @@ def from_netcdf(
try:
with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
filename, mode="r"
) as data:
) as file_handle:
if base_group == "/":
data = file_handle
else:
data = file_handle[base_group]

data_groups = list(data.groups)

for group in data_groups:
Expand All @@ -394,13 +408,13 @@ def from_netcdf(
if re.search(key, group):
group_kws = kws
group_kws.setdefault("engine", engine)
with xr.open_dataset(filename, group=group, **group_kws) as data:
with xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws) as data:
if rcParams["data.load"] == "eager":
groups[group] = data.load()
else:
groups[group] = data

with xr.open_dataset(filename, engine=engine) as data:
with xr.open_dataset(filename, engine=engine, group=base_group) as data:
attrs.update(data.load().attrs)

return InferenceData(attrs=attrs, **groups)
Expand All @@ -424,6 +438,8 @@ def to_netcdf(
compress: bool = True,
groups: Optional[List[str]] = None,
engine: str = "h5netcdf",
base_group: str = "/",
overwrite_existing: bool = True,
) -> str:
"""Write InferenceData to netcdf4 file.
Expand All @@ -438,15 +454,29 @@ def to_netcdf(
Write only these groups to netcdf file.
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
base_group : str, default "/"
The group in the netCDF file where the InferenceData is will be stored.
By default, will write to the root of the netCDF file
overwrite_existing : bool, default True
Whether to overwrite the existing file or append to it.
Returns
-------
str
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if base_group is None:
base_group = "/"

if os.path.exists(filename) and not overwrite_existing:
mode = "a"
else:
mode = "w" # overwrite first, then append

if self._attrs:
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode, engine=engine)
xr.Dataset(attrs=self._attrs).to_netcdf(
filename, mode=mode, engine=engine, group=base_group
)
mode = "a"

if self._groups_all: # check's whether a group is present or not.
Expand All @@ -464,7 +494,7 @@ def to_netcdf(
for var_name, values in data.variables.items()
if _compressible_dtype(values.dtype)
}
data.to_netcdf(filename, mode=mode, group=group, **kwargs)
data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
data.close()
mode = "a"
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
Expand Down
6 changes: 4 additions & 2 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,11 @@ def test_io_function(self, data, eight_schools_params):
os.remove(filepath)
assert not os.path.exists(filepath)

@pytest.mark.parametrize("base_group", ["/", "test_group", "group/subgroup"])
@pytest.mark.parametrize("groups_arg", [False, True])
@pytest.mark.parametrize("compress", [True, False])
@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
def test_io_method(self, data, eight_schools_params, groups_arg, compress, engine):
def test_io_method(self, data, eight_schools_params, groups_arg, base_group, compress, engine):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
Expand Down Expand Up @@ -1334,12 +1335,13 @@ def test_io_method(self, data, eight_schools_params, groups_arg, compress, engin
filepath,
groups=("posterior", "observed_data") if groups_arg else None,
compress=compress,
base_group=base_group,
)

# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
inference_data2 = InferenceData.from_netcdf(filepath)
inference_data2 = InferenceData.from_netcdf(filepath, base_group=base_group)
if groups_arg: # if groups arg, update test dict to contain only saved groups
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
Expand Down

0 comments on commit 353508d

Please sign in to comment.