-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: nate nowack <thrast36@gmail.com>
- Loading branch information
1 parent
88a030e
commit cc5b133
Showing
2 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
173 changes: 173 additions & 0 deletions
173
src/integrations/prefect-kubernetes/prefect_kubernetes/experimental/decorators.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import ( | ||
Any, | ||
Awaitable, | ||
Callable, | ||
Coroutine, | ||
Iterable, | ||
NoReturn, | ||
Optional, | ||
TypeVar, | ||
overload, | ||
) | ||
|
||
from prefect_kubernetes.worker import KubernetesWorker | ||
from typing_extensions import Literal, ParamSpec | ||
|
||
from prefect import Flow, State | ||
from prefect.futures import PrefectFuture | ||
from prefect.utilities.asyncutils import run_coro_as_sync | ||
from prefect.utilities.callables import get_call_parameters | ||
|
||
P = ParamSpec("P") | ||
R = TypeVar("R") | ||
T = TypeVar("T") | ||
|
||
|
||
class InfrastructureBoundFlow(Flow[P, R]): | ||
def __init__( | ||
self, | ||
*args: Any, | ||
work_pool: str, | ||
job_variables: dict[str, Any], | ||
# TODO: Update this to use BaseWorker when the .submit method is moved to the base class | ||
worker_cls: type[KubernetesWorker], | ||
**kwargs: Any, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.work_pool = work_pool | ||
self.job_variables = job_variables | ||
self.worker_cls = worker_cls | ||
|
||
@classmethod | ||
def from_flow( | ||
cls, | ||
flow: Flow[P, R], | ||
work_pool: str, | ||
job_variables: dict[str, Any], | ||
worker_cls: type[KubernetesWorker], | ||
) -> InfrastructureBoundFlow[P, R]: | ||
new = cls( | ||
flow.fn, | ||
work_pool=work_pool, | ||
job_variables=job_variables, | ||
worker_cls=worker_cls, | ||
) | ||
# Copy all attributes from the original flow | ||
for attr, value in flow.__dict__.items(): | ||
setattr(new, attr, value) | ||
return new | ||
|
||
@overload | ||
def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> None: | ||
# `NoReturn` matches if a type can't be inferred for the function which stops a | ||
# sync function from matching the `Coroutine` overload | ||
... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, Coroutine[Any, Any, T]]", | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> Coroutine[Any, Any, T]: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, T]", | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> T: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, Coroutine[Any, Any, T]]", | ||
*args: P.args, | ||
return_state: Literal[True], | ||
**kwargs: P.kwargs, | ||
) -> Awaitable[State[T]]: ... | ||
|
||
@overload | ||
def __call__( | ||
self: "Flow[P, T]", | ||
*args: P.args, | ||
return_state: Literal[True], | ||
**kwargs: P.kwargs, | ||
) -> State[T]: ... | ||
|
||
def __call__( | ||
self, | ||
*args: "P.args", | ||
return_state: bool = False, | ||
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, | ||
**kwargs: "P.kwargs", | ||
): | ||
async def modified_call( | ||
*args: P.args, | ||
return_state: bool = False, | ||
# TODO: Handle wait_for once we have an asynchronous way to wait for futures | ||
wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, | ||
**kwargs: P.kwargs, | ||
) -> R | State[R]: | ||
async with self.worker_cls(work_pool_name=self.work_pool) as worker: | ||
parameters = get_call_parameters(self, args, kwargs) | ||
future = await worker.submit( | ||
flow=self, | ||
parameters=parameters, | ||
job_variables=self.job_variables, | ||
) | ||
if return_state: | ||
await future.wait_async() | ||
return future.state | ||
return await future.aresult() | ||
|
||
if inspect.iscoroutinefunction(self.fn): | ||
return modified_call( | ||
*args, return_state=return_state, wait_for=wait_for, **kwargs | ||
) | ||
else: | ||
return run_coro_as_sync( | ||
modified_call( | ||
*args, | ||
return_state=return_state, | ||
wait_for=wait_for, | ||
**kwargs, | ||
) | ||
) | ||
|
||
|
||
def kubernetes( | ||
work_pool: str, **job_variables: Any | ||
) -> Callable[[Flow[P, R]], Flow[P, R]]: | ||
""" | ||
Decorator that binds execution of a flow to a Kubernetes work pool | ||
Args: | ||
work_pool: The name of the Kubernetes work pool to use | ||
**job_variables: Additional job variables to use for infrastructure configuration | ||
Example: | ||
```python | ||
from prefect import flow | ||
from prefect_kubernetes import kubernetes | ||
@kubernetes(work_pool="my-pool") | ||
@flow | ||
def my_flow(): | ||
... | ||
# This will run the flow in a Kubernetes job | ||
my_flow() | ||
``` | ||
""" | ||
|
||
def decorator(flow: Flow[P, R]) -> InfrastructureBoundFlow[P, R]: | ||
return InfrastructureBoundFlow.from_flow( | ||
flow, | ||
work_pool=work_pool, | ||
job_variables=job_variables, | ||
worker_cls=KubernetesWorker, | ||
) | ||
|
||
return decorator |
117 changes: 117 additions & 0 deletions
117
src/integrations/prefect-kubernetes/tests/experimental/test_decorator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Generator | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
from prefect_kubernetes.experimental.decorators import kubernetes | ||
from prefect_kubernetes.worker import KubernetesWorker | ||
|
||
from prefect import State, flow | ||
from prefect.futures import PrefectFuture | ||
|
||
|
||
@pytest.fixture | ||
def mock_submit() -> Generator[AsyncMock, None, None]: | ||
"""Create a mock for the KubernetesWorker.submit method""" | ||
# Create a mock state | ||
mock_state = MagicMock(spec=State) | ||
mock_state.is_completed.return_value = True | ||
mock_state.message = "Success" | ||
|
||
# Create a mock future | ||
mock_future = MagicMock(spec=PrefectFuture) | ||
mock_future.aresult = AsyncMock(return_value="test_result") | ||
mock_future.wait_async = AsyncMock() | ||
mock_future.state = mock_state | ||
|
||
mock = AsyncMock(return_value=mock_future) | ||
|
||
patcher = patch.object(KubernetesWorker, "submit", mock) | ||
patcher.start() | ||
yield mock | ||
patcher.stop() | ||
|
||
|
||
def test_kubernetes_decorator_sync_flow(mock_submit: AsyncMock) -> None: | ||
"""Test that a synchronous flow is correctly decorated and executed""" | ||
|
||
@kubernetes(work_pool="test-pool", memory="2Gi") | ||
@flow | ||
def sync_test_flow(param1, param2="default"): | ||
return f"{param1}-{param2}" | ||
|
||
result = sync_test_flow("test") | ||
|
||
mock_submit.assert_called_once() | ||
args, kwargs = mock_submit.call_args | ||
assert kwargs["parameters"] == {"param1": "test", "param2": "default"} | ||
assert kwargs["job_variables"] == {"memory": "2Gi"} | ||
assert result == "test_result" | ||
|
||
|
||
async def test_kubernetes_decorator_async_flow(mock_submit: AsyncMock) -> None: | ||
"""Test that an asynchronous flow is correctly decorated and executed""" | ||
|
||
@kubernetes(work_pool="test-pool", cpu="1") | ||
@flow | ||
async def async_test_flow(param1): | ||
return f"async-{param1}" | ||
|
||
result = await async_test_flow("test") | ||
|
||
mock_submit.assert_called_once() | ||
args, kwargs = mock_submit.call_args | ||
assert kwargs["parameters"] == {"param1": "test"} | ||
assert kwargs["job_variables"] == {"cpu": "1"} | ||
assert result == "test_result" | ||
|
||
|
||
@pytest.mark.usefixtures("mock_submit") | ||
def test_kubernetes_decorator_return_state() -> None: | ||
"""Test that return_state=True returns the state instead of the result""" | ||
|
||
@kubernetes(work_pool="test-pool") | ||
@flow | ||
def test_flow(): | ||
return "completed" | ||
|
||
state = test_flow(return_state=True) | ||
|
||
assert state.is_completed() | ||
assert state.message == "Success" | ||
|
||
|
||
@pytest.mark.usefixtures("mock_submit") | ||
def test_kubernetes_decorator_preserves_flow_attributes() -> None: | ||
"""Test that the decorator preserves the original flow's attributes""" | ||
|
||
@flow(name="custom-flow-name", description="Custom description") | ||
def original_flow(): | ||
return "test" | ||
|
||
original_name = original_flow.name | ||
original_description = original_flow.description | ||
|
||
decorated_flow = kubernetes(work_pool="test-pool")(original_flow) | ||
|
||
assert decorated_flow.name == original_name | ||
assert decorated_flow.description == original_description | ||
|
||
result = decorated_flow() | ||
assert result == "test_result" | ||
|
||
|
||
def test_submit_method_receives_work_pool_name(mock_submit: AsyncMock) -> None: | ||
"""Test that the correct work pool name is passed to submit""" | ||
|
||
@kubernetes(work_pool="specific-pool") | ||
@flow | ||
def test_flow(): | ||
return "test" | ||
|
||
test_flow() | ||
|
||
mock_submit.assert_called_once() | ||
kwargs = mock_submit.call_args.kwargs | ||
assert "flow" in kwargs | ||
assert "parameters" in kwargs | ||
assert "job_variables" in kwargs |