Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce Dask computations in DayNightCompositor #2617

Merged
merged 8 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
import warnings
from typing import Optional, Sequence

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -119,7 +120,12 @@ def id(self):
id_keys = self.attrs.get("_satpy_id_keys", minimal_default_keys_config)
return DataID(id_keys, **self.attrs)

def __call__(self, datasets, optional_datasets=None, **info):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**info
) -> xr.DataArray:
"""Generate a composite."""
raise NotImplementedError()

Expand Down Expand Up @@ -422,7 +428,12 @@ def _get_sensors(self, projectables):
sensor = list(sensor)[0]
return sensor

def __call__(self, projectables, nonprojectables=None, **attrs):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**attrs
) -> xr.DataArray:
"""Build the composite."""
if "deprecation_warning" in self.attrs:
warnings.warn(
Expand All @@ -431,29 +442,29 @@ def __call__(self, projectables, nonprojectables=None, **attrs):
stacklevel=2
)
self.attrs.pop("deprecation_warning", None)
num = len(projectables)
num = len(datasets)
mode = attrs.get("mode")
if mode is None:
# num may not be in `self.modes` so only check if we need to
mode = self.modes[num]
if len(projectables) > 1:
projectables = self.match_data_arrays(projectables)
data = self._concat_datasets(projectables, mode)
if len(datasets) > 1:
datasets = self.match_data_arrays(datasets)
data = self._concat_datasets(datasets, mode)
# Skip masking if user wants it or a specific alpha channel is given.
if self.common_channel_mask and mode[-1] != "A":
data = data.where(data.notnull().all(dim="bands"))
else:
data = projectables[0]
data = datasets[0]

# if inputs have a time coordinate that may differ slightly between
# themselves then find the mid time and use that as the single
# time coordinate value
if len(projectables) > 1:
time = check_times(projectables)
if len(datasets) > 1:
time = check_times(datasets)
if time is not None and "time" in data.dims:
data["time"] = [time]

new_attrs = combine_metadata(*projectables)
new_attrs = combine_metadata(*datasets)
# remove metadata that shouldn't make sense in a composite
new_attrs["wavelength"] = None
new_attrs.pop("units", None)
Expand All @@ -467,7 +478,7 @@ def __call__(self, projectables, nonprojectables=None, **attrs):
new_attrs.update(self.attrs)
if resolution is not None:
new_attrs["resolution"] = resolution
new_attrs["sensor"] = self._get_sensors(projectables)
new_attrs["sensor"] = self._get_sensors(datasets)
new_attrs["mode"] = mode

return xr.DataArray(data=data.data, attrs=new_attrs,
Expand Down Expand Up @@ -692,22 +703,27 @@ def __init__(self, name, lim_low=85., lim_high=88., day_night="day_night", inclu
self._has_sza = False
super(DayNightCompositor, self).__init__(name, **kwargs)

def __call__(self, projectables, **kwargs):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**attrs
) -> xr.DataArray:
"""Generate the composite."""
projectables = self.match_data_arrays(projectables)
datasets = self.match_data_arrays(datasets)
# At least one composite is requested.
foreground_data = projectables[0]
foreground_data = datasets[0]

weights = self._get_coszen_blending_weights(projectables)
weights = self._get_coszen_blending_weights(datasets)

# Apply enhancements to the foreground data
foreground_data = enhance2dataset(foreground_data)

if "only" in self.day_night:
attrs = foreground_data.attrs.copy()
fg_attrs = foreground_data.attrs.copy()
day_data, night_data, weights = self._get_data_for_single_side_product(foreground_data, weights)
else:
day_data, night_data, attrs = self._get_data_for_combined_product(foreground_data, projectables[1])
day_data, night_data, fg_attrs = self._get_data_for_combined_product(foreground_data, datasets[1])

# The computed coszen is for the full area, so it needs to be masked for missing and off-swath data
if self.include_alpha and not self._has_sza:
Expand All @@ -718,11 +734,18 @@ def __call__(self, projectables, **kwargs):
day_data = zero_missing_data(day_data, night_data)
night_data = zero_missing_data(night_data, day_data)

data = self._weight_data(day_data, night_data, weights, attrs)
data = self._weight_data(day_data, night_data, weights, fg_attrs)

return super(DayNightCompositor, self).__call__(data, **kwargs)
return super(DayNightCompositor, self).__call__(
data,
optional_datasets=optional_datasets,
**attrs
)

def _get_coszen_blending_weights(self, projectables):
def _get_coszen_blending_weights(
self,
projectables: Sequence[xr.DataArray],
) -> xr.DataArray:
lim_low = np.cos(np.deg2rad(self.lim_low))
lim_high = np.cos(np.deg2rad(self.lim_high))
try:
Expand All @@ -739,7 +762,11 @@ def _get_coszen_blending_weights(self, projectables):

return coszen.clip(0, 1)

def _get_data_for_single_side_product(self, foreground_data, weights):
def _get_data_for_single_side_product(
self,
foreground_data: xr.DataArray,
weights: xr.DataArray,
) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
# Only one portion (day or night) is selected. One composite is requested.
# Add alpha band to single L/RGB composite to make the masked-out portion transparent when needed
# L -> LA
Expand All @@ -754,8 +781,8 @@ def _get_data_for_single_side_product(self, foreground_data, weights):

def _mask_weights(self, weights):
if "day" in self.day_night:
return da.where(weights != 0, weights, np.nan)
return da.where(weights != 1, weights, np.nan)
return weights.where(weights != 0, np.nan)
return weights.where(weights != 1, np.nan)

def _get_day_night_data_for_single_side_product(self, foreground_data):
if "day" in self.day_night:
Expand All @@ -778,25 +805,33 @@ def _get_data_for_combined_product(self, day_data, night_data):

return day_data, night_data, attrs

def _mask_weights_with_data(self, weights, day_data, night_data):
def _mask_weights_with_data(
self,
weights: xr.DataArray,
day_data: xr.DataArray,
night_data: xr.DataArray,
) -> xr.DataArray:
data_a = _get_single_channel(day_data)
data_b = _get_single_channel(night_data)
if "only" in self.day_night:
mask = _get_weight_mask_for_single_side_product(data_a, data_b)
else:
mask = _get_weight_mask_for_daynight_product(weights, data_a, data_b)

return da.where(mask, weights, np.nan)
return weights.where(mask, np.nan)

def _weight_data(self, day_data, night_data, weights, attrs):
def _weight_data(
self,
day_data: xr.DataArray,
night_data: xr.DataArray,
weights: xr.DataArray,
attrs: dict,
) -> list[xr.DataArray]:
if not self.include_alpha:
fill = 1 if self.day_night == "night_only" else 0
weights = da.where(np.isnan(weights), fill, weights)

weights = weights.where(~np.isnan(weights), fill)
data = []
for b in _get_band_names(day_data, night_data):
# if self.day_night == "night_only" and self.include_alpha is False:
# import ipdb; ipdb.set_trace()
day_band = _get_single_band_data(day_data, b)
night_band = _get_single_band_data(night_data, b)
# For day-only and night-only products only the alpha channel is weighted
Expand Down Expand Up @@ -824,9 +859,12 @@ def _get_single_band_data(data, band):
return data.sel(bands=band)


def _get_single_channel(data):
def _get_single_channel(data: xr.DataArray) -> xr.DataArray:
try:
data = data[0, :, :]
# remove coordinates that may be band-specific (ex. "bands")
# and we don't care about anymore
data = data.reset_coords(drop=True)
except (IndexError, TypeError):
pass
return data
Expand Down
2 changes: 1 addition & 1 deletion satpy/modifiers/_crefl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __call__(self, sensor_azimuth, sensor_zenith, solar_azimuth, solar_zenith, a
def _run_crefl(self, mus, muv, phi, solar_zenith, sensor_zenith, height, coeffs):
raise NotImplementedError()

def _height_from_avg_elevation(self, avg_elevation: Optional[np.ndarray]) -> da.Array:
def _height_from_avg_elevation(self, avg_elevation: Optional[np.ndarray]) -> da.Array | float:
"""Get digital elevation map data for our granule with ocean fill value set to 0."""
if avg_elevation is None:
LOG.debug("No average elevation information provided in CREFL")
Expand Down
89 changes: 56 additions & 33 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyresample import AreaDefinition

import satpy
from satpy.tests.utils import CustomScheduler

# NOTE:
# The following fixtures are not defined in this file, but are used and injected by Pytest:
Expand Down Expand Up @@ -431,28 +432,34 @@ def setUp(self):
def test_daynight_sza(self):
"""Test compositor with both day and night portions when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b, self.sza))
res = res.compute()
expected = np.array([[0., 0.22122352], [0.5, 1.]])
np.testing.assert_allclose(res.values[0], expected)

def test_daynight_area(self):
"""Test compositor both day and night portions when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b))
res = res.compute()
expected_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel)

def test_night_only_sza_with_alpha(self):
"""Test compositor with night portion with alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b, self.sza))
res = res.compute()
expected_red_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[0., 0.33296056], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_red_channel)
Expand All @@ -461,19 +468,23 @@ def test_night_only_sza_with_alpha(self):
def test_night_only_sza_without_alpha(self):
"""Test compositor with night portion without alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected = np.array([[0., 0.11042631], [0.66835017, 1.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

def test_night_only_area_with_alpha(self):
"""Test compositor with night portion with alpha band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 0.], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -482,19 +493,23 @@ def test_night_only_area_with_alpha(self):
def test_night_only_area_without_alpha(self):
"""Test compositor with night portion without alpha band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_b,))
res = res.compute()
expected = np.array([[np.nan, 0.], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

def test_day_only_sza_with_alpha(self):
"""Test compositor with day portion with alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_red_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 0.66703944], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected_red_channel)
Expand All @@ -503,9 +518,11 @@ def test_day_only_sza_with_alpha(self):
def test_day_only_sza_without_alpha(self):
"""Test compositor with day portion without alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_channel_data = np.array([[0., 0.22122352], [0., 0.]])
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel_data)
Expand All @@ -514,9 +531,11 @@ def test_day_only_sza_without_alpha(self):
def test_day_only_area_with_alpha(self):
"""Test compositor with day portion with alpha_band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a,))
res = res.compute()
expected_l_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 1.], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -525,9 +544,11 @@ def test_day_only_area_with_alpha(self):
def test_day_only_area_with_alpha_and_missing_data(self):
"""Test compositor with day portion with alpha_band when SZA data is not provided and there is missing data."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 1.], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -536,9 +557,11 @@ def test_day_only_area_with_alpha_and_missing_data(self):
def test_day_only_area_without_alpha(self):
"""Test compositor with day portion without alpha_band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a,))
res = res.compute()
expected = np.array([[0., 0.33164983], [0.66835017, 1.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands
Expand Down