Skip to content

Commit

Permalink
Add @kubernetes decorator (#17248)
Browse files Browse the repository at this point in the history
Co-authored-by: nate nowack <thrast36@gmail.com>
  • Loading branch information
desertaxle and zzstoatzz authored Feb 25, 2025
1 parent 88a030e commit cc5b133
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

import inspect
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Iterable,
NoReturn,
Optional,
TypeVar,
overload,
)

from prefect_kubernetes.worker import KubernetesWorker
from typing_extensions import Literal, ParamSpec

from prefect import Flow, State
from prefect.futures import PrefectFuture
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import get_call_parameters

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


class InfrastructureBoundFlow(Flow[P, R]):
def __init__(
self,
*args: Any,
work_pool: str,
job_variables: dict[str, Any],
# TODO: Update this to use BaseWorker when the .submit method is moved to the base class
worker_cls: type[KubernetesWorker],
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.work_pool = work_pool
self.job_variables = job_variables
self.worker_cls = worker_cls

@classmethod
def from_flow(
cls,
flow: Flow[P, R],
work_pool: str,
job_variables: dict[str, Any],
worker_cls: type[KubernetesWorker],
) -> InfrastructureBoundFlow[P, R]:
new = cls(
flow.fn,
work_pool=work_pool,
job_variables=job_variables,
worker_cls=worker_cls,
)
# Copy all attributes from the original flow
for attr, value in flow.__dict__.items():
setattr(new, attr, value)
return new

@overload
def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> None:
# `NoReturn` matches if a type can't be inferred for the function which stops a
# sync function from matching the `Coroutine` overload
...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
**kwargs: P.kwargs,
) -> Coroutine[Any, Any, T]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> Awaitable[State[T]]: ...

@overload
def __call__(
self: "Flow[P, T]",
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> State[T]: ...

def __call__(
self,
*args: "P.args",
return_state: bool = False,
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: "P.kwargs",
):
async def modified_call(
*args: P.args,
return_state: bool = False,
# TODO: Handle wait_for once we have an asynchronous way to wait for futures
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
**kwargs: P.kwargs,
) -> R | State[R]:
async with self.worker_cls(work_pool_name=self.work_pool) as worker:
parameters = get_call_parameters(self, args, kwargs)
future = await worker.submit(
flow=self,
parameters=parameters,
job_variables=self.job_variables,
)
if return_state:
await future.wait_async()
return future.state
return await future.aresult()

if inspect.iscoroutinefunction(self.fn):
return modified_call(
*args, return_state=return_state, wait_for=wait_for, **kwargs
)
else:
return run_coro_as_sync(
modified_call(
*args,
return_state=return_state,
wait_for=wait_for,
**kwargs,
)
)


def kubernetes(
work_pool: str, **job_variables: Any
) -> Callable[[Flow[P, R]], Flow[P, R]]:
"""
Decorator that binds execution of a flow to a Kubernetes work pool
Args:
work_pool: The name of the Kubernetes work pool to use
**job_variables: Additional job variables to use for infrastructure configuration
Example:
```python
from prefect import flow
from prefect_kubernetes import kubernetes
@kubernetes(work_pool="my-pool")
@flow
def my_flow():
...
# This will run the flow in a Kubernetes job
my_flow()
```
"""

def decorator(flow: Flow[P, R]) -> InfrastructureBoundFlow[P, R]:
return InfrastructureBoundFlow.from_flow(
flow,
work_pool=work_pool,
job_variables=job_variables,
worker_cls=KubernetesWorker,
)

return decorator
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Generator
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from prefect_kubernetes.experimental.decorators import kubernetes
from prefect_kubernetes.worker import KubernetesWorker

from prefect import State, flow
from prefect.futures import PrefectFuture


@pytest.fixture
def mock_submit() -> Generator[AsyncMock, None, None]:
"""Create a mock for the KubernetesWorker.submit method"""
# Create a mock state
mock_state = MagicMock(spec=State)
mock_state.is_completed.return_value = True
mock_state.message = "Success"

# Create a mock future
mock_future = MagicMock(spec=PrefectFuture)
mock_future.aresult = AsyncMock(return_value="test_result")
mock_future.wait_async = AsyncMock()
mock_future.state = mock_state

mock = AsyncMock(return_value=mock_future)

patcher = patch.object(KubernetesWorker, "submit", mock)
patcher.start()
yield mock
patcher.stop()


def test_kubernetes_decorator_sync_flow(mock_submit: AsyncMock) -> None:
"""Test that a synchronous flow is correctly decorated and executed"""

@kubernetes(work_pool="test-pool", memory="2Gi")
@flow
def sync_test_flow(param1, param2="default"):
return f"{param1}-{param2}"

result = sync_test_flow("test")

mock_submit.assert_called_once()
args, kwargs = mock_submit.call_args
assert kwargs["parameters"] == {"param1": "test", "param2": "default"}
assert kwargs["job_variables"] == {"memory": "2Gi"}
assert result == "test_result"


async def test_kubernetes_decorator_async_flow(mock_submit: AsyncMock) -> None:
"""Test that an asynchronous flow is correctly decorated and executed"""

@kubernetes(work_pool="test-pool", cpu="1")
@flow
async def async_test_flow(param1):
return f"async-{param1}"

result = await async_test_flow("test")

mock_submit.assert_called_once()
args, kwargs = mock_submit.call_args
assert kwargs["parameters"] == {"param1": "test"}
assert kwargs["job_variables"] == {"cpu": "1"}
assert result == "test_result"


@pytest.mark.usefixtures("mock_submit")
def test_kubernetes_decorator_return_state() -> None:
"""Test that return_state=True returns the state instead of the result"""

@kubernetes(work_pool="test-pool")
@flow
def test_flow():
return "completed"

state = test_flow(return_state=True)

assert state.is_completed()
assert state.message == "Success"


@pytest.mark.usefixtures("mock_submit")
def test_kubernetes_decorator_preserves_flow_attributes() -> None:
"""Test that the decorator preserves the original flow's attributes"""

@flow(name="custom-flow-name", description="Custom description")
def original_flow():
return "test"

original_name = original_flow.name
original_description = original_flow.description

decorated_flow = kubernetes(work_pool="test-pool")(original_flow)

assert decorated_flow.name == original_name
assert decorated_flow.description == original_description

result = decorated_flow()
assert result == "test_result"


def test_submit_method_receives_work_pool_name(mock_submit: AsyncMock) -> None:
"""Test that the correct work pool name is passed to submit"""

@kubernetes(work_pool="specific-pool")
@flow
def test_flow():
return "test"

test_flow()

mock_submit.assert_called_once()
kwargs = mock_submit.call_args.kwargs
assert "flow" in kwargs
assert "parameters" in kwargs
assert "job_variables" in kwargs

0 comments on commit cc5b133

Please sign in to comment.