Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for toggling data mode for array node #2940

Merged
merged 9 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import ReferenceTask
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand All @@ -34,8 +35,7 @@
class ArrayNode:
def __init__(
self,
target: Union[LaunchPlan, "FlyteLaunchPlan"],
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
Expand All @@ -51,17 +51,17 @@ def __init__(
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying node
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata
self._data_mode = None
self._execution_mode = None

if min_successes is not None:
self._min_successes = min_successes
Expand Down Expand Up @@ -92,9 +92,12 @@ def __init__(
else:
raise ValueError("No interface found for the target entity.")

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if isinstance(target, (LaunchPlan, FlyteLaunchPlan)):
self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE
self._execution_mode = _core_workflow.ArrayNode.FULL_STATE
elif isinstance(target, ReferenceTask):
self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES
self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE
else:
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

Expand Down Expand Up @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]:
def flyte_entity(self) -> Any:
return self.target

@property
def data_mode(self) -> _core_workflow.ArrayNode.DataMode:
return self._data_mode

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")
Expand Down Expand Up @@ -254,7 +261,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan, "FlyteLaunchPlan"],
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -275,8 +282,8 @@ def array_node(
"""
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")
if not isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
raise ValueError("Only LaunchPlans and ReferenceTasks are supported for now.")

node = ArrayNode(
target=target,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.task import ReferenceTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.loggers import logger
Expand Down Expand Up @@ -390,7 +391,7 @@ def map_task(
"""
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
3 changes: 3 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
min_success_ratio=None,
execution_mode=None,
is_original_sub_node_interface=False,
data_mode=None,
) -> None:
"""
TODO: docstring
Expand All @@ -401,6 +402,7 @@ def __init__(
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode
self._is_original_sub_node_interface = is_original_sub_node_interface
self._data_mode = data_mode

@property
def node(self) -> "Node":
Expand All @@ -414,6 +416,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface),
data_mode=self._data_mode,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def get_serializable_array_node(
min_success_ratio=array_node.min_success_ratio,
execution_mode=array_node.execution_mode,
is_original_sub_node_interface=array_node.is_original_sub_node_interface,
data_mode=array_node.data_mode,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.9",
"flyteidl>=1.14.1",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
Loading