Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject task run state change if the cache key is too large #16914

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/prefect/server/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from prefect.server.schemas.states import StateType
from prefect.server.task_queue import TaskQueue
from prefect.settings import (
PREFECT_API_TASK_CACHE_KEY_MAX_LENGTH,
PREFECT_DEPLOYMENT_CONCURRENCY_SLOT_WAIT_SECONDS,
PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS,
)
Expand Down Expand Up @@ -589,6 +590,23 @@ class CacheInsertion(TaskRunOrchestrationRule):
FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = {StateType.COMPLETED}

async def before_transition(
self,
initial_state: states.State | None,
proposed_state: states.State | None,
context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
) -> None:
if proposed_state is None:
return

cache_key = proposed_state.state_details.cache_key
if cache_key and len(cache_key) > PREFECT_API_TASK_CACHE_KEY_MAX_LENGTH.value():
await self.reject_transition(
state=proposed_state,
reason=f"Cache key exceeded maximum allowed length of {PREFECT_API_TASK_CACHE_KEY_MAX_LENGTH.value()} characters.",
)
return

@db_injector
async def after_transition(
self,
Expand Down
29 changes: 29 additions & 0 deletions tests/server/orchestration/api/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pendulum
import pytest
from httpx import AsyncClient
from starlette import status

from prefect.client.orchestration import PrefectClient
Expand Down Expand Up @@ -583,6 +584,34 @@ async def test_autonomous_task_run_aborts_if_enters_pending_from_disallowed_stat

assert response_2.status == responses.SetStateStatus.ABORT

async def test_set_task_run_state_with_long_cache_key_rejects_transition(
self, task_run: TaskRun, client: AsyncClient
):
await client.post(
f"/flow_runs/{task_run.flow_run_id}/set_state",
json=dict(state=dict(type="RUNNING")),
)

response = await client.post(
f"/task_runs/{task_run.id}/set_state",
json=dict(
state=dict(
type="COMPLETED",
name="Test State",
state_details={"cache_key": "a" * 5000},
)
),
)
assert response.status_code == status.HTTP_201_CREATED

api_response = OrchestrationResult.model_validate(response.json())
assert api_response.status == responses.SetStateStatus.REJECT
assert isinstance(api_response.details, responses.StateRejectDetails)
assert (
api_response.details.reason
== "Cache key exceeded maximum allowed length of 2000 characters."
)


class TestTaskRunHistory:
async def test_history_interval_must_be_one_second_or_larger(self, client):
Expand Down