diff --git a/src/prefect/client/base.py b/src/prefect/client/base.py index 4a018e2f7352..5071387668aa 100644 --- a/src/prefect/client/base.py +++ b/src/prefect/client/base.py @@ -161,7 +161,7 @@ class PrefectResponse(httpx.Response): Provides more informative error messages. """ - def raise_for_status(self) -> None: + def raise_for_status(self) -> Response: """ Raise an exception if the response contains an HTTPStatusError. @@ -174,7 +174,7 @@ def raise_for_status(self) -> None: raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__ @classmethod - def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Self: + def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response: """ Create a `PrefectReponse` from an `httpx.Response`. @@ -200,10 +200,10 @@ class PrefectHttpxAsyncClient(httpx.AsyncClient): def __init__( self, - *args, + *args: Any, enable_csrf_support: bool = False, raise_on_all_errors: bool = True, - **kwargs, + **kwargs: Any, ): self.enable_csrf_support: bool = enable_csrf_support self.csrf_token: Optional[str] = None @@ -222,10 +222,10 @@ async def _send_with_retry( self, request: Request, send: Callable[[Request], Awaitable[Response]], - send_args: Tuple, - send_kwargs: Dict, + send_args: Tuple[Any, ...], + send_kwargs: Dict[str, Any], retry_codes: Set[int] = set(), - retry_exceptions: Tuple[Exception, ...] = tuple(), + retry_exceptions: Tuple[Type[Exception], ...] = tuple(), ): """ Send a request and retry it if it fails. @@ -297,7 +297,7 @@ async def _send_with_retry( if exc_info else ( "Received response with retryable status code" - f" {response.status_code}. " + f" {response.status_code if response else 'unknown'}. " ) ) + f"Another attempt will be made in {retry_seconds}s. " @@ -314,7 +314,7 @@ async def _send_with_retry( # We ran out of retries, return the failed response return response - async def send(self, request: Request, *args, **kwargs) -> Response: + async def send(self, request: Request, *args: Any, **kwargs: Any) -> Response: """ Send a request with automatic retry behavior for the following status codes: @@ -414,10 +414,10 @@ class PrefectHttpxSyncClient(httpx.Client): def __init__( self, - *args, + *args: Any, enable_csrf_support: bool = False, raise_on_all_errors: bool = True, - **kwargs, + **kwargs: Any, ): self.enable_csrf_support: bool = enable_csrf_support self.csrf_token: Optional[str] = None @@ -436,10 +436,10 @@ def _send_with_retry( self, request: Request, send: Callable[[Request], Response], - send_args: Tuple, - send_kwargs: Dict, + send_args: Tuple[Any, ...], + send_kwargs: Dict[str, Any], retry_codes: Set[int] = set(), - retry_exceptions: Tuple[Exception, ...] = tuple(), + retry_exceptions: Tuple[Type[Exception], ...] = tuple(), ): """ Send a request and retry it if it fails. @@ -511,7 +511,7 @@ def _send_with_retry( if exc_info else ( "Received response with retryable status code" - f" {response.status_code}. " + f" {response.status_code if response else 'unknown'}. " ) ) + f"Another attempt will be made in {retry_seconds}s. " @@ -528,7 +528,7 @@ def _send_with_retry( # We ran out of retries, return the failed response return response - def send(self, request: Request, *args, **kwargs) -> Response: + def send(self, request: Request, *args: Any, **kwargs: Any) -> Response: """ Send a request with automatic retry behavior for the following status codes: diff --git a/src/prefect/client/cloud.py b/src/prefect/client/cloud.py index 38a69150e922..6542393ed4b7 100644 --- a/src/prefect/client/cloud.py +++ b/src/prefect/client/cloud.py @@ -30,7 +30,7 @@ def get_cloud_client( host: Optional[str] = None, api_key: Optional[str] = None, - httpx_settings: Optional[dict] = None, + httpx_settings: Optional[Dict[str, Any]] = None, infer_cloud_url: bool = False, ) -> "CloudClient": """ @@ -45,6 +45,9 @@ def get_cloud_client( configured_url = prefect.settings.PREFECT_API_URL.value() host = re.sub(PARSE_API_URL_REGEX, "", configured_url) + if host is None: + raise ValueError("Host was not provided and could not be inferred") + return CloudClient( host=host, api_key=api_key or PREFECT_API_KEY.value(), @@ -176,7 +179,7 @@ async def __aenter__(self): await self._client.__aenter__() return self - async def __aexit__(self, *exc_info): + async def __aexit__(self, *exc_info: Any) -> None: return await self._client.__aexit__(*exc_info) def __enter__(self): @@ -188,10 +191,10 @@ def __enter__(self): def __exit__(self, *_): assert False, "This should never be called but must be defined for __enter__" - async def get(self, route, **kwargs): + async def get(self, route: str, **kwargs: Any) -> Any: return await self.request("GET", route, **kwargs) - async def request(self, method, route, **kwargs): + async def request(self, method: str, route: str, **kwargs: Any) -> Any: try: res = await self._client.request(method, route, **kwargs) res.raise_for_status() diff --git a/src/prefect/client/collections.py b/src/prefect/client/collections.py index 12285d50a3d1..e5bd79f04325 100644 --- a/src/prefect/client/collections.py +++ b/src/prefect/client/collections.py @@ -13,12 +13,12 @@ async def read_worker_metadata(self) -> Dict[str, Any]: async def __aenter__(self) -> "CollectionsMetadataClient": ... - async def __aexit__(self, *exc_info) -> Any: + async def __aexit__(self, *exc_info: Any) -> Any: ... def get_collections_metadata_client( - httpx_settings: Optional[Dict] = None, + httpx_settings: Optional[Dict[str, Any]] = None, ) -> "CollectionsMetadataClient": """ Creates a client that can be used to fetch metadata for diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py index c2ebf0ab673e..d13873e14b05 100644 --- a/src/prefect/client/subscriptions.py +++ b/src/prefect/client/subscriptions.py @@ -27,27 +27,33 @@ def __init__( ): self.model = model self.client_id = client_id - base_url = base_url.replace("http", "ws", 1) + base_url = base_url.replace("http", "ws", 1) if base_url else None self.subscription_url = f"{base_url}{path}" self.keys = list(keys) self._connect = websockets.connect( self.subscription_url, - subprotocols=["prefect"], + subprotocols=[websockets.Subprotocol("prefect")], ) self._websocket = None def __aiter__(self) -> Self: return self + @property + def websocket(self) -> websockets.WebSocketClientProtocol: + if not self._websocket: + raise RuntimeError("Subscription is not connected") + return self._websocket + async def __anext__(self) -> S: while True: try: await self._ensure_connected() - message = await self._websocket.recv() + message = await self.websocket.recv() - await self._websocket.send(orjson.dumps({"type": "ack"}).decode()) + await self.websocket.send(orjson.dumps({"type": "ack"}).decode()) return self.model.model_validate_json(message) except ( @@ -84,13 +90,19 @@ async def _ensure_connected(self): AssertionError, websockets.exceptions.ConnectionClosedError, ) as e: - if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION: + if isinstance(e, AssertionError) or ( + e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION + ): if isinstance(e, AssertionError): reason = e.args[0] - elif isinstance(e, websockets.exceptions.ConnectionClosedError): + elif e.rcvd and e.rcvd.reason: reason = e.rcvd.reason + else: + reason = "unknown" + else: + reason = None - if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION: + if reason: raise Exception( "Unable to authenticate to the subscription. Please " "ensure the provided `PREFECT_API_KEY` you are using is " diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index ffe42e63195f..81ff31199e6e 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -15,13 +15,14 @@ Optional, Tuple, TypeVar, + Union, cast, ) from typing_extensions import Concatenate, ParamSpec if TYPE_CHECKING: - from prefect.client.orchestration import PrefectClient + from prefect.client.orchestration import PrefectClient, SyncPrefectClient P = ParamSpec("P") R = TypeVar("R") @@ -29,7 +30,7 @@ def get_or_create_client( client: Optional["PrefectClient"] = None, -) -> Tuple["PrefectClient", bool]: +) -> Tuple[Union["PrefectClient", "SyncPrefectClient"], bool]: """ Returns provided client, infers a client from context if available, or creates a new client. @@ -48,7 +49,7 @@ def get_or_create_client( flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() - if async_client_context and async_client_context.client._loop == get_running_loop(): + if async_client_context and async_client_context.client._loop == get_running_loop(): # type: ignore[reportPrivateUsage] return async_client_context.client, True elif ( flow_run_context @@ -72,7 +73,7 @@ def client_injector( @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: client, _ = get_or_create_client() - return await func(client, *args, **kwargs) + return await func(cast("PrefectClient", client), *args, **kwargs) return wrapper @@ -90,16 +91,18 @@ def inject_client( @wraps(fn) async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R: - client = cast(Optional["PrefectClient"], kwargs.pop("client", None)) - client, inferred = get_or_create_client(client) + client, inferred = get_or_create_client( + cast(Optional["PrefectClient"], kwargs.pop("client", None)) + ) + _client = cast("PrefectClient", client) if not inferred: - context = client + context = _client else: from prefect.utilities.asyncutils import asyncnullcontext context = asyncnullcontext() async with context as new_client: - kwargs.setdefault("client", new_client or client) + kwargs.setdefault("client", new_client or _client) return await fn(*args, **kwargs) return with_injected_client diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index 3939632e1641..ce5a0229b049 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -12,6 +12,7 @@ from functools import partial, wraps from typing import ( Any, + AsyncGenerator, Awaitable, Callable, Coroutine, @@ -410,7 +411,9 @@ async def ctx_call(): @asynccontextmanager -async def asyncnullcontext(value=None, *args, **kwargs): +async def asyncnullcontext( + value: Optional[Any] = None, *args: Any, **kwargs: Any +) -> AsyncGenerator[Any, None]: yield value diff --git a/src/prefect/utilities/math.py b/src/prefect/utilities/math.py index 2ece5eb85fa3..9daca1c74186 100644 --- a/src/prefect/utilities/math.py +++ b/src/prefect/utilities/math.py @@ -2,7 +2,9 @@ import random -def poisson_interval(average_interval, lower=0, upper=1): +def poisson_interval( + average_interval: float, lower: float = 0, upper: float = 1 +) -> float: """ Generates an "inter-arrival time" for a Poisson process. @@ -16,12 +18,12 @@ def poisson_interval(average_interval, lower=0, upper=1): return -math.log(max(1 - random.uniform(lower, upper), 1e-10)) * average_interval -def exponential_cdf(x, average_interval): +def exponential_cdf(x: float, average_interval: float) -> float: ld = 1 / average_interval return 1 - math.exp(-ld * x) -def lower_clamp_multiple(k): +def lower_clamp_multiple(k: float) -> float: """ Computes a lower clamp multiple that can be used to bound a random variate drawn from an exponential distribution. @@ -38,7 +40,9 @@ def lower_clamp_multiple(k): return math.log(max(2**k / (2**k - 1), 1e-10), 2) -def clamped_poisson_interval(average_interval, clamping_factor=0.3): +def clamped_poisson_interval( + average_interval: float, clamping_factor: float = 0.3 +) -> float: """ Bounds Poisson "inter-arrival times" to a range defined by the clamping factor. @@ -57,7 +61,7 @@ def clamped_poisson_interval(average_interval, clamping_factor=0.3): return poisson_interval(average_interval, lower_rv, upper_rv) -def bounded_poisson_interval(lower_bound, upper_bound): +def bounded_poisson_interval(lower_bound: float, upper_bound: float) -> float: """ Bounds Poisson "inter-arrival times" to a range.