Skip to content

Commit

Permalink
Add KuberenetesWorker.submit for ad-hoc submission to a Kubernetes …
Browse files Browse the repository at this point in the history
…work pool (#17218)
  • Loading branch information
desertaxle authored Feb 25, 2025
1 parent 1ed5edf commit 88a030e
Show file tree
Hide file tree
Showing 7 changed files with 489 additions and 25 deletions.
177 changes: 168 additions & 9 deletions src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,32 @@
checkout out the [Prefect docs](https://docs.prefect.io/concepts/work-pools/).
"""

from __future__ import annotations

import asyncio
import base64
import enum
import json
import logging
import shlex
import tempfile
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import anyio
import anyio.abc
import kubernetes_asyncio
from jsonpatch import JsonPatch
Expand All @@ -141,12 +149,13 @@
from typing_extensions import Literal, Self

import prefect
from prefect.client.schemas.objects import FlowRun
from prefect.client.schemas.objects import Flow as APIFlow
from prefect.exceptions import (
InfrastructureError,
)
from prefect.server.schemas.core import Flow
from prefect.server.schemas.responses import DeploymentResponse
from prefect.futures import PrefectFlowRunFuture
from prefect.states import Pending
from prefect.utilities.collections import get_from_dict
from prefect.utilities.dockerutils import get_prefect_image_name
from prefect.utilities.templating import find_placeholders
from prefect.utilities.timeout import timeout_async
Expand All @@ -166,6 +175,15 @@
_slugify_name,
)

if TYPE_CHECKING:
from prefect.client.schemas.objects import FlowRun
from prefect.client.schemas.responses import DeploymentResponse
from prefect.flows import Flow

# Captures flow return type
R = TypeVar("R")


MAX_ATTEMPTS = 3
RETRY_MIN_DELAY_SECONDS = 1
RETRY_MIN_DELAY_JITTER_SECONDS = 0
Expand Down Expand Up @@ -346,7 +364,7 @@ def prepare_for_flow_run(
self,
flow_run: "FlowRun",
deployment: Optional["DeploymentResponse"] = None,
flow: Optional["Flow"] = None,
flow: Optional["APIFlow"] = None,
):
"""
Prepares the job configuration for a flow run.
Expand Down Expand Up @@ -554,7 +572,13 @@ class KubernetesWorkerResult(BaseWorkerResult):
"""Contains information about the final state of a completed process"""


class KubernetesWorker(BaseWorker):
class KubernetesWorker(
BaseWorker[
"KubernetesWorkerJobConfiguration",
"KubernetesWorkerVariables",
"KubernetesWorkerResult",
]
):
"""Prefect worker that executes flow runs within Kubernetes Jobs."""

type: str = "kubernetes"
Expand All @@ -568,15 +592,17 @@ class KubernetesWorker(BaseWorker):
_documentation_url = "https://docs.prefect.io/integrations/prefect-kubernetes"
_logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/2d0b896006ad463b49c28aaac14f31e00e32cfab-250x250.png" # noqa

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._created_secrets = {}
self._created_secrets: dict[
tuple[str, str], KubernetesWorkerJobConfiguration
] = {}

async def run(
self,
flow_run: "FlowRun",
configuration: KubernetesWorkerJobConfiguration,
task_status: Optional[anyio.abc.TaskStatus] = None,
task_status: Optional[anyio.abc.TaskStatus[int]] = None,
) -> KubernetesWorkerResult:
"""
Executes a flow run within a Kubernetes Job and waits for the flow run
Expand Down Expand Up @@ -621,7 +647,140 @@ async def run(

return KubernetesWorkerResult(identifier=pid, status_code=status_code)

async def teardown(self, *exc_info):
async def submit(
self,
flow: "Flow[..., R]",
parameters: dict[str, Any] | None = None,
job_variables: dict[str, Any] | None = None,
) -> "PrefectFlowRunFuture[R]":
"""
EXPERIMENTAL: The interface for this method is subject to change.
Submits a flow to run in a Kubernetes job.
Args:
flow: The flow to submit
parameters: The parameters to pass to the flow
Returns:
A flow run object
"""
warnings.warn(
"The `submit` method on the Kubernetes worker is experimental. The interface "
"and behavior of this method are subject to change.",
category=FutureWarning,
)
if self._runs_task_group is None:
raise RuntimeError("Worker not properly initialized")
flow_run = await self._runs_task_group.start(
partial(
self._submit_adhoc_run,
flow=flow,
parameters=parameters,
job_variables=job_variables,
),
)
return PrefectFlowRunFuture(flow_run_id=flow_run.id)

async def _submit_adhoc_run(
self,
flow: "Flow[..., R]",
parameters: dict[str, Any] | None = None,
job_variables: dict[str, Any] | None = None,
task_status: anyio.abc.TaskStatus["FlowRun"] | None = None,
):
"""
Submits a flow run to the Kubernetes worker.
"""
from prefect._experimental.bundles import (
convert_step_to_command,
create_bundle_for_flow_run,
)

if TYPE_CHECKING:
assert self._client is not None
assert self._work_pool is not None
flow_run = await self._client.create_flow_run(
flow, parameters=parameters, state=Pending()
)
if task_status is not None:
# Emit the flow run object to .submit to allow it to return a future as soon as possible
task_status.started(flow_run)
# Avoid an API call to get the flow
api_flow = APIFlow(id=flow_run.flow_id, name=flow.name, labels={})
logger = self.get_flow_run_logger(flow_run)

# TODO: Migrate steps to their own attributes on work pool when hardening this design
upload_step = json.loads(
get_from_dict(
self._work_pool.base_job_template,
"variables.properties.env.default.PREFECT__BUNDLE_UPLOAD_STEP",
"{}",
)
)
execute_step = json.loads(
get_from_dict(
self._work_pool.base_job_template,
"variables.properties.env.default.PREFECT__BUNDLE_EXECUTE_STEP",
"{}",
)
)

upload_command = convert_step_to_command(upload_step, str(flow_run.id))
execute_command = convert_step_to_command(execute_step, str(flow_run.id))

job_variables = (job_variables or {}) | {"command": " ".join(execute_command)}

configuration = await self.job_configuration.from_template_and_values(
base_job_template=self._work_pool.base_job_template,
values=job_variables,
client=self._client,
)
configuration.prepare_for_flow_run(flow_run=flow_run, flow=api_flow)

bundle = create_bundle_for_flow_run(flow=flow, flow_run=flow_run)

logger.debug("Uploading execution bundle")
with tempfile.TemporaryDirectory() as temp_dir:
await (
anyio.Path(temp_dir)
.joinpath(str(flow_run.id))
.write_bytes(json.dumps(bundle).encode("utf-8"))
)

try:
await anyio.run_process(
upload_command + [str(flow_run.id)],
cwd=temp_dir,
)
except Exception as e:
self._logger.error(
"Failed to upload bundle: %s", e.stderr.decode("utf-8")
)
raise e

logger.debug("Successfully uploaded execution bundle")

try:
result = await self.run(flow_run, configuration)

if result.status_code != 0:
await self._propose_crashed_state(
flow_run,
(
"Flow run infrastructure exited with non-zero status code"
f" {result.status_code}."
),
)
except Exception as exc:
# This flow run was being submitted and did not start successfully
logger.exception(
f"Failed to submit flow run '{flow_run.id}' to infrastructure."
)
message = f"Flow run could not be submitted to infrastructure:\n{exc!r}"
await self._propose_crashed_state(flow_run, message)

async def teardown(self, *exc_info: Any):
await super().teardown(*exc_info)

await self._clean_up_created_secrets()
Expand Down
Loading

0 comments on commit 88a030e

Please sign in to comment.