Skip to content

Commit

Permalink
Control attrs of result in merge(), concat(), `combine_by_coords(…
Browse files Browse the repository at this point in the history
…)` and `combine_nested()` (#3877)

* Optionally promote attrs from DataArray to Dataset in to_dataset

Adds option 'promote_attrs' to DataArray.to_dataset(). By default
promote_attrs=False, maintaining current behaviour. If
promote_attrs=True, the attrs of the DataArray are shallow-copied to the
Dataset returned by to_dataset().

* utils.ordered_dict_union returns the union of two compatible dicts

If the values of any shared key are not equivalent, then raises an
error.

* combine_attrs argument for merge()

Provides several options for how to combine the attributes of the passed
objects and give them to the returned Dataset.

* combine_attrs argument for concat()

Provides several options for how to combine the attributes of the passed
objects and give them to the returned DataArray or Dataset.

* combine_attrs argument for combine_by_coords() and combine_nested()

Provides several options for how to combine the attributes of the passed
objects and give them to the returned Dataset.

* Add combine_attrs changes to whats-new.rst

* Update docstrings to note default values

Apply suggestions from code review

Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>

* First argument of update_safety_check and ordered_dict_union not mutable

No need for these arguments to be MutableMapping rather than just
Mapping.

* Rename ordered_dict_union -> compat_dict_union

Do not use OrderedDicts any more, so name did not make sense.

* Move combine_attrs to v0.16.0 in whats-new.rst

* Fix merge of whats-new.rst

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
  • Loading branch information
johnomotani and max-sixty authored Mar 24, 2020
1 parent c10c992 commit d8bb620
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 25 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Breaking changes

New Features
~~~~~~~~~~~~
- Control over attributes of result in :py:func:`merge`, :py:func:`concat`,
:py:func:`combine_by_coords` and :py:func:`combine_nested` using
combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`)
By `John Omotani <https://github.com/johnomotani>`_


Bug fixes
Expand Down
50 changes: 47 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _combine_nd(
compat="no_conflicts",
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
):
"""
Combines an N-dimensional structure of datasets into one by applying a
Expand Down Expand Up @@ -202,13 +203,21 @@ def _combine_nd(
compat=compat,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)
(combined_ds,) = combined_ids.values()
return combined_ds


def _combine_all_along_first_dim(
combined_ids, dim, data_vars, coords, compat, fill_value=dtypes.NA, join="outer"
combined_ids,
dim,
data_vars,
coords,
compat,
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
):

# Group into lines of datasets which must be combined along dim
Expand All @@ -223,7 +232,7 @@ def _combine_all_along_first_dim(
combined_ids = dict(sorted(group))
datasets = combined_ids.values()
new_combined_ids[new_id] = _combine_1d(
datasets, dim, compat, data_vars, coords, fill_value, join
datasets, dim, compat, data_vars, coords, fill_value, join, combine_attrs
)
return new_combined_ids

Expand All @@ -236,6 +245,7 @@ def _combine_1d(
coords="different",
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
):
"""
Applies either concat or merge to 1D list of datasets depending on value
Expand All @@ -252,6 +262,7 @@ def _combine_1d(
compat=compat,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)
except ValueError as err:
if "encountered unexpected variable" in str(err):
Expand All @@ -265,7 +276,13 @@ def _combine_1d(
else:
raise
else:
combined = merge(datasets, compat=compat, fill_value=fill_value, join=join)
combined = merge(
datasets,
compat=compat,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)

return combined

Expand All @@ -284,6 +301,7 @@ def _nested_combine(
ids,
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
):

if len(datasets) == 0:
Expand Down Expand Up @@ -311,6 +329,7 @@ def _nested_combine(
coords=coords,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)
return combined

Expand All @@ -323,6 +342,7 @@ def combine_nested(
coords="different",
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
):
"""
Explicitly combine an N-dimensional grid of datasets into one by using a
Expand Down Expand Up @@ -390,6 +410,16 @@ def combine_nested(
- 'override': if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'},
default 'drop'
String indicating how to combine attrs of the objects being merged:
- 'drop': empty attrs on returned Dataset.
- 'identical': all attrs must be the same on every object.
- 'no_conflicts': attrs from all objects are combined, any that have
the same name must also have the same value.
- 'override': skip comparing and copy attrs from the first dataset to
the result.
Returns
-------
Expand Down Expand Up @@ -468,6 +498,7 @@ def combine_nested(
ids=False,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)


Expand All @@ -482,6 +513,7 @@ def combine_by_coords(
coords="different",
fill_value=dtypes.NA,
join="outer",
combine_attrs="no_conflicts",
):
"""
Attempt to auto-magically combine the given datasets into one by using
Expand Down Expand Up @@ -557,6 +589,16 @@ def combine_by_coords(
- 'override': if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'},
default 'drop'
String indicating how to combine attrs of the objects being merged:
- 'drop': empty attrs on returned Dataset.
- 'identical': all attrs must be the same on every object.
- 'no_conflicts': attrs from all objects are combined, any that have
the same name must also have the same value.
- 'override': skip comparing and copy attrs from the first dataset to
the result.
Returns
-------
Expand Down Expand Up @@ -700,6 +742,7 @@ def combine_by_coords(
compat=compat,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)

# Check the overall coordinates are monotonically increasing
Expand All @@ -717,6 +760,7 @@ def combine_by_coords(
compat=compat,
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
)


Expand Down
34 changes: 26 additions & 8 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from . import dtypes, utils
from .alignment import align
from .duck_array_ops import lazy_array_equiv
from .merge import _VALID_COMPAT, unique_variable
from .merge import _VALID_COMPAT, merge_attrs, unique_variable
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars

Expand All @@ -17,6 +17,7 @@ def concat(
positions=None,
fill_value=dtypes.NA,
join="outer",
combine_attrs="override",
):
"""Concatenate xarray objects along a new or existing dimension.
Expand Down Expand Up @@ -92,15 +93,21 @@ def concat(
- 'override': if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'},
default 'override
String indicating how to combine attrs of the objects being merged:
- 'drop': empty attrs on returned Dataset.
- 'identical': all attrs must be the same on every object.
- 'no_conflicts': attrs from all objects are combined, any that have
the same name must also have the same value.
- 'override': skip comparing and copy attrs from the first dataset to
the result.
Returns
-------
concatenated : type of objs
Notes
-----
Each concatenated Variable preserves corresponding ``attrs`` from the first element of ``objs``.
See also
--------
merge
Expand Down Expand Up @@ -132,7 +139,9 @@ def concat(
"can only concatenate xarray Dataset and DataArray "
"objects, got %s" % type(first_obj)
)
return f(objs, dim, data_vars, coords, compat, positions, fill_value, join)
return f(
objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs
)


def _calc_concat_dim_coord(dim):
Expand Down Expand Up @@ -306,6 +315,7 @@ def _dataset_concat(
positions,
fill_value=dtypes.NA,
join="outer",
combine_attrs="override",
):
"""
Concatenate a sequence of datasets along a new or existing dimension
Expand Down Expand Up @@ -362,7 +372,7 @@ def _dataset_concat(
result_vars.update(dim_coords)

# assign attrs and encoding from first dataset
result_attrs = datasets[0].attrs
result_attrs = merge_attrs([ds.attrs for ds in datasets], combine_attrs)
result_encoding = datasets[0].encoding

# check that global attributes are fixed across all datasets if necessary
Expand Down Expand Up @@ -425,6 +435,7 @@ def _dataarray_concat(
positions,
fill_value=dtypes.NA,
join="outer",
combine_attrs="override",
):
arrays = list(arrays)

Expand Down Expand Up @@ -453,5 +464,12 @@ def _dataarray_concat(
positions,
fill_value=fill_value,
join=join,
combine_attrs="drop",
)
return arrays[0]._from_temp_dataset(ds, name)

merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs)

result = arrays[0]._from_temp_dataset(ds, name)
result.attrs = merged_attrs

return result
19 changes: 16 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,13 @@ def _to_dataset_whole(
dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes)
return dataset

def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset:
def to_dataset(
self,
dim: Hashable = None,
*,
name: Hashable = None,
promote_attrs: bool = False,
) -> Dataset:
"""Convert a DataArray to a Dataset.
Parameters
Expand All @@ -487,6 +493,8 @@ def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset:
name : hashable, optional
Name to substitute for this array's name. Only valid if ``dim`` is
not provided.
promote_attrs : bool, default False
Set to True to shallow copy attrs of DataArray to returned Dataset.
Returns
-------
Expand All @@ -500,9 +508,14 @@ def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset:
if dim is not None:
if name is not None:
raise TypeError("cannot supply both dim and name arguments")
return self._to_dataset_split(dim)
result = self._to_dataset_split(dim)
else:
return self._to_dataset_whole(name)
result = self._to_dataset_whole(name)

if promote_attrs:
result.attrs = dict(self.attrs)

return result

@property
def name(self) -> Optional[Hashable]:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def __init__(
if isinstance(coords, Dataset):
coords = coords.variables

variables, coord_names, dims, indexes = merge_data_and_coords(
variables, coord_names, dims, indexes, _ = merge_data_and_coords(
data_vars, coords, compat="broadcast_equals"
)

Expand Down
Loading

0 comments on commit d8bb620

Please sign in to comment.