Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 11, 2022
1 parent a2ba6c6 commit d381d96
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 28 deletions.
10 changes: 5 additions & 5 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,11 @@ def open_group(

if zarr_version is None:
# default to 2 if store doesn't specify it's version (e.g. a path)
zarr_version = getattr(store, '_store_version', 2)
zarr_version = getattr(store, "_store_version", 2)

if zarr_version > 2 and group is None:
# v3 stores require a group name: use 'xarray' as a default one.
group = 'xarray'
group = "xarray"

open_kwargs = dict(
mode=mode,
Expand Down Expand Up @@ -427,7 +427,7 @@ def open_store_variable(self, name, zarr_array):
# TODO: how to properly handle 'filters' for v3 stores
# currently these use a hack to store 'filters' within attributes
# need to drop this here for V3 store tests to succeed
attributes.pop('filters', None)
attributes.pop("filters", None)

encoding = {
"chunks": zarr_array.chunks,
Expand Down Expand Up @@ -558,11 +558,11 @@ def store(
self.set_variables(
variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims
)
zarr_version = getattr(self.zarr_group.store, '_store_version', 3)
zarr_version = getattr(self.zarr_group.store, "_store_version", 3)
consolidate_kwargs = {}
if zarr_version > 2:
# zarr v3 spec requires providing a path
consolidate_kwargs['path'] = self.zarr_group.path
consolidate_kwargs["path"] = self.zarr_group.path
if self._consolidate_on_close:
zarr.consolidate_metadata(self.zarr_group.store, **consolidate_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,12 +1781,12 @@ def to_zarr(
group: str = None,
encoding: Mapping = None,
compute: bool = True,
consolidated: Optional[bool] = None,
consolidated: bool | None = None,
append_dim: Hashable = None,
region: Mapping[str, slice] = None,
safe_chunks: bool = True,
storage_options: dict[str, str] = None,
zarr_version: Optional[int] = None,
zarr_version: int | None = None,
) -> ZarrStore:
"""Write dataset contents to a zarr group.
Expand Down
84 changes: 63 additions & 21 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@
have_zarr_kvstore = False
try:
from zarr.storage import KVStore

have_zarr_kvstore = True
except ImportError:
KVStore = None

have_zarr_v3 = False
try:
from zarr.storage_v3 import DirectoryStoreV3, KVStoreV3

have_zarr_v3 = True
except ImportError:
KVStoreV3 = None
Expand Down Expand Up @@ -1698,14 +1700,18 @@ def create_zarr_target(self):
@contextlib.contextmanager
def create_store(self):
with self.create_zarr_target() as store_target:
yield backends.ZarrStore.open_group(store_target, mode="w", **self.version_kwargs)
yield backends.ZarrStore.open_group(
store_target, mode="w", **self.version_kwargs
)

def save(self, dataset, store_target, **kwargs):
return dataset.to_zarr(store=store_target, **kwargs, **self.version_kwargs)

@contextlib.contextmanager
def open(self, store_target, **kwargs):
with xr.open_dataset(store_target, engine="zarr", **kwargs, **self.version_kwargs) as ds:
with xr.open_dataset(
store_target, engine="zarr", **kwargs, **self.version_kwargs
) as ds:
yield ds

@contextlib.contextmanager
Expand Down Expand Up @@ -2029,9 +2035,13 @@ def test_write_persistence_modes(self, group):
ds, ds_to_append, _ = create_append_test_data()
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target, mode="w", group=group, **self.version_kwargs)
ds_to_append.to_zarr(store_target, append_dim="time", group=group, **self.version_kwargs)
ds_to_append.to_zarr(
store_target, append_dim="time", group=group, **self.version_kwargs
)
original = xr.concat([ds, ds_to_append], dim="time")
actual = xr.open_dataset(store_target, group=group, engine="zarr", **self.version_kwargs)
actual = xr.open_dataset(
store_target, group=group, engine="zarr", **self.version_kwargs
)
assert_identical(original, actual)

def test_compressor_encoding(self):
Expand Down Expand Up @@ -2093,13 +2103,19 @@ def test_append_with_invalid_dim_raises(self):
with pytest.raises(
ValueError, match="does not match any existing dataset dimensions"
):
ds_to_append.to_zarr(store_target, append_dim="notvalid", **self.version_kwargs)
ds_to_append.to_zarr(
store_target, append_dim="notvalid", **self.version_kwargs
)

def test_append_with_no_dims_raises(self):
with self.create_zarr_target() as store_target:
Dataset({"foo": ("x", [1])}).to_zarr(store_target, mode="w", **self.version_kwargs)
Dataset({"foo": ("x", [1])}).to_zarr(
store_target, mode="w", **self.version_kwargs
)
with pytest.raises(ValueError, match="different dimension names"):
Dataset({"foo": ("y", [2])}).to_zarr(store_target, mode="a", **self.version_kwargs)
Dataset({"foo": ("y", [2])}).to_zarr(
store_target, mode="a", **self.version_kwargs
)

def test_append_with_append_dim_not_set_raises(self):
ds, ds_to_append, _ = create_append_test_data()
Expand All @@ -2113,7 +2129,9 @@ def test_append_with_mode_not_a_raises(self):
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target, mode="w", **self.version_kwargs)
with pytest.raises(ValueError, match="cannot set append_dim unless"):
ds_to_append.to_zarr(store_target, mode="w", append_dim="time", **self.version_kwargs)
ds_to_append.to_zarr(
store_target, mode="w", append_dim="time", **self.version_kwargs
)

def test_append_with_existing_encoding_raises(self):
ds, ds_to_append, _ = create_append_test_data()
Expand All @@ -2139,11 +2157,15 @@ def test_check_encoding_is_consistent_after_append(self):
encoding = {"da": {"compressor": compressor}}
ds.to_zarr(store_target, mode="w", encoding=encoding, **self.version_kwargs)
ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs)
actual_ds = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs)
actual_ds = xr.open_dataset(
store_target, engine="zarr", **self.version_kwargs
)
actual_encoding = actual_ds["da"].encoding["compressor"]
assert actual_encoding.get_config() == compressor.get_config()
assert_identical(
xr.open_dataset(store_target, engine="zarr", **self.version_kwargs).compute(),
xr.open_dataset(
store_target, engine="zarr", **self.version_kwargs
).compute(),
xr.concat([ds, ds_to_append], dim="time"),
)

Expand All @@ -2160,7 +2182,8 @@ def test_append_with_new_variable(self):
combined = xr.concat([ds, ds_to_append], dim="time")
combined["new_var"] = ds_with_new_var["new_var"]
assert_identical(
combined, xr.open_dataset(store_target, engine="zarr", **self.version_kwargs)
combined,
xr.open_dataset(store_target, engine="zarr", **self.version_kwargs),
)

@requires_dask
Expand Down Expand Up @@ -2267,9 +2290,14 @@ def test_write_region(self, consolidated, compute, use_dask):
for i in range(0, 10, 2):
region = {"x": slice(i, i + 2)}
nonzeros.isel(region).to_zarr(
store, region=region, consolidated=consolidated, **self.version_kwargs,
store,
region=region,
consolidated=consolidated,
**self.version_kwargs,
)
with xr.open_zarr(store, consolidated=consolidated, **self.version_kwargs) as actual:
with xr.open_zarr(
store, consolidated=consolidated, **self.version_kwargs
) as actual:
assert_identical(actual, nonzeros)

@pytest.mark.parametrize("mode", [None, "r+", "a"])
Expand All @@ -2279,7 +2307,9 @@ def test_write_region_mode(self, mode):
with self.create_zarr_target() as store:
zeros.to_zarr(store, **self.version_kwargs)
for region in [{"x": slice(5)}, {"x": slice(5, 10)}]:
nonzeros.isel(region).to_zarr(store, region=region, mode=mode, **self.version_kwargs)
nonzeros.isel(region).to_zarr(
store, region=region, mode=mode, **self.version_kwargs
)
with xr.open_zarr(store, **self.version_kwargs) as actual:
assert_identical(actual, nonzeros)

Expand Down Expand Up @@ -2321,7 +2351,9 @@ def test_write_preexisting_override_metadata(self):
with self.create_zarr_target() as store:
original.to_zarr(store, compute=False, **self.version_kwargs)
# with region, the default mode becomes r+
both_modified.to_zarr(store, region={"x": slice(None)}, **self.version_kwargs)
both_modified.to_zarr(
store, region={"x": slice(None)}, **self.version_kwargs
)
with self.open(store) as actual:
assert_identical(actual, only_new_data)

Expand Down Expand Up @@ -2349,7 +2381,9 @@ def setup_and_verify_store(expected=data):
"cannot set region unless mode='a', mode='r+' or mode=None"
),
):
data.to_zarr(store, region={"x": slice(None)}, mode="w", **self.version_kwargs)
data.to_zarr(
store, region={"x": slice(None)}, mode="w", **self.version_kwargs
)

with setup_and_verify_store() as store:
with pytest.raises(TypeError, match=r"must be a dict"):
Expand All @@ -2361,7 +2395,9 @@ def setup_and_verify_store(expected=data):

with setup_and_verify_store() as store:
with pytest.raises(ValueError, match=r"step on all slices"):
data2.to_zarr(store, region={"x": slice(None, None, 2)}, **self.version_kwargs)
data2.to_zarr(
store, region={"x": slice(None, None, 2)}, **self.version_kwargs
)

with setup_and_verify_store() as store:
with pytest.raises(
Expand All @@ -2375,13 +2411,20 @@ def setup_and_verify_store(expected=data):
ValueError,
match=r"all variables in the dataset to write must have at least one dimension in common",
):
data2.assign(v=2).to_zarr(store, region={"x": slice(2)}, **self.version_kwargs)
data2.assign(v=2).to_zarr(
store, region={"x": slice(2)}, **self.version_kwargs
)

with setup_and_verify_store() as store:
with pytest.raises(
ValueError, match=r"cannot list the same dimension in both"
):
data.to_zarr(store, region={"x": slice(None)}, append_dim="x", **self.version_kwargs)
data.to_zarr(
store,
region={"x": slice(None)},
append_dim="x",
**self.version_kwargs,
)

with setup_and_verify_store() as store:
with pytest.raises(
Expand Down Expand Up @@ -2479,7 +2522,6 @@ def create_store(self):
yield group



class ZarrBaseV3(ZarrBase):
def test_roundtrip_coordinates_with_space(self):
original = Dataset(coords={"x": 0, "y z": 1})
Expand Down Expand Up @@ -2510,7 +2552,7 @@ def create_zarr_target(self):
class TestZarrDirectoryStoreV3FromPath(TestZarrDirectoryStoreV3):
# Must specify zarr_version=3 to get a v3 store because create_zarr_target
# is a string path.
version_kwargs = {'zarr_version': 3}
version_kwargs = {"zarr_version": 3}

@contextlib.contextmanager
def create_zarr_target(self):
Expand Down

0 comments on commit d381d96

Please sign in to comment.