Skip to content

Commit

Permalink
Only save data from one preprocessing task at a time with the Distrib…
Browse files Browse the repository at this point in the history
…uted scheduler (#2610)

Co-authored-by: Manuel Schlund <32543114+schlunma@users.noreply.github.com>
  • Loading branch information
bouweandela and schlunma authored Jan 17, 2025
1 parent 5c585da commit 04e0cbd
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 45 deletions.
95 changes: 56 additions & 39 deletions esmvalcore/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import importlib
import logging
import multiprocessing
import numbers
import os
import pprint
Expand All @@ -14,7 +15,6 @@
import threading
import time
from copy import deepcopy
from multiprocessing import Pool
from pathlib import Path, PosixPath
from shutil import which
from typing import Optional
Expand Down Expand Up @@ -260,6 +260,7 @@ def __init__(self, ancestors=None, name="", products=None):
self.name = name
self.activity = None
self.priority = 0
self.scheduler_lock = None

def initialize_provenance(self, recipe_entity):
"""Initialize task provenance activity."""
Expand Down Expand Up @@ -854,60 +855,76 @@ def done(task):
"""Assume a task is done if it not scheduled or running."""
return not (task in scheduled or task in running)

with Pool(processes=max_parallel_tasks) as pool:
while scheduled or running:
# Submit new tasks to pool
for task in sorted(scheduled, key=lambda t: t.priority):
if len(running) >= max_parallel_tasks:
break
if all(done(t) for t in task.ancestors):
future = pool.apply_async(
_run_task, [task, scheduler_address]
with multiprocessing.Manager() as manager:
# Use a lock to avoid overloading the Dask workers by making only
# one :class:`esmvalcore.preprocessor.PreprocessingTask` submit its
# data save task graph to the distributed scheduler at a time.
#
# See https://github.com/ESMValGroup/ESMValCore/issues/2609 for
# additional detail.
scheduler_lock = (
None if scheduler_address is None else manager.Lock()
)

with multiprocessing.Pool(processes=max_parallel_tasks) as pool:
while scheduled or running:
# Submit new tasks to pool
for task in sorted(scheduled, key=lambda t: t.priority):
if len(running) >= max_parallel_tasks:
break
if all(done(t) for t in task.ancestors):
future = pool.apply_async(
_run_task,
[task, scheduler_address, scheduler_lock],
)
running[task] = future
scheduled.remove(task)

# Handle completed tasks
ready = {t for t in running if running[t].ready()}
for task in ready:
_copy_results(task, running[task])
running.pop(task)

# Wait if there are still tasks running
if running:
time.sleep(0.1)

# Log progress message
if (
len(scheduled) != n_scheduled
or len(running) != n_running
):
n_scheduled, n_running = len(scheduled), len(running)
n_done = n_tasks - n_scheduled - n_running
logger.info(
"Progress: %s tasks running, %s tasks waiting for "
"ancestors, %s/%s done",
n_running,
n_scheduled,
n_done,
n_tasks,
)
running[task] = future
scheduled.remove(task)

# Handle completed tasks
ready = {t for t in running if running[t].ready()}
for task in ready:
_copy_results(task, running[task])
running.pop(task)

# Wait if there are still tasks running
if running:
time.sleep(0.1)

# Log progress message
if len(scheduled) != n_scheduled or len(running) != n_running:
n_scheduled, n_running = len(scheduled), len(running)
n_done = n_tasks - n_scheduled - n_running
logger.info(
"Progress: %s tasks running, %s tasks waiting for "
"ancestors, %s/%s done",
n_running,
n_scheduled,
n_done,
n_tasks,
)

logger.info("Successfully completed all tasks.")
pool.close()
pool.join()
logger.info("Successfully completed all tasks.")
pool.close()
pool.join()


def _copy_results(task, future):
"""Update task with the results from the remote process."""
task.output_files, task.products = future.get()


def _run_task(task, scheduler_address):
def _run_task(task, scheduler_address, scheduler_lock):
"""Run task and return the result."""
if scheduler_address is None:
client = contextlib.nullcontext()
else:
client = Client(scheduler_address)

with client:
task.scheduler_lock = scheduler_lock
output_files = task.run()

return output_files, task.products
18 changes: 16 additions & 2 deletions esmvalcore/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,23 @@ def _run(self, _) -> list[str]:
delayed = product.close()
delayeds.append(delayed)

logger.info("Computing and saving data for task %s", self.name)
delayeds = [d for d in delayeds if d is not None]
_compute_with_progress(delayeds, description=self.name)

if self.scheduler_lock is not None:
logger.debug("Acquiring save lock for task %s", self.name)
self.scheduler_lock.acquire()
logger.debug("Acquired save lock for task %s", self.name)
try:
logger.info(
"Computing and saving data for preprocessing task %s",
self.name,
)
_compute_with_progress(delayeds, description=self.name)
finally:
if self.scheduler_lock is not None:
self.scheduler_lock.release()
logger.debug("Released save lock for task %s", self.name)

metadata_files = write_metadata(
self.products, self.write_ncl_interface
)
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/preprocessor/test_preprocessing_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import iris
import iris.cube
import pytest
from prov.model import ProvDocument

import esmvalcore.preprocessor
from esmvalcore.dataset import Dataset
from esmvalcore.preprocessor import PreprocessingTask, PreprocessorFile


def test_load_save_task(tmp_path):
@pytest.mark.parametrize("scheduler_lock", [False, True])
def test_load_save_task(tmp_path, mocker, scheduler_lock):
"""Test that a task that just loads and saves a file."""
# Prepare a test dataset
cube = iris.cube.Cube(data=[273.0], var_name="tas", units="K")
Expand All @@ -36,6 +38,9 @@ def test_load_save_task(tmp_path):
activity = provenance.activity("software:esmvalcore")
task.initialize_provenance(activity)

if scheduler_lock:
task.scheduler_lock = mocker.Mock()

task.run()

assert len(task.products) == 1
Expand All @@ -45,6 +50,12 @@ def test_load_save_task(tmp_path):
result.attributes.clear()
assert result == cube

if scheduler_lock:
task.scheduler_lock.acquire.assert_called_once_with()
task.scheduler_lock.release.assert_called_once_with()
else:
assert task.scheduler_lock is None


def test_load_save_and_other_task(tmp_path, monkeypatch):
"""Test that a task just copies one file and preprocesses another file."""
Expand Down
14 changes: 11 additions & 3 deletions tests/integration/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def test_run_tasks(monkeypatch, max_parallel_tasks, example_tasks, mpmethod):
get_distributed_client_mock(None),
)
monkeypatch.setattr(
esmvalcore._task, "Pool", multiprocessing.get_context(mpmethod).Pool
esmvalcore._task.multiprocessing,
"Pool",
multiprocessing.get_context(mpmethod).Pool,
)
example_tasks.run(max_parallel_tasks=max_parallel_tasks)

Expand Down Expand Up @@ -152,7 +154,7 @@ def _run(self, input_files):
return [f"{self.name}_test.nc"]

monkeypatch.setattr(MockBaseTask, "_run", _run)
monkeypatch.setattr(esmvalcore._task, "Pool", ThreadPool)
monkeypatch.setattr(esmvalcore._task.multiprocessing, "Pool", ThreadPool)

runner(example_tasks)
print(order)
Expand All @@ -165,11 +167,17 @@ def test_run_task(mocker, address):
# Set up mock Dask distributed client
mocker.patch.object(esmvalcore._task, "Client")

# Set up a mock multiprocessing.Lock
scheduler_lock = mocker.sentinel

task = mocker.create_autospec(DiagnosticTask, instance=True)
task.products = mocker.Mock()
output_files, products = _run_task(task, scheduler_address=address)
output_files, products = _run_task(
task, scheduler_address=address, scheduler_lock=scheduler_lock
)
assert output_files == task.run.return_value
assert products == task.products
assert task.scheduler_lock == scheduler_lock
if address is None:
esmvalcore._task.Client.assert_not_called()
else:
Expand Down

0 comments on commit 04e0cbd

Please sign in to comment.