Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability to append an InferenceData object to an existing netCDF file #2227

Merged
merged 7 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features

- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))

### Maintenance and fixes

Expand Down
44 changes: 37 additions & 7 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Union,
overload,
)
import os

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -337,7 +338,12 @@ def items(self) -> "InferenceData.InferenceDataItemsView":

@staticmethod
def from_netcdf(
filename, *, engine="h5netcdf", group_kwargs=None, regex=False
filename,
*,
engine="h5netcdf",
group_kwargs=None,
regex=False,
base_group: str = "/",
) -> "InferenceData":
"""Initialize object from a netcdf file.

Expand All @@ -360,6 +366,9 @@ def from_netcdf(
regex : bool, default False
Specifies where regex search should be used to extend the keyword arguments.
This feature is currently experimental.
base_group : str, default "/"
The group in the netCDF file where the InferenceData is stored. By default,
assumes that the file only contains an InferenceData object.

Returns
-------
Expand All @@ -380,7 +389,12 @@ def from_netcdf(
try:
with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
filename, mode="r"
) as data:
) as file_handle:
if base_group == "/":
data = file_handle
else:
data = file_handle[base_group]

data_groups = list(data.groups)

for group in data_groups:
Expand All @@ -394,13 +408,13 @@ def from_netcdf(
if re.search(key, group):
group_kws = kws
group_kws.setdefault("engine", engine)
with xr.open_dataset(filename, group=group, **group_kws) as data:
with xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws) as data:
if rcParams["data.load"] == "eager":
groups[group] = data.load()
else:
groups[group] = data

with xr.open_dataset(filename, engine=engine) as data:
with xr.open_dataset(filename, engine=engine, group=base_group) as data:
attrs.update(data.load().attrs)

return InferenceData(attrs=attrs, **groups)
Expand All @@ -424,6 +438,8 @@ def to_netcdf(
compress: bool = True,
groups: Optional[List[str]] = None,
engine: str = "h5netcdf",
base_group: str = "/",
overwrite_existing: bool = True,
) -> str:
"""Write InferenceData to netcdf4 file.

Expand All @@ -438,15 +454,29 @@ def to_netcdf(
Write only these groups to netcdf file.
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
base_group : str, default "/"
The group in the netCDF file where the InferenceData is will be stored.
By default, will write to the root of the netCDF file
varchasgopalaswamy marked this conversation as resolved.
Show resolved Hide resolved
overwrite_existing : bool, default True
Whether to overwrite the existing file or append to it.

Returns
-------
str
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if base_group is None:
base_group = "/"

if os.path.exists(filename) and not overwrite_existing:
mode = "a"
else:
mode = "w" # overwrite first, then append

if self._attrs:
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode, engine=engine)
xr.Dataset(attrs=self._attrs).to_netcdf(
filename, mode=mode, engine=engine, group=base_group
)
mode = "a"

if self._groups_all: # check's whether a group is present or not.
Expand All @@ -464,7 +494,7 @@ def to_netcdf(
for var_name, values in data.variables.items()
if _compressible_dtype(values.dtype)
}
data.to_netcdf(filename, mode=mode, group=group, **kwargs)
data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
data.close()
mode = "a"
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
Expand Down
6 changes: 4 additions & 2 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,11 @@ def test_io_function(self, data, eight_schools_params):
os.remove(filepath)
assert not os.path.exists(filepath)

@pytest.mark.parametrize("base_group", ["/", "test_group", "group/subgroup"])
@pytest.mark.parametrize("groups_arg", [False, True])
@pytest.mark.parametrize("compress", [True, False])
@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
def test_io_method(self, data, eight_schools_params, groups_arg, compress, engine):
def test_io_method(self, data, eight_schools_params, groups_arg, base_group, compress, engine):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
Expand Down Expand Up @@ -1334,12 +1335,13 @@ def test_io_method(self, data, eight_schools_params, groups_arg, compress, engin
filepath,
groups=("posterior", "observed_data") if groups_arg else None,
compress=compress,
base_group=base_group,
)

# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
inference_data2 = InferenceData.from_netcdf(filepath)
inference_data2 = InferenceData.from_netcdf(filepath, base_group=base_group)
if groups_arg: # if groups arg, update test dict to contain only saved groups
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
Expand Down