Skip to content

Commit

Permalink
Include merge dim positions in group keys emitted by split_fragments (
Browse files Browse the repository at this point in the history
#521)

* issue 517

* fix split_fragments docstring

* adapt test to replicate #517 bug

* rename dims_starts_sizes -> concat_dims_starts_sizes

* parametrize end-to-end test with multivar pattern

* combine multivar fragments test continued

* rework combine_fragments for (single) merge dim

* drop unused nvars parameter

* add .vscode to .gitignore

* improve coverage of combine fragments merge dim test

* use merge_fragments func to pre-merge fragments

* distinguish fragments from merged_fragments

* revert rechunking.py changes

* modify test for split fragments grouping rework

* make split_fragments merge dim aware WIP

* try combining fragments in unit test

* testing tweaks

* add split fragments possible bug test

* fix offset in possible bug test (its not a bug)

* remove possible bug test (its not a bug)

* fix IndexedPosition mistake in test rechunking

* make some assertions about combined ds

* remove stray comment line

* clarify comment in test rechunking

---------

Co-authored-by: Charles Stern <62192187+cisaacstern@users.noreply.github.com>
  • Loading branch information
norlandrhagen and cisaacstern authored Jun 29, 2023
1 parent 2db1624 commit e8e6609
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ _version.py
# tutorials
*.nc
dask-worker-space

# vscode
.vscode/
18 changes: 17 additions & 1 deletion pangeo_forge_recipes/rechunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def split_fragment(
for dim, chunk_slice in target_chunk_slices.items()
)
)
# extract the position along each merge dim at which this fragment resides.
# this will be appended to the groupkey to ensure that `combine_fragments`
# (which consumes the output of this function) receives groups of fragments which are
# homogenous in all merge dimensions. a possible value here would be `[("variable", 0)]`.
merge_dim_positions = sorted(
[
(dim.name, position.value)
for dim, position in common_index.items()
if dim.operation == CombineOp.MERGE
]
)

# this iteration yields new fragments, indexed by their target chunk group
for target_chunk_group in all_chunks:
Expand All @@ -104,7 +115,12 @@ def split_fragment(
)
sub_fragment_ds = ds.isel(**sub_fragment_indexer)

yield tuple(sorted(target_chunk_group)), (sub_fragment_index, sub_fragment_ds)
yield (
# append the `merge_dim_positions` to the target_chunk_group before returning,
# to ensure correct grouping of merge dims. e.g., `(("time", 0), ("variable", 0))`.
tuple(sorted(target_chunk_group) + merge_dim_positions),
(sub_fragment_index, sub_fragment_ds),
)


def _sort_index_key(item):
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import apache_beam as beam
import fsspec
import pytest
import xarray as xr
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline

Expand Down Expand Up @@ -59,7 +60,10 @@ def split_up_files_by_day(ds, day_param):
return datasets, fnames


def split_up_files_by_variable_and_day(ds, day_param):
def split_up_files_by_variable_and_day(
ds: xr.Dataset,
day_param: str,
) -> tuple[list[xr.Dataset], list[str], dict]:
all_dsets = []
all_fnames = []
fnames_by_variable = {}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def pipeline():
@pytest.mark.parametrize("target_chunks", [{"time": 1}, {"time": 2}, {"time": 3}])
def test_xarray_zarr(
daily_xarray_dataset,
netcdf_local_file_pattern_sequential,
netcdf_local_file_pattern,
pipeline,
tmp_target_url,
target_chunks,
):
pattern = netcdf_local_file_pattern_sequential
pattern = netcdf_local_file_pattern
with pipeline as p:
(
p
Expand Down
77 changes: 75 additions & 2 deletions tests/test_rechunking.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,87 @@
import itertools
import random
from collections import namedtuple

import pytest
import xarray as xr

from pangeo_forge_recipes.rechunking import combine_fragments, split_fragment
from pangeo_forge_recipes.rechunking import GroupKey, combine_fragments, split_fragment
from pangeo_forge_recipes.types import CombineOp, Dimension, Index, IndexedPosition, Position

from .conftest import split_up_files_by_variable_and_day
from .data_generation import make_ds


@pytest.mark.parametrize(
"nt_dayparam",
[(5, "1D"), (10, "2D")],
)
@pytest.mark.parametrize("time_chunks", [1, 2, 5])
def test_split_and_combine_fragments_with_merge_dim(nt_dayparam, time_chunks):
"""Test if sub-fragments split from datasets with merge dims can be combined with each other."""

target_chunks = {"time": time_chunks}
nt, dayparam = nt_dayparam
ds = make_ds(nt=nt)
dsets, _, _ = split_up_files_by_variable_and_day(ds, dayparam)

# replicates indexes created by IndexItems transform.
time_positions = {t: i for i, t in enumerate(ds.time.values)}
merge_dim = Dimension("variable", CombineOp.MERGE)
concat_dim = Dimension("time", CombineOp.CONCAT)
indexes = [
Index(
{
merge_dim: Position((0 if "bar" in ds.data_vars else 1)),
concat_dim: IndexedPosition(time_positions[ds.time[0].values], dimsize=nt),
}
)
for ds in dsets
]

# split the (mock indexed) datasets into sub-fragments.
# the splits list are nested tuples which are a bit confusing for humans to think about.
# create a namedtuple to help remember the structure of these tuples and cast the
# elements of splits list to this more descriptive type.
splits = [
list(split_fragment((index, ds), target_chunks=target_chunks))
for index, ds in zip(indexes, dsets)
]
Subfragment = namedtuple("Subfragment", "groupkey, content")
subfragments = list(itertools.chain(*[[Subfragment(*s) for s in split] for split in splits]))

# combine subfragments, starting by grouping subfragments by groupkey.
# replicates behavior of `... | beam.GroupByKey() | beam.MapTuple(combine_fragments)`
# in the `Rechunk` transform.
groupkeys = set([sf.groupkey for sf in subfragments])
grouped_subfragments: dict[GroupKey, list[Subfragment]] = {g: [] for g in groupkeys}
for sf in subfragments:
grouped_subfragments[sf.groupkey].append(sf)

for g in sorted(groupkeys):
# just confirms that grouping logic within this test is correct
assert all([sf.groupkey == g for sf in grouped_subfragments[g]])
# for the merge dimension of each subfragment in the current group, assert that there
# is only one positional value present. this verifies that `split_fragments` has not
# grouped distinct merge dimension positional values together under the same groupkey.
merge_position_vals = [sf.content[0][merge_dim].value for sf in grouped_subfragments[g]]
assert all([v == merge_position_vals[0] for v in merge_position_vals])
# now actually try to combine the fragments
_, ds_combined = combine_fragments(
g,
[sf.content for sf in grouped_subfragments[g]],
)
# ensure vars are *not* combined (we only want to concat, not merge)
assert len([k for k in ds_combined.data_vars.keys()]) == 1
# check that time chunking is correct
if nt % time_chunks == 0:
assert len(ds_combined.time) == time_chunks
else:
# if `nt` is not evenly divisible by `time_chunks`, all chunks will be of
# `len(time_chunks)` except the last one, which will be the lenth of the remainder
assert len(ds_combined.time) in [time_chunks, nt % time_chunks]


@pytest.mark.parametrize("offset", [0, 5]) # hypothetical offset of this fragment
@pytest.mark.parametrize("time_chunks", [1, 3, 5, 10, 11])
def test_split_fragment(time_chunks, offset):
Expand Down Expand Up @@ -36,7 +109,7 @@ def test_split_fragment(time_chunks, offset):

for n in range(len(all_splits)):
chunk_number = offset // time_chunks + n
assert group_keys[n] == (("time", chunk_number),)
assert group_keys[n] == (("time", chunk_number), ("bar", 1))
chunk_start = time_chunks * chunk_number
chunk_stop = min(time_chunks * (chunk_number + 1), nt_total)
fragment_start = max(chunk_start, offset)
Expand Down

0 comments on commit e8e6609

Please sign in to comment.