Skip to content

Commit

Permalink
Merge pull request #139 from rabernat/make_recipe_test_subset
Browse files Browse the repository at this point in the history
Add copy_pruned() method to XarrayZarrRecipe
  • Loading branch information
rabernat authored May 24, 2021
2 parents ecadd35 + 5d81dc5 commit 23b7296
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
28 changes: 26 additions & 2 deletions pangeo_forge_recipes/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Filename / URL patterns.
"""

from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from itertools import product
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -47,6 +47,7 @@ class MergeDim:


Index = Tuple[int, ...]
CombineDim = Union[MergeDim, ConcatDim]


class FilePattern:
Expand Down Expand Up @@ -77,7 +78,7 @@ def _make_da(format_function, combine_dims) -> xr.DataArray:
coords = {cdim.name: (cdim.name, cdim.keys) for cdim in combine_dims}
return xr.DataArray(fnames_np, dims=list(coords), coords=coords) # type: ignore

def __init__(self, format_function: Callable, *combine_dims: Union[MergeDim, ConcatDim]):
def __init__(self, format_function: Callable, *combine_dims: CombineDim):
self.__setstate__((format_function, combine_dims))

def __getstate__(self):
Expand Down Expand Up @@ -157,3 +158,26 @@ def format_function(**kwargs):
return file_list[kwargs[concat_dim]]

return FilePattern(format_function, concat)


def prune_pattern(fp: FilePattern, nkeep: int = 2) -> FilePattern:
"""
Create a smaller pattern from a full pattern.
Keeps all MergeDims but only the first `nkeep` items from each ConcatDim
:param fp: The original pattern.
:param nkeep: The number of items to keep from each ConcatDim sequence.
"""

new_combine_dims = [] # type: List[CombineDim]
for cdim in fp.combine_dims:
if isinstance(cdim, MergeDim):
new_combine_dims.append(cdim)
elif isinstance(cdim, ConcatDim):
new_keys = cdim.keys[:nkeep]
new_cdim = replace(cdim, keys=new_keys)
new_combine_dims.append(new_cdim)
else: # pragma: no cover
assert "Should never happen"

return FilePattern(fp.format_function, *new_combine_dims)
15 changes: 13 additions & 2 deletions pangeo_forge_recipes/recipes/xarray_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import warnings
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from itertools import product
from typing import Callable, Dict, List, Optional, Sequence, Tuple

Expand All @@ -14,7 +14,7 @@
import xarray as xr
import zarr

from ..patterns import FilePattern
from ..patterns import FilePattern, prune_pattern
from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener
from ..utils import (
chunk_bounds_and_conflicts,
Expand Down Expand Up @@ -181,6 +181,17 @@ def _validate_input_and_chunk_keys(self):
if not all_chunk_keys == set([c for val in self._inputs_chunks.values() for c in val]):
raise ValueError("_inputs_chunks and _chunks_inputs don't use the same chunk keys.")

def copy_pruned(self, nkeep: int = 2) -> BaseRecipe:
"""Make a copy of this recipe with a pruned file pattern.
:param nkeep: The number of items to keep from each ConcatDim sequence.
"""

new_pattern = prune_pattern(self.file_pattern, nkeep=nkeep)
return replace(self, file_pattern=new_pattern)

# below here are methods that are part of recipe execution

def _set_target_chunks(self):
target_concat_dim_chunks = self.target_chunks.get(self._concat_dim)
if (self._nitems_per_input is None) and (target_concat_dim_chunks is None):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
FilePattern,
MergeDim,
pattern_from_file_sequence,
prune_pattern,
)


Expand Down Expand Up @@ -73,3 +74,17 @@ def format_function(time, variable):
assert fp[key] == fname
fnames.append(fname)
assert list(fp.items()) == list(zip(expected_keys, fnames))


@pytest.mark.parametrize("nkeep", [1, 2])
def test_prune(nkeep):
concat = ConcatDim(name="time", keys=list(range(3)))
merge = MergeDim(name="variable", keys=["foo", "bar"])

def format_function(time, variable):
return f"T_{time}_V_{variable}"

fp = FilePattern(format_function, merge, concat)
fp_pruned = prune_pattern(fp, nkeep=nkeep)
assert fp_pruned.dims == {"variable": 2, "time": nkeep}
assert len(list(fp_pruned.items())) == 2 * nkeep
15 changes: 15 additions & 0 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ def test_recipe(recipe_fixture, execute_recipe):
xr.testing.assert_identical(ds_actual, ds_expected)


@pytest.mark.parametrize("recipe_fixture", all_recipes)
@pytest.mark.parametrize("nkeep", [1, 2])
def test_prune_recipe(recipe_fixture, execute_recipe, nkeep):
"""The basic recipe test. Use this as a template for other tests."""

RecipeClass, file_pattern, kwargs, ds_expected, target = recipe_fixture
rec = RecipeClass(file_pattern, **kwargs)
rec_pruned = rec.copy_pruned(nkeep=nkeep)
assert len(list(rec.iter_inputs())) > len(list(rec_pruned.iter_inputs()))
execute_recipe(rec_pruned)
ds_pruned = xr.open_zarr(target.get_mapper()).load()
nitems_per_input = list(file_pattern.nitems_per_input.values())[0]
assert ds_pruned.dims["time"] == nkeep * nitems_per_input


@pytest.mark.parametrize("cache_inputs", [True, False])
@pytest.mark.parametrize("copy_input_to_local_file", [True, False])
def test_recipe_caching_copying(
Expand Down

0 comments on commit 23b7296

Please sign in to comment.