Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscribe to state changes in wait_for_flow_run #17243

Merged
merged 7 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/prefect/cli/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,6 @@ async def run(
"--watch",
help=("Whether to poll the flow run until a terminal state is reached."),
),
watch_interval: Optional[int] = typer.Option(
None,
"--watch-interval",
help=("How often to poll the flow run for state changes (in seconds)."),
),
watch_timeout: Optional[int] = typer.Option(
None,
"--watch-timeout",
Expand Down Expand Up @@ -768,10 +763,6 @@ async def run(
multi_params = json.loads(multiparams)
except ValueError as exc:
exit_with_error(f"Failed to parse JSON: {exc}")
if watch_interval and not watch:
exit_with_error(
"`--watch-interval` can only be used with `--watch`.",
)
cli_params: dict[str, Any] = _load_json_key_values(params or [], "parameter")
conflicting_keys = set(cli_params.keys()).intersection(multi_params.keys())
if conflicting_keys:
Expand Down Expand Up @@ -894,12 +885,10 @@ async def run(
soft_wrap=True,
)
if watch:
watch_interval = 5 if watch_interval is None else watch_interval
app.console.print(f"Watching flow run {flow_run.name!r}...")
finished_flow_run = await wait_for_flow_run(
flow_run.id,
timeout=watch_timeout,
poll_interval=watch_interval,
log_states=True,
)
finished_flow_run_state = finished_flow_run.state
Expand Down
39 changes: 28 additions & 11 deletions src/prefect/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas import FlowRun
from prefect.client.schemas.objects import (
StateType,
State, StateType,
)
from prefect.client.schemas.responses import SetStateStatus
from prefect.client.utilities import inject_client
from prefect.context import (
FlowRunContext,
TaskRunContext,
)
from prefect.events.clients import get_events_subscriber
from prefect.events.filters import EventFilter, EventNameFilter, EventResourceFilter
from prefect.exceptions import (
Abort,
FlowPauseTimeout,
Expand Down Expand Up @@ -54,7 +56,6 @@
async def wait_for_flow_run(
flow_run_id: UUID,
timeout: int | None = 10800,
poll_interval: int = 5,
Copy link
Collaborator

@zzstoatzz zzstoatzz Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically this is a breaking change (unfortunately since its likely infrequently used) so I think we should keep it and if a non-default value is provided we log some warning to say it has no effect and will be removed in a future release

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, that's actually how I had it before so I just reverted 53a1242

client: "PrefectClient | None" = None,
log_states: bool = False,
) -> FlowRun:
Expand All @@ -64,7 +65,8 @@ async def wait_for_flow_run(
Args:
flow_run_id: The flow run ID for the flow run to wait for.
timeout: The wait timeout in seconds. Defaults to 10800 (3 hours).
poll_interval: The poll interval in seconds. Defaults to 5.
client: Optional Prefect client. If not provided, one will be injected.
log_states: If True, log state changes. Defaults to False.

Returns:
FlowRun: The finished flow run.
Expand Down Expand Up @@ -116,15 +118,30 @@ async def main(num_runs: int):
"""
assert client is not None, "Client injection failed"
logger = get_logger()

flow_run = await client.read_flow_run(flow_run_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know that this is even necessary to check before we enter the subscriber iterable, maybe we can remove this whole block?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think to avoid any missed events / race conditions due to the time it takes to make this call, we should check for a final state immediately after entering the subscriber (but before the async for) , but yeah we should be able to remove that first one outside the subscriber

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, let me know if the new version is what you had in mind

if flow_run.state and flow_run.state.is_final():
if log_states:
logger.info(f"Flow run is in state {flow_run.state.name!r}")
return flow_run

filter = EventFilter(
event=EventNameFilter(prefix=["prefect.flow-run"]),
resource=EventResourceFilter(id=[f"prefect.flow-run.{flow_run_id}"])
)

with anyio.move_on_after(timeout):
while True:
flow_run = await client.read_flow_run(flow_run_id)
flow_state = flow_run.state
if log_states and flow_state:
logger.info(f"Flow run is in state {flow_state.name!r}")
if flow_state and flow_state.is_final():
return flow_run
await anyio.sleep(poll_interval)
async with get_events_subscriber(filter=filter) as subscriber:
async for event in subscriber:
state_type = StateType(event.resource["prefect.state-type"])
state = State(type=state_type)

if log_states:
logger.info(f"Flow run is in state {state.name!r}")

if state.is_final():
return await client.read_flow_run(flow_run_id)

raise FlowRunWaitTimeout(
f"Flow run with ID {flow_run_id} exceeded watch timeout of {timeout} seconds"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def foo():
flow_run = await prefect_client.create_flow_run(foo, state=Completed())
assert isinstance(flow_run, client_schemas.FlowRun)

lookup = await wait_for_flow_run(flow_run.id, poll_interval=0)
lookup = await wait_for_flow_run(flow_run.id)
# Estimates will not be equal since time has passed
assert lookup == flow_run
assert flow_run.state
Expand Down
Loading