From cc5b13313ecef1124526fd18403a30d2fbcbf3d3 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Tue, 25 Feb 2025 12:29:03 -0600 Subject: [PATCH] Add `@kubernetes` decorator (#17248) Co-authored-by: nate nowack --- .../experimental/decorators.py | 173 ++++++++++++++++++ .../tests/experimental/test_decorator.py | 117 ++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py create mode 100644 src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py new file mode 100644 index 000000000000..dcb061585c28 --- /dev/null +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import inspect +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Iterable, + NoReturn, + Optional, + TypeVar, + overload, +) + +from prefect_kubernetes.worker import KubernetesWorker +from typing_extensions import Literal, ParamSpec + +from prefect import Flow, State +from prefect.futures import PrefectFuture +from prefect.utilities.asyncutils import run_coro_as_sync +from prefect.utilities.callables import get_call_parameters + +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +class InfrastructureBoundFlow(Flow[P, R]): + def __init__( + self, + *args: Any, + work_pool: str, + job_variables: dict[str, Any], + # TODO: Update this to use BaseWorker when the .submit method is moved to the base class + worker_cls: type[KubernetesWorker], + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.work_pool = work_pool + self.job_variables = job_variables + self.worker_cls = worker_cls + + @classmethod + def from_flow( + cls, + flow: Flow[P, R], + work_pool: str, + job_variables: dict[str, Any], + worker_cls: type[KubernetesWorker], + ) -> InfrastructureBoundFlow[P, R]: + new = cls( + flow.fn, + work_pool=work_pool, + job_variables=job_variables, + worker_cls=worker_cls, + ) + # Copy all attributes from the original flow + for attr, value in flow.__dict__.items(): + setattr(new, attr, value) + return new + + @overload + def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> None: + # `NoReturn` matches if a type can't be inferred for the function which stops a + # sync function from matching the `Coroutine` overload + ... + + @overload + def __call__( + self: "Flow[P, Coroutine[Any, Any, T]]", + *args: P.args, + **kwargs: P.kwargs, + ) -> Coroutine[Any, Any, T]: ... + + @overload + def __call__( + self: "Flow[P, T]", + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + + @overload + def __call__( + self: "Flow[P, Coroutine[Any, Any, T]]", + *args: P.args, + return_state: Literal[True], + **kwargs: P.kwargs, + ) -> Awaitable[State[T]]: ... + + @overload + def __call__( + self: "Flow[P, T]", + *args: P.args, + return_state: Literal[True], + **kwargs: P.kwargs, + ) -> State[T]: ... + + def __call__( + self, + *args: "P.args", + return_state: bool = False, + wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, + **kwargs: "P.kwargs", + ): + async def modified_call( + *args: P.args, + return_state: bool = False, + # TODO: Handle wait_for once we have an asynchronous way to wait for futures + wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, + **kwargs: P.kwargs, + ) -> R | State[R]: + async with self.worker_cls(work_pool_name=self.work_pool) as worker: + parameters = get_call_parameters(self, args, kwargs) + future = await worker.submit( + flow=self, + parameters=parameters, + job_variables=self.job_variables, + ) + if return_state: + await future.wait_async() + return future.state + return await future.aresult() + + if inspect.iscoroutinefunction(self.fn): + return modified_call( + *args, return_state=return_state, wait_for=wait_for, **kwargs + ) + else: + return run_coro_as_sync( + modified_call( + *args, + return_state=return_state, + wait_for=wait_for, + **kwargs, + ) + ) + + +def kubernetes( + work_pool: str, **job_variables: Any +) -> Callable[[Flow[P, R]], Flow[P, R]]: + """ + Decorator that binds execution of a flow to a Kubernetes work pool + + Args: + work_pool: The name of the Kubernetes work pool to use + **job_variables: Additional job variables to use for infrastructure configuration + + Example: + ```python + from prefect import flow + from prefect_kubernetes import kubernetes + + @kubernetes(work_pool="my-pool") + @flow + def my_flow(): + ... + + # This will run the flow in a Kubernetes job + my_flow() + ``` + """ + + def decorator(flow: Flow[P, R]) -> InfrastructureBoundFlow[P, R]: + return InfrastructureBoundFlow.from_flow( + flow, + work_pool=work_pool, + job_variables=job_variables, + worker_cls=KubernetesWorker, + ) + + return decorator diff --git a/src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py b/src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py new file mode 100644 index 000000000000..fc7c3ac5f222 --- /dev/null +++ b/src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py @@ -0,0 +1,117 @@ +from typing import Generator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from prefect_kubernetes.experimental.decorators import kubernetes +from prefect_kubernetes.worker import KubernetesWorker + +from prefect import State, flow +from prefect.futures import PrefectFuture + + +@pytest.fixture +def mock_submit() -> Generator[AsyncMock, None, None]: + """Create a mock for the KubernetesWorker.submit method""" + # Create a mock state + mock_state = MagicMock(spec=State) + mock_state.is_completed.return_value = True + mock_state.message = "Success" + + # Create a mock future + mock_future = MagicMock(spec=PrefectFuture) + mock_future.aresult = AsyncMock(return_value="test_result") + mock_future.wait_async = AsyncMock() + mock_future.state = mock_state + + mock = AsyncMock(return_value=mock_future) + + patcher = patch.object(KubernetesWorker, "submit", mock) + patcher.start() + yield mock + patcher.stop() + + +def test_kubernetes_decorator_sync_flow(mock_submit: AsyncMock) -> None: + """Test that a synchronous flow is correctly decorated and executed""" + + @kubernetes(work_pool="test-pool", memory="2Gi") + @flow + def sync_test_flow(param1, param2="default"): + return f"{param1}-{param2}" + + result = sync_test_flow("test") + + mock_submit.assert_called_once() + args, kwargs = mock_submit.call_args + assert kwargs["parameters"] == {"param1": "test", "param2": "default"} + assert kwargs["job_variables"] == {"memory": "2Gi"} + assert result == "test_result" + + +async def test_kubernetes_decorator_async_flow(mock_submit: AsyncMock) -> None: + """Test that an asynchronous flow is correctly decorated and executed""" + + @kubernetes(work_pool="test-pool", cpu="1") + @flow + async def async_test_flow(param1): + return f"async-{param1}" + + result = await async_test_flow("test") + + mock_submit.assert_called_once() + args, kwargs = mock_submit.call_args + assert kwargs["parameters"] == {"param1": "test"} + assert kwargs["job_variables"] == {"cpu": "1"} + assert result == "test_result" + + +@pytest.mark.usefixtures("mock_submit") +def test_kubernetes_decorator_return_state() -> None: + """Test that return_state=True returns the state instead of the result""" + + @kubernetes(work_pool="test-pool") + @flow + def test_flow(): + return "completed" + + state = test_flow(return_state=True) + + assert state.is_completed() + assert state.message == "Success" + + +@pytest.mark.usefixtures("mock_submit") +def test_kubernetes_decorator_preserves_flow_attributes() -> None: + """Test that the decorator preserves the original flow's attributes""" + + @flow(name="custom-flow-name", description="Custom description") + def original_flow(): + return "test" + + original_name = original_flow.name + original_description = original_flow.description + + decorated_flow = kubernetes(work_pool="test-pool")(original_flow) + + assert decorated_flow.name == original_name + assert decorated_flow.description == original_description + + result = decorated_flow() + assert result == "test_result" + + +def test_submit_method_receives_work_pool_name(mock_submit: AsyncMock) -> None: + """Test that the correct work pool name is passed to submit""" + + @kubernetes(work_pool="specific-pool") + @flow + def test_flow(): + return "test" + + test_flow() + + mock_submit.assert_called_once() + kwargs = mock_submit.call_args.kwargs + assert "flow" in kwargs + assert "parameters" in kwargs + assert "job_variables" in kwargs