Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for mapping over remote launch plans #2761

Merged
merged 9 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 60 additions & 19 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import math
from typing import Any, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union

from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface
from flytekit.core.interface import (
transform_interface_to_list_interface,
transform_interface_to_typed_interface,
)
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
Promise,
VoidPromise,
create_and_link_node,
create_and_link_node_from_remote,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import TaskMetadata
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Literal, LiteralCollection, Scalar

ARRAY_NODE_SUBNODE_NAME = "array_node_subnode"

if TYPE_CHECKING:
from flytekit.remote import FlyteLaunchPlan


class ArrayNode:
def __init__(
self,
target: LaunchPlan,
target: Union[LaunchPlan, "FlyteLaunchPlan"],
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
Expand All @@ -47,6 +55,8 @@ def __init__(
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
Expand All @@ -60,21 +70,30 @@ def __init__(
self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0
self._min_successes = 0

n_outputs = len(self.target.python_interface.outputs)
if self.target.python_interface:
n_outputs = len(self.target.python_interface.outputs)
else:
n_outputs = len(self.target.interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

# TODO - bound inputs are not supported at the moment
self._bound_inputs: Set[str] = set()

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
self._collection_interface = collection_interface

self._remote_interface = None
if self.target.python_interface:
self._python_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
elif self.target.interface:
self._remote_interface = self.target.interface.transform_interface_to_list()
else:
raise ValueError("No interface found for the target entity.")

self.metadata = None
if isinstance(target, LaunchPlan):
if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
Expand All @@ -98,7 +117,14 @@ def name(self) -> str:
@property
def python_interface(self) -> flyte_interface.Interface:
# Part of SupportsNodeCreation interface
return self._collection_interface
return self._python_interface

@property
def interface(self) -> _interface_models.TypedInterface:
# Required in get_serializable_node
if self._remote_interface:
return self._remote_interface
raise AttributeError("interface attribute is not available")

@property
def bindings(self) -> List[_literal_models.Binding]:
Expand All @@ -115,6 +141,9 @@ def flyte_entity(self) -> Any:
return self.target

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")

outputs_expected = True
if not self.python_interface.outputs:
outputs_expected = False
Expand Down Expand Up @@ -199,17 +228,27 @@ def __call__(self, *args, **kwargs):
if not self._bindings:
ctx = FlyteContext.current_context()
# since a new entity with an updated list interface is not created, we have to work around the mismatch
# between the interface and the inputs
collection_interface = transform_interface_to_list_interface(
self.flyte_entity.python_interface, self._bound_inputs
)
# don't link the node to the compilation state, since we don't want to add the subnode to the
# workflow as a node
# between the interface and the inputs. Also, don't link the node to the compilation state,
# since we don't want to add the subnode to the workflow as a node
if self._remote_interface:
bound_subnode = create_and_link_node_from_remote(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=self._remote_interface,
**kwargs,
)
self._bindings = bound_subnode.ref.node.bindings
return create_and_link_node_from_remote(
ctx,
entity=self,
**kwargs,
)
bound_subnode = create_and_link_node(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=collection_interface,
overridden_interface=self.python_interface,
node_id=ARRAY_NODE_SUBNODE_NAME,
**kwargs,
)
Expand All @@ -218,7 +257,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan],
target: Union[LaunchPlan, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -237,7 +276,9 @@ def array_node(
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
if not isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
Expand Down
11 changes: 8 additions & 3 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os # TODO: use flytekit logger
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

import typing_extensions
from flyteidl.core import tasks_pb2
Expand All @@ -31,6 +31,9 @@
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager

if TYPE_CHECKING:
from flytekit.remote import FlyteLaunchPlan


class ArrayNodeMapTask(PythonTask):
def __init__(
Expand Down Expand Up @@ -359,7 +362,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask],
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand All @@ -377,7 +380,9 @@ def map_task(
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
if isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
25 changes: 21 additions & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,9 @@ def extract_obj_name(name: str) -> str:
def create_and_link_node_from_remote(
ctx: FlyteContext,
entity: HasFlyteInterface,
overridden_interface: Optional[_interface_models.TypedInterface] = None,
add_node_to_compilation_state: bool = True,
node_id: str = "",
_inputs_not_allowed: Optional[Set[str]] = None,
_ignorable_inputs: Optional[Set[str]] = None,
**kwargs,
Expand All @@ -1084,20 +1087,25 @@ def create_and_link_node_from_remote(

:param ctx: FlyteContext
:param entity: RemoteEntity
:param overridden_interface: utilize this interface instead of the one provided by the entity. This is useful for
ArrayNode as there's a mismatch between the underlying interface and inputs
:param add_node_to_compilation_state: bool that enables for nodes to be created but not linked to the workflow. This
is useful when creating nodes nested under other nodes such as ArrayNode
:param node_id: str if provided, this will be used as the node id.
:param _inputs_not_allowed: Set of all variable names that should not be provided when using this entity.
Useful for Launchplans with `fixed` inputs
:param _ignorable_inputs: Set of all variable names that are optional, but if provided will be overridden. Useful
for launchplans with `default` inputs
:param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises.
:return: Optional[Union[Tuple[Promise], Promise, VoidPromise]]
"""
if ctx.compilation_state is None:
if ctx.compilation_state is None and add_node_to_compilation_state:
raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...")

used_inputs = set()
bindings = []

typed_interface = entity.interface
typed_interface = overridden_interface or entity.interface

if _inputs_not_allowed:
inputs_not_allowed_specified = _inputs_not_allowed.intersection(kwargs.keys())
Expand Down Expand Up @@ -1148,14 +1156,23 @@ def create_and_link_node_from_remote(
# These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(set([n for n in nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID]))

# if not adding to compilation state, we don't need to generate a unique node id
node_id = node_id or (
f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}"
if add_node_to_compilation_state and ctx.compilation_state
else node_id
)

flytekit_node = Node(
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
id=node_id,
metadata=entity.construct_node_metadata(),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_entity=entity,
)
ctx.compilation_state.add_node(flytekit_node)

if add_node_to_compilation_state and ctx.compilation_state:
ctx.compilation_state.add_node(flytekit_node)

if len(typed_interface.outputs) == 0:
return VoidPromise(entity.name, NodeOutput(node=flytekit_node, var="placeholder"))
Expand Down
23 changes: 23 additions & 0 deletions flytekit/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core import interface_pb2 as _interface_pb2
from flyteidl.core import types_pb2 as _types_pb2

from flytekit.models import common as _common
from flytekit.models import literals as _literals
Expand Down Expand Up @@ -64,6 +65,17 @@ def to_flyte_idl(self):
artifact_tag=self.artifact_tag,
)

def to_flyte_idl_list(self):
"""
:rtype: flyteidl.core.interface_pb2.Variable
"""
return _interface_pb2.Variable(
type=_types_pb2.LiteralType(collection_type=self.type.to_flyte_idl()),
description=self.description,
artifact_partial_id=self.artifact_partial_id,
artifact_tag=self.artifact_tag,
)

@classmethod
def from_flyte_idl(cls, variable_proto) -> _interface_pb2.Variable:
"""
Expand Down Expand Up @@ -146,6 +158,17 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface
outputs={k: Variable.from_flyte_idl(v) for k, v in proto.outputs.variables.items()},
)

def transform_interface_to_list(self) -> "TypedInterface":
"""
Takes a single task interface and interpolates it to an array interface - to allow performing distributed
python map like functions
"""
list_interface = _interface_pb2.TypedInterface(
inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.inputs.items()}),
outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.outputs.items()}),
)
return self.from_flyte_idl(list_interface)


class Parameter(_common.FlyteIdlEntity):
def __init__(
Expand Down
50 changes: 43 additions & 7 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from flytekit.core.array_node import array_node
from flytekit.core.array_node_map_task import map_task
from flytekit.models.core import identifier as identifier_models
from flytekit.tools.translator import get_serializable
from flytekit.remote import FlyteLaunchPlan
from flytekit.remote.interface import TypedInterface
from flytekit.tools.translator import gather_dependent_entities, get_serializable


@pytest.fixture
Expand Down Expand Up @@ -40,13 +42,45 @@ def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int:
lp = LaunchPlan.get_default_launch_plan(ctx, parent_wf)


@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])
def get_grandparent_wf(serialization_settings):
@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_wf


def get_grandparent_remote_wf(serialization_settings):
serialized = OrderedDict()
lp_model = get_serializable(serialized, serialization_settings, lp)

task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
for wf_id, spec in wf_specs.items():
break

remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec)
# To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't
# part of the IDL object but is something FlyteRemote does
remote_lp._interface = TypedInterface.promote_from_model(spec.template.interface)

@workflow
def grandparent_remote_wf() -> typing.List[int]:
return array_node(
remote_lp, concurrency=10, min_success_ratio=0.9
)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_remote_wf

def test_lp_serialization(serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf)

@pytest.mark.parametrize(
"target",
[
get_grandparent_wf,
get_grandparent_remote_wf,
],
)
def test_lp_serialization(target, serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, target(serialization_settings))
assert len(wf_spec.template.nodes) == 1

top_level = wf_spec.template.nodes[0]
Expand All @@ -56,7 +90,9 @@ def test_lp_serialization(serialization_settings):
assert binding.scalar.primitive.integer is not None
assert top_level.inputs[1].var == "b"
for binding in top_level.inputs[1].binding.collection.bindings:
assert binding.scalar.union is not None
assert (binding.scalar.union is not None or
binding.scalar.primitive.integer is not None or
binding.scalar.primitive.string_value is not None)
assert len(top_level.inputs[1].binding.collection.bindings) == 3
assert top_level.inputs[2].var == "c"
assert len(top_level.inputs[2].binding.collection.bindings) == 3
Expand Down
Loading
Loading