Skip to content

Commit

Permalink
Merge pull request #1639 from BENR0/fix_cf_writer_returning_two_delayeds
Browse files Browse the repository at this point in the history
Fix MultiScene writer handling of multiple delayed objects
  • Loading branch information
djhoese authored May 17, 2021
2 parents 590c7c3 + 3a07044 commit 62c0e83
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
12 changes: 7 additions & 5 deletions satpy/multiscene.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from satpy.dataset import DataID, combine_metadata
from satpy.scene import Scene
from satpy.writers import get_enhanced_image
from satpy.writers import get_enhanced_image, split_results

try:
import imageio
Expand Down Expand Up @@ -340,18 +340,20 @@ def load_data(q):
q.task_done()

input_q = Queue(batch_size if batch_size is not None else 1)
load_thread = Thread(target=load_data, args=(input_q,))
# set threads to daemon so they are killed if error is raised from main thread
load_thread = Thread(target=load_data, args=(input_q,), daemon=True)
load_thread.start()

for scene in scenes_iter:
delayed = scene.save_datasets(compute=False, **kwargs)
if isinstance(delayed, (list, tuple)) and len(delayed) == 2:
delayeds = scene.save_datasets(compute=False, **kwargs)
sources, targets, delayeds = split_results(delayeds)
if len(sources) > 0:
# TODO Make this work for (source, target) datasets
# given a target, source combination
raise NotImplementedError("Distributed save_datasets does not support writers "
"that return (source, target) combinations at this time. Use "
"the non-distributed save_datasets instead.")
future = client.compute(delayed)
future = client.compute(delayeds)
input_q.put(future)
input_q.put(None)

Expand Down
39 changes: 37 additions & 2 deletions satpy/tests/test_multiscene.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ def test_save_datasets_simple(self):
self.assertEqual(save_datasets.call_count, 2)

@mock.patch('satpy.multiscene.get_enhanced_image', _fake_get_enhanced_image)
def test_save_datasets_distributed(self):
"""Save a series of fake scenes to an PNG images using dask distributed."""
def test_save_datasets_distributed_delayed(self):
"""Test distributed save for writers returning delayed obejcts e.g. simple_image."""
from satpy import MultiScene
from dask.delayed import Delayed
area = _create_test_area()
scenes = _create_test_scenes(area=area)

Expand All @@ -453,6 +454,7 @@ def test_save_datasets_distributed(self):
client_mock.compute.side_effect = lambda x: tuple(v for v in x)
client_mock.gather.side_effect = lambda x: x
future_mock = mock.MagicMock()
future_mock.__class__ = Delayed
with mock.patch('satpy.multiscene.Scene.save_datasets') as save_datasets:
save_datasets.return_value = [future_mock] # some arbitrary return value
# force order of datasets by specifying them
Expand All @@ -462,6 +464,39 @@ def test_save_datasets_distributed(self):
# 2 for each scene
self.assertEqual(save_datasets.call_count, 2)

@mock.patch('satpy.multiscene.get_enhanced_image', _fake_get_enhanced_image)
def test_save_datasets_distributed_source_target(self):
"""Test distributed save for writers returning sources and targets e.g. geotiff writer."""
from satpy import MultiScene
import dask.array as da
area = _create_test_area()
scenes = _create_test_scenes(area=area)

# Add a dataset to only one of the Scenes
scenes[1]['ds3'] = _create_test_dataset('ds3')
# Add a start and end time
for ds_id in ['ds1', 'ds2', 'ds3']:
scenes[1][ds_id].attrs['start_time'] = datetime(2018, 1, 2)
scenes[1][ds_id].attrs['end_time'] = datetime(2018, 1, 2, 12)
if ds_id == 'ds3':
continue
scenes[0][ds_id].attrs['start_time'] = datetime(2018, 1, 1)
scenes[0][ds_id].attrs['end_time'] = datetime(2018, 1, 1, 12)

mscn = MultiScene(scenes)
client_mock = mock.MagicMock()
client_mock.compute.side_effect = lambda x: tuple(v for v in x)
client_mock.gather.side_effect = lambda x: x
source_mock = mock.MagicMock()
source_mock.__class__ = da.Array
target_mock = mock.MagicMock()
with mock.patch('satpy.multiscene.Scene.save_datasets') as save_datasets:
save_datasets.return_value = [(source_mock, target_mock)] # some arbitrary return value
# force order of datasets by specifying them
with self.assertRaises(NotImplementedError):
mscn.save_datasets(base_dir=self.base_dir, client=client_mock, datasets=['ds1', 'ds2', 'ds3'],
writer='geotiff')

def test_crop(self):
"""Test the crop method."""
from satpy import Scene, MultiScene
Expand Down
34 changes: 17 additions & 17 deletions satpy/writers/cf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,18 +703,6 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
"""
logger.info('Saving datasets to NetCDF4/CF.')

if groups is None:
# Write all datasets to the file root without creating a group
groups_ = {None: datasets}
else:
# User specified a group assignment using dataset names. Collect the corresponding datasets.
groups_ = defaultdict(list)
for dataset in datasets:
for group_name, group_members in groups.items():
if dataset.attrs['name'] in group_members:
groups_[group_name].append(dataset)
break

if compression is None:
compression = {'zlib': True}

Expand All @@ -734,11 +722,6 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
else:
root.attrs['history'] = _history_create

if groups is None:
# Groups are not CF-1.7 compliant
if 'Conventions' not in root.attrs:
root.attrs['Conventions'] = CF_VERSION

# Remove satpy-specific kwargs
to_netcdf_kwargs = copy.deepcopy(to_netcdf_kwargs) # may contain dictionaries (encoding)
satpy_kwargs = ['overlay', 'decorate', 'config_files']
Expand All @@ -748,6 +731,22 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
init_nc_kwargs = to_netcdf_kwargs.copy()
init_nc_kwargs.pop('encoding', None) # No variables to be encoded at this point
init_nc_kwargs.pop('unlimited_dims', None)

if groups is None:
# Groups are not CF-1.7 compliant
if 'Conventions' not in root.attrs:
root.attrs['Conventions'] = CF_VERSION
# Write all datasets to the file root without creating a group
groups_ = {None: datasets}
else:
# User specified a group assignment using dataset names. Collect the corresponding datasets.
groups_ = defaultdict(list)
for dataset in datasets:
for group_name, group_members in groups.items():
if dataset.attrs['name'] in group_members:
groups_[group_name].append(dataset)
break

written = [root.to_netcdf(filename, engine=engine, mode='w', **init_nc_kwargs)]

# Write datasets to groups (appending to the file; group=None means no group)
Expand All @@ -771,4 +770,5 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
res = dataset.to_netcdf(filename, engine=engine, group=group_name, mode='a', encoding=encoding,
**other_to_netcdf_kwargs)
written.append(res)

return written

0 comments on commit 62c0e83

Please sign in to comment.