diff --git a/cylc/flow/network/multi.py b/cylc/flow/network/multi.py
index 724a8cbc368..2b9ea418976 100644
--- a/cylc/flow/network/multi.py
+++ b/cylc/flow/network/multi.py
@@ -15,9 +15,8 @@
# along with this program. If not, see .
import asyncio
-from functools import partial
import sys
-from typing import Callable, Dict, List, Tuple, Optional, Union
+from typing import Callable, Dict, List, Tuple, Optional, Union, Type
from ansimarkup import ansiprint
@@ -46,6 +45,7 @@ async def call_multi_async(
] = None,
max_workflows=None,
max_tasks=None,
+ success_exceptions: Optional[Tuple[Type]] = None,
) -> Dict[str, bool]:
"""Call a function for each workflow in a list of IDs.
@@ -71,6 +71,10 @@ async def call_multi_async(
Reporter functions are provided with the "response". They must
return the outcome of the operation and may also return stdout/err
text which will be written to the terminal.
+ success_exceptions:
+ An optional tuple of exceptions that can convey success outcomes.
+ E.G. a "WorkflowStopped" exception indicates an error state for
+ "cylc broadcast" but a success state for "cylc stop".
Returns:
{workflow_id: outcome}
@@ -90,9 +94,9 @@ async def call_multi_async(
if not report:
report = _report
if multi_mode:
- reporter = partial(_report_multi, report)
+ reporter = _report_multi
else:
- reporter = partial(_report_single, report)
+ reporter = _report_single
if constraint == 'workflows':
workflow_args = {
@@ -102,7 +106,7 @@ async def call_multi_async(
# run coros
results: Dict[str, bool] = {}
- async for (workflow_id, *args), result in unordered_map(
+ async for (workflow_id, *args), response in unordered_map(
fcn,
(
(workflow_id, *args)
@@ -112,19 +116,24 @@ async def call_multi_async(
# (this way if one command errors, others may still run)
wrap_exceptions=True,
):
- results[workflow_id] = reporter(workflow_id, result)
+ # get outcome
+ out, err, outcome = _process_response(
+ report, response, success_exceptions
+ )
+ # report outcome
+ reporter(workflow_id, out, err)
+ results[workflow_id] = outcome
+
return results
def _report_multi(
- report: Callable, workflow: str, response: Union[dict, Exception]
-) -> bool:
+ workflow: str, out: Optional[str], err: Optional[str]
+) -> None:
"""Report a response for a multi-workflow operation.
This is called once for each workflow the operation is called against.
"""
- out, err, outcome = _process_response(report, response)
-
msg = f'{workflow}:'
if out:
out = out.replace('\n', '\n ') # indent
@@ -137,26 +146,21 @@ def _report_multi(
err = f'{msg} {err}'
ansiprint(err, file=sys.stdout)
- return outcome
-
def _report_single(
- report: Callable, _workflow: str, response: Union[dict, Exception]
-) -> bool:
+ workflow: str, out: Optional[str], err: Optional[str]
+) -> None:
"""Report the response for a single-workflow operation."""
- out, err, outcome = _process_response(report, response)
-
if out:
ansiprint(out)
if err:
ansiprint(err, file=sys.stderr)
- return outcome
-
def _process_response(
report: Callable,
response: Union[dict, Exception],
+ success_exceptions: Optional[Tuple[Type]] = None,
) -> Tuple[Optional[str], Optional[str], bool]:
"""Handle exceptions and return processed results.
@@ -169,16 +173,28 @@ def _process_response(
report:
The reporter function for extracting the result from the provided
response.
+ success_exceptions:
+ An optional tuple of exceptions that can convey success outcomes.
+ E.G. a "WorkflowStopped" exception indicates an error state for
+ "cylc broadcast" but a success state for "cylc stop".
Returns:
(stdout, stderr, outcome)
"""
- if isinstance(response, WorkflowStopped):
- # workflow stopped -> can't do anything
+ if success_exceptions and isinstance(response, success_exceptions):
+ # an exception was raised, however, that exception indicates a success
+ # outcome in this case
+ out = f'{response.__class__.__name__}: {response}'
+ err = None
+ outcome = True
+
+ elif isinstance(response, WorkflowStopped):
+ # workflow stopped -> report differently to other CylcErrors
out = None
err = f'{response.__class__.__name__}: {response}'
outcome = False
+
elif isinstance(response, CylcError):
# exception -> report error
if cylc.flow.flags.verbosity > 1: # debug mode
@@ -186,9 +202,11 @@ def _process_response(
out = None
err = f'{response.__class__.__name__}: {response}'
outcome = False
+
elif isinstance(response, Exception):
# unexpected error -> raise
raise response
+
else:
try:
# run the reporter to extract the operation outcome
@@ -196,7 +214,7 @@ def _process_response(
except Exception as exc:
# an exception was raised in the reporter -> report this error the
# same was as an error in the response
- return _process_response(report, exc)
+ return _process_response(report, exc, success_exceptions)
return out, err, outcome
diff --git a/cylc/flow/scripts/stop.py b/cylc/flow/scripts/stop.py
index 8154e52dc80..543bb257dea 100755
--- a/cylc/flow/scripts/stop.py
+++ b/cylc/flow/scripts/stop.py
@@ -71,6 +71,7 @@
ClientTimeout,
CylcError,
InputError,
+ WorkflowStopped,
)
from cylc.flow.network.client_factory import get_client
from cylc.flow.network.multi import call_multi
@@ -206,6 +207,14 @@ async def run(
options: 'Values',
workflow_id,
*tokens_list,
+) -> object:
+ return await _run(options, workflow_id, *tokens_list)
+
+
+async def _run(
+ options: 'Values',
+ workflow_id,
+ *tokens_list,
) -> object:
# parse the stop-task or stop-cycle if provided
stop_task = stop_cycle = None
@@ -274,5 +283,6 @@ def main(
*ids,
constraint='mixed',
max_tasks=1,
+ success_exceptions=(WorkflowStopped,),
)
sys.exit(all(rets.values()) is False)
diff --git a/tests/unit/network/test_multi.py b/tests/unit/network/test_multi.py
index 81e999b0c22..e7d0f7a344a 100644
--- a/tests/unit/network/test_multi.py
+++ b/tests/unit/network/test_multi.py
@@ -35,13 +35,26 @@ def test_report_valid(monkeypatch):
"""It should report command outcome."""
monkeypatch.setattr('cylc.flow.flags.verbosity', 0)
- assert _report(response(False, 'MyError')) == (None, 'MyError', False)
- assert _report(response(True, '12345')) == ('Command queued', None, True)
+ # fail case
+ assert _report(response(False, 'MyError')) == (
+ None,
+ 'MyError',
+ False,
+ )
+
+ # success case
+ assert _report(response(True, '12345')) == (
+ 'Command queued',
+ None,
+ True,
+ )
+ # success case (debug mode)
monkeypatch.setattr('cylc.flow.flags.verbosity', 1)
- assert (
- _report(response(True, '12345'))
- == (f'Command queued <{DIM}>id=12345{DIM}>', None, True)
+ assert _report(response(True, '12345')) == (
+ f'Command queued <{DIM}>id=12345{DIM}>',
+ None,
+ True,
)
@@ -105,7 +118,7 @@ def report(exception_class, _response):
class Foo(Exception):
pass
- # WorkflowStopped -> expected error, log it
+ # WorkflowStopped -> fail case
monkeypatch.setattr('cylc.flow.flags.verbosity', 0)
assert _process_response(partial(report, WorkflowStopped), {}) == (
None,
@@ -113,6 +126,20 @@ class Foo(Exception):
False,
)
+ # WorkflowStopped -> success case for this command
+ monkeypatch.setattr('cylc.flow.flags.verbosity', 0)
+ assert _process_response(
+ partial(report, WorkflowStopped),
+ {},
+ # this overrides the default interpretation of "WorkflowStopped" as a
+ # fail case
+ success_exceptions=(WorkflowStopped,),
+ ) == (
+ 'WorkflowStopped: xxx is not running',
+ None,
+ True, # success outcome
+ )
+
# CylcError -> expected error, log it
monkeypatch.setattr('cylc.flow.flags.verbosity', 0)
assert _process_response(partial(report, CylcError), {}) == (