diff --git a/esmvalcore/_task.py b/esmvalcore/_task.py index 27a6b83d14..cb9269b087 100644 --- a/esmvalcore/_task.py +++ b/esmvalcore/_task.py @@ -5,6 +5,7 @@ import datetime import importlib import logging +import multiprocessing import numbers import os import pprint @@ -14,7 +15,6 @@ import threading import time from copy import deepcopy -from multiprocessing import Pool from pathlib import Path, PosixPath from shutil import which from typing import Optional @@ -260,6 +260,7 @@ def __init__(self, ancestors=None, name="", products=None): self.name = name self.activity = None self.priority = 0 + self.scheduler_lock = None def initialize_provenance(self, recipe_entity): """Initialize task provenance activity.""" @@ -854,45 +855,60 @@ def done(task): """Assume a task is done if it not scheduled or running.""" return not (task in scheduled or task in running) - with Pool(processes=max_parallel_tasks) as pool: - while scheduled or running: - # Submit new tasks to pool - for task in sorted(scheduled, key=lambda t: t.priority): - if len(running) >= max_parallel_tasks: - break - if all(done(t) for t in task.ancestors): - future = pool.apply_async( - _run_task, [task, scheduler_address] + with multiprocessing.Manager() as manager: + # Use a lock to avoid overloading the Dask workers by making only + # one :class:`esmvalcore.preprocessor.PreprocessingTask` submit its + # data save task graph to the distributed scheduler at a time. + # + # See https://github.com/ESMValGroup/ESMValCore/issues/2609 for + # additional detail. + scheduler_lock = ( + None if scheduler_address is None else manager.Lock() + ) + + with multiprocessing.Pool(processes=max_parallel_tasks) as pool: + while scheduled or running: + # Submit new tasks to pool + for task in sorted(scheduled, key=lambda t: t.priority): + if len(running) >= max_parallel_tasks: + break + if all(done(t) for t in task.ancestors): + future = pool.apply_async( + _run_task, + [task, scheduler_address, scheduler_lock], + ) + running[task] = future + scheduled.remove(task) + + # Handle completed tasks + ready = {t for t in running if running[t].ready()} + for task in ready: + _copy_results(task, running[task]) + running.pop(task) + + # Wait if there are still tasks running + if running: + time.sleep(0.1) + + # Log progress message + if ( + len(scheduled) != n_scheduled + or len(running) != n_running + ): + n_scheduled, n_running = len(scheduled), len(running) + n_done = n_tasks - n_scheduled - n_running + logger.info( + "Progress: %s tasks running, %s tasks waiting for " + "ancestors, %s/%s done", + n_running, + n_scheduled, + n_done, + n_tasks, ) - running[task] = future - scheduled.remove(task) - - # Handle completed tasks - ready = {t for t in running if running[t].ready()} - for task in ready: - _copy_results(task, running[task]) - running.pop(task) - - # Wait if there are still tasks running - if running: - time.sleep(0.1) - - # Log progress message - if len(scheduled) != n_scheduled or len(running) != n_running: - n_scheduled, n_running = len(scheduled), len(running) - n_done = n_tasks - n_scheduled - n_running - logger.info( - "Progress: %s tasks running, %s tasks waiting for " - "ancestors, %s/%s done", - n_running, - n_scheduled, - n_done, - n_tasks, - ) - logger.info("Successfully completed all tasks.") - pool.close() - pool.join() + logger.info("Successfully completed all tasks.") + pool.close() + pool.join() def _copy_results(task, future): @@ -900,7 +916,7 @@ def _copy_results(task, future): task.output_files, task.products = future.get() -def _run_task(task, scheduler_address): +def _run_task(task, scheduler_address, scheduler_lock): """Run task and return the result.""" if scheduler_address is None: client = contextlib.nullcontext() @@ -908,6 +924,7 @@ def _run_task(task, scheduler_address): client = Client(scheduler_address) with client: + task.scheduler_lock = scheduler_lock output_files = task.run() return output_files, task.products diff --git a/esmvalcore/preprocessor/__init__.py b/esmvalcore/preprocessor/__init__.py index 2c956aa0ad..6ba0d7c946 100644 --- a/esmvalcore/preprocessor/__init__.py +++ b/esmvalcore/preprocessor/__init__.py @@ -736,9 +736,23 @@ def _run(self, _) -> list[str]: delayed = product.close() delayeds.append(delayed) - logger.info("Computing and saving data for task %s", self.name) delayeds = [d for d in delayeds if d is not None] - _compute_with_progress(delayeds, description=self.name) + + if self.scheduler_lock is not None: + logger.debug("Acquiring save lock for task %s", self.name) + self.scheduler_lock.acquire() + logger.debug("Acquired save lock for task %s", self.name) + try: + logger.info( + "Computing and saving data for preprocessing task %s", + self.name, + ) + _compute_with_progress(delayeds, description=self.name) + finally: + if self.scheduler_lock is not None: + self.scheduler_lock.release() + logger.debug("Released save lock for task %s", self.name) + metadata_files = write_metadata( self.products, self.write_ncl_interface ) diff --git a/tests/integration/preprocessor/test_preprocessing_task.py b/tests/integration/preprocessor/test_preprocessing_task.py index 5b74a94cda..43dc7af6a6 100644 --- a/tests/integration/preprocessor/test_preprocessing_task.py +++ b/tests/integration/preprocessor/test_preprocessing_task.py @@ -2,6 +2,7 @@ import iris import iris.cube +import pytest from prov.model import ProvDocument import esmvalcore.preprocessor @@ -9,7 +10,8 @@ from esmvalcore.preprocessor import PreprocessingTask, PreprocessorFile -def test_load_save_task(tmp_path): +@pytest.mark.parametrize("scheduler_lock", [False, True]) +def test_load_save_task(tmp_path, mocker, scheduler_lock): """Test that a task that just loads and saves a file.""" # Prepare a test dataset cube = iris.cube.Cube(data=[273.0], var_name="tas", units="K") @@ -36,6 +38,9 @@ def test_load_save_task(tmp_path): activity = provenance.activity("software:esmvalcore") task.initialize_provenance(activity) + if scheduler_lock: + task.scheduler_lock = mocker.Mock() + task.run() assert len(task.products) == 1 @@ -45,6 +50,12 @@ def test_load_save_task(tmp_path): result.attributes.clear() assert result == cube + if scheduler_lock: + task.scheduler_lock.acquire.assert_called_once_with() + task.scheduler_lock.release.assert_called_once_with() + else: + assert task.scheduler_lock is None + def test_load_save_and_other_task(tmp_path, monkeypatch): """Test that a task just copies one file and preprocesses another file.""" diff --git a/tests/integration/test_task.py b/tests/integration/test_task.py index 9570ec8e58..d8fec5a416 100644 --- a/tests/integration/test_task.py +++ b/tests/integration/test_task.py @@ -92,7 +92,9 @@ def test_run_tasks(monkeypatch, max_parallel_tasks, example_tasks, mpmethod): get_distributed_client_mock(None), ) monkeypatch.setattr( - esmvalcore._task, "Pool", multiprocessing.get_context(mpmethod).Pool + esmvalcore._task.multiprocessing, + "Pool", + multiprocessing.get_context(mpmethod).Pool, ) example_tasks.run(max_parallel_tasks=max_parallel_tasks) @@ -152,7 +154,7 @@ def _run(self, input_files): return [f"{self.name}_test.nc"] monkeypatch.setattr(MockBaseTask, "_run", _run) - monkeypatch.setattr(esmvalcore._task, "Pool", ThreadPool) + monkeypatch.setattr(esmvalcore._task.multiprocessing, "Pool", ThreadPool) runner(example_tasks) print(order) @@ -165,11 +167,17 @@ def test_run_task(mocker, address): # Set up mock Dask distributed client mocker.patch.object(esmvalcore._task, "Client") + # Set up a mock multiprocessing.Lock + scheduler_lock = mocker.sentinel + task = mocker.create_autospec(DiagnosticTask, instance=True) task.products = mocker.Mock() - output_files, products = _run_task(task, scheduler_address=address) + output_files, products = _run_task( + task, scheduler_address=address, scheduler_lock=scheduler_lock + ) assert output_files == task.run.return_value assert products == task.products + assert task.scheduler_lock == scheduler_lock if address is None: esmvalcore._task.Client.assert_not_called() else: