diff --git a/satpy/dataset.py b/satpy/dataset.py index db1055d385..14f2042224 100644 --- a/satpy/dataset.py +++ b/satpy/dataset.py @@ -17,14 +17,12 @@ # satpy. If not, see . """Dataset objects.""" -import sys import logging import numbers from collections import namedtuple +from collections.abc import Collection from datetime import datetime -import numpy as np - logger = logging.getLogger(__name__) @@ -62,11 +60,13 @@ def average_datetimes(dt_list): def combine_metadata(*metadata_objects, **kwargs): """Combine the metadata of two or more Datasets. - If any keys are not equal or do not exist in all provided dictionaries - then they are not included in the returned dictionary. - By default any keys with the word 'time' in them and consisting - of datetime objects will be averaged. This is to handle cases where - data were observed at almost the same time but not exactly. + If the values corresponding to any keys are not equal or do not + exist in all provided dictionaries then they are not included in + the returned dictionary. By default any keys with the word 'time' + in them and consisting of datetime objects will be averaged. This + is to handle cases where data were observed at almost the same time + but not exactly. In the interest of time, arrays are compared by + object identity rather than by their contents. Args: *metadata_objects: MetadataObject or dict objects to combine @@ -98,18 +98,57 @@ def combine_metadata(*metadata_objects, **kwargs): shared_info = {} for k in shared_keys: values = [nfo[k] for nfo in info_dicts] - any_arrays = any([isinstance(val, np.ndarray) for val in values]) - if any_arrays: - if all(np.all(val == values[0]) for val in values[1:]): + if _share_metadata_key(k, values, average_times): + if 'time' in k and isinstance(values[0], datetime) and average_times: + shared_info[k] = average_datetimes(values) + else: shared_info[k] = values[0] - elif 'time' in k and isinstance(values[0], datetime) and average_times: - shared_info[k] = average_datetimes(values) - elif all(val == values[0] for val in values[1:]): - shared_info[k] = values[0] return shared_info +def _share_metadata_key(k, values, average_times): + """Helper for combine_metadata, decide if key is shared.""" + any_arrays = any([hasattr(val, "__array__") for val in values]) + # in the real world, the `ancillary_variables` attribute may be + # List[xarray.DataArray], this means our values are now + # List[List[xarray.DataArray]]. + # note that this list_of_arrays check is also true for any + # higher-dimensional ndarray, but we only use this check after we have + # checked any_arrays so this false positive should have no impact + list_of_arrays = any( + [isinstance(val, Collection) and len(val) > 0 and + all([hasattr(subval, "__array__") + for subval in val]) + for val in values]) + if any_arrays: + return _share_metadata_key_array(values) + elif list_of_arrays: + return _share_metadata_key_list_arrays(values) + elif 'time' in k and isinstance(values[0], datetime) and average_times: + return True + elif all(val == values[0] for val in values[1:]): + return True + return False + + +def _share_metadata_key_array(values): + """Helper for combine_metadata, check object identity in list of arrays.""" + for val in values[1:]: + if val is not values[0]: + return False + return True + + +def _share_metadata_key_list_arrays(values): + """Helper for combine_metadata, check object identity in list of list of arrays.""" + for val in values[1:]: + for arr, ref in zip(val, values[0]): + if arr is not ref: + return False + return True + + DATASET_KEYS = ("name", "wavelength", "resolution", "polarization", "calibration", "level", "modifiers") DatasetID = namedtuple("DatasetID", " ".join(DATASET_KEYS)) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index f72bb147e5..82bd360dc2 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -84,3 +84,39 @@ def test_combine_times(self): ret = combine_metadata(*dts, average_times=False) # times are not equal so don't include it in the final result self.assertNotIn('start_time', ret) + + def test_combine_arrays(self): + """Test the combine_metadata with arrays.""" + from satpy.dataset import combine_metadata + from numpy import arange, ones + from xarray import DataArray + dts = [ + {"quality": (arange(25) % 2).reshape(5, 5).astype("?")}, + {"quality": (arange(1, 26) % 3).reshape(5, 5).astype("?")}, + {"quality": ones((5, 5,), "?")}, + ] + assert "quality" not in combine_metadata(*dts) + dts2 = [{"quality": DataArray(d["quality"])} for d in dts] + assert "quality" not in combine_metadata(*dts2) + # the ancillary_variables attribute is actually a list of data arrays + dts3 = [{"quality": [d["quality"]]} for d in dts] + assert "quality" not in combine_metadata(*dts3) + # check cases with repeated arrays + dts4 = [ + {"quality": dts[0]["quality"]}, + {"quality": dts[0]["quality"]}, + ] + assert "quality" in combine_metadata(*dts4) + dts5 = [ + {"quality": dts3[0]["quality"]}, + {"quality": dts3[0]["quality"]}, + ] + assert "quality" in combine_metadata(*dts5) + # check with other types + dts6 = [ + DataArray(arange(5), attrs=dts[0]), + DataArray(arange(5), attrs=dts[0]), + DataArray(arange(5), attrs=dts[1]), + object() + ] + assert "quality" not in combine_metadata(*dts6)