diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index ae653f37ea..ca7d68a992 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -15,7 +15,7 @@ as_completed, wait, ) -from typing import Any, Iterable, Iterator +from typing import Any, Collection, Iterable, Iterator from more_itertools import interleave from pluggy import PluginManager @@ -198,16 +198,15 @@ def _suggest_resume_scenario( postfix = "" if done_nodes: - node_names = (n.name for n in remaining_nodes) - resume_p = pipeline.only_nodes(*node_names) - start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs()) + # Find which of the remaining nodes would need to run first (in topo sort) + remaining_initial_nodes = _find_initial_nodes(pipeline, remaining_nodes) - # find the nearest persistent ancestors of the nodes in start_p - start_p_persistent_ancestors = _find_persistent_ancestors( - pipeline, start_p.nodes, catalog + # Find the nearest persistent ancestors of these nodes + persistent_ancestors = _find_persistent_ancestors( + pipeline, remaining_initial_nodes, catalog ) - start_node_names = (n.name for n in start_p_persistent_ancestors) + start_node_names = sorted(n.name for n in persistent_ancestors) postfix += f" --from-nodes \"{','.join(start_node_names)}\"" if not postfix: @@ -230,7 +229,7 @@ def _find_persistent_ancestors( ) -> set[Node]: """Breadth-first search approach to finding the complete set of persistent ancestors of an iterable of ``Node``s. Persistent - ancestors exclusively have persisted ``Dataset``s as inputs. + ancestors exclusively have persisted ``Dataset``s or parameters as inputs. Args: pipeline: the ``Pipeline`` to find ancestors in. @@ -242,54 +241,86 @@ def _find_persistent_ancestors( ``Node``s. """ - ancestor_nodes_to_run = set() + initial_nodes_to_run: set[Node] = set() + queue, visited = deque(children), set(children) while queue: current_node = queue.popleft() - if _has_persistent_inputs(current_node, catalog): - ancestor_nodes_to_run.add(current_node) + impersistent_inputs = _enumerate_impersistent_inputs(current_node, catalog) + + # If all inputs are persistent, we can run this node as is + if not impersistent_inputs: + initial_nodes_to_run.add(current_node) continue - for parent in _enumerate_parents(pipeline, current_node): - if parent in visited: + + # Otherwise, look for the nodes that produce impersistent inputs + for node in _enumerate_nodes_with_outputs(pipeline, impersistent_inputs): + if node in visited: continue - visited.add(parent) - queue.append(parent) - return ancestor_nodes_to_run + visited.add(node) + queue.append(node) + + return initial_nodes_to_run + + +def _enumerate_impersistent_inputs(node: Node, catalog: DataCatalog) -> set[str]: + """Enumerate impersistent input Datasets of a ``Node``. + + Args: + node: the ``Node`` to check the inputs of. + catalog: the ``DataCatalog`` of the run. + Returns: + Set of names of impersistent inputs of given ``Node``. + + """ + # We use _data_sets because they pertain parameter name format + catalog_datasets = catalog._datasets + missing_inputs: set[str] = set() + for node_input in node.inputs: + # Important difference vs. Kedro approach + if node_input.startswith("params:"): + continue + if isinstance(catalog_datasets[node_input], MemoryDataset): + missing_inputs.add(node_input) -def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]: - """For a given ``Node``, returns a list containing the direct parents - of that ``Node`` in the given ``Pipeline``. + return missing_inputs + + +def _enumerate_nodes_with_outputs( + pipeline: Pipeline, outputs: Collection[str] +) -> list[Node]: + """For given outputs, returns a list containing nodes that + generate them in the given ``Pipeline``. Args: - pipeline: the ``Pipeline`` to search for direct parents in. - child: the ``Node`` to find parents of. + pipeline: the ``Pipeline`` to search for nodes in. + outputs: the dataset names to find source nodes for. Returns: - A list of all ``Node``s that are direct parents of ``child``. + A list of all ``Node``s that are producing ``outputs``. """ - parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs) + parent_pipeline = pipeline.only_nodes_with_outputs(*outputs) return parent_pipeline.nodes -def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool: - """Check if a ``Node`` exclusively has persisted Datasets as inputs. - If at least one input is a ``MemoryDataset``, return False. +def _find_initial_nodes(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]: + """Given a collection of ``Node``s in a ``Pipeline``, + find the initial group of ``Node``s to be run (in topological order). Args: - node: the ``Node`` to check the inputs of. - catalog: the ``DataCatalog`` of the run. + pipeline: the ``Pipeline`` to search for initial ``Node``s in. + nodes: the ``Node``s to find initial group for. Returns: - True if the ``Node`` being checked exclusively has inputs that - are not ``MemoryDataset``, else False. + A list of initial ``Node``s to run given inputs (in topological order). """ - for node_input in node.inputs: - if isinstance(catalog._datasets[node_input], MemoryDataset): - return False - return True + node_names = set(n.name for n in nodes) + sub_pipeline = pipeline.only_nodes(*node_names) + initial_nodes = sub_pipeline.grouped_nodes[0] + return initial_nodes def run_node( diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 25ca233e97..5262a43fbd 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -15,6 +15,10 @@ def identity(arg): return arg +def first_arg(*args): + return args[0] + + def sink(arg): pass @@ -36,7 +40,7 @@ def return_not_serialisable(arg): return lambda x: x -def multi_input_list_output(arg1, arg2): +def multi_input_list_output(arg1, arg2, arg3=None): # pylint: disable=unused-argument return [arg1, arg2] @@ -80,6 +84,8 @@ def _save(arg): "ds0_B": persistent_dataset, "ds2_A": persistent_dataset, "ds2_B": persistent_dataset, + "dsX": persistent_dataset, + "params:p": MemoryDataset(1), } ) @@ -167,7 +173,38 @@ def two_branches_crossed_pipeline(): ) -@pytest.fixture +@pytest.fixture( + params=[(), ("dsX",), ("params:p",)], + ids=[ + "no_extras", + "extra_persistent_ds", + "extra_param", + ], +) +def two_branches_crossed_pipeline_variable_inputs(request): + """A ``Pipeline`` with an X-shape (two branches with one common node). + Non-persistent datasets (other than parameters) are prefixed with an underscore. + """ + extra_inputs = list(request.param) + + return pipeline( + [ + node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"), + node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"), + node( + multi_input_list_output, + ["_ds1_A", "_ds1_B"] + extra_inputs, + ["ds2_A", "ds2_B"], + name="node2", + ), + node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"), + node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"), + node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"), + node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"), + ] + ) + + def pipeline_with_memory_datasets(): return pipeline( [ diff --git a/tests/runner/test_sequential_runner.py b/tests/runner/test_sequential_runner.py index 0e28feed6d..ee4a98677b 100644 --- a/tests/runner/test_sequential_runner.py +++ b/tests/runner/test_sequential_runner.py @@ -252,18 +252,18 @@ def test_confirms(self, mocker, test_pipeline, is_async): fake_dataset_instance.confirm.assert_called_once_with() -@pytest.mark.parametrize( - "failing_node_names,expected_pattern", - [ - (["node1_A"], r"No nodes ran."), - (["node2"], r"(node1_A,node1_B|node1_B,node1_A)"), - (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"), - (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), - (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), - (["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"), - ], -) class TestSuggestResumeScenario: + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A"], r"No nodes ran."), + (["node2"], r"(node1_A,node1_B|node1_B,node1_A)"), + (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), + (["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"), + ], + ) def test_suggest_resume_scenario( self, caplog, @@ -284,7 +284,49 @@ def test_suggest_resume_scenario( persistent_dataset_catalog, hook_manager=_create_hook_manager(), ) - assert re.search(expected_pattern, caplog.text) + assert re.search( + expected_pattern, caplog.text + ), f"{expected_pattern=}, {caplog.text=}" + + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A"], r"No nodes ran."), + (["node2"], r'"node1_A,node1_B"'), + (["node3_A"], r'"node3_A,node3_B"'), + (["node4_A"], r'"node3_A,node3_B"'), + (["node3_A", "node4_A"], r'"node3_A,node3_B"'), + (["node2", "node4_A"], r'"node1_A,node1_B"'), + ], + ) + def test_stricter_suggest_resume_scenario( + self, + caplog, + two_branches_crossed_pipeline_variable_inputs, + persistent_dataset_catalog, + failing_node_names, + expected_pattern, + ): + """ + Stricter version of previous test. + Covers pipelines where inputs are shared across nodes. + """ + test_pipeline = two_branches_crossed_pipeline_variable_inputs + + nodes = {n.name: n for n in test_pipeline.nodes} + for name in failing_node_names: + test_pipeline -= modular_pipeline([nodes[name]]) + test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)]) + + with pytest.raises(Exception, match="test exception"): + SequentialRunner().run( + test_pipeline, + persistent_dataset_catalog, + hook_manager=_create_hook_manager(), + ) + assert re.search( + expected_pattern, caplog.text + ), f"{expected_pattern=}, {caplog.text=}" class TestMemoryDatasetBehaviour: