From f6eb28cdbdbcc35c87632cbd87cc4a09b2903361 Mon Sep 17 00:00:00 2001 From: tellet-q Date: Tue, 11 Mar 2025 14:23:50 +0100 Subject: [PATCH] Add support for async grpc --- qdrant_client/common/client_exceptions.py | 5 ++- qdrant_client/connection.py | 43 ++++++++++++++++++++--- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/qdrant_client/common/client_exceptions.py b/qdrant_client/common/client_exceptions.py index 647e2960..2dcd513f 100644 --- a/qdrant_client/common/client_exceptions.py +++ b/qdrant_client/common/client_exceptions.py @@ -1,8 +1,11 @@ +from grpc.aio import AioRpcError + + class QdrantException(Exception): """Base class""" -class ResourceExhaustedResponse(QdrantException): +class ResourceExhaustedResponse(QdrantException, AioRpcError): def __init__(self, message: str, retry_after_s: int) -> None: self.message = message if message else "Resource Exhausted Response" try: diff --git a/qdrant_client/connection.py b/qdrant_client/connection.py index 9115979b..f1d69b99 100644 --- a/qdrant_client/connection.py +++ b/qdrant_client/connection.py @@ -1,5 +1,6 @@ import asyncio import collections +from inspect import iscoroutinefunction from typing import Any, Awaitable, Callable, Optional, Union import grpc @@ -72,7 +73,10 @@ async def intercept_unary_unary( ) next_request = next(new_request_iterator) response = await continuation(new_details, next_request) - return postprocess(response) if postprocess else response + if iscoroutinefunction(postprocess): + return await postprocess(response) if postprocess else response + else: + return postprocess(response) if postprocess else response async def intercept_unary_stream( self, continuation: Any, client_call_details: Any, request: Any @@ -81,7 +85,10 @@ async def intercept_unary_stream( client_call_details, iter((request,)), False, True ) response_it = await continuation(new_details, next(new_request_iterator)) - return postprocess(response_it) if postprocess else response_it + if iscoroutinefunction(postprocess): + return await postprocess(response_it) if postprocess else response_it + else: + return postprocess(response_it) if postprocess else response_it async def intercept_stream_unary( self, continuation: Any, client_call_details: Any, request_iterator: Any @@ -90,7 +97,10 @@ async def intercept_stream_unary( client_call_details, request_iterator, True, False ) response = await continuation(new_details, new_request_iterator) - return postprocess(response) if postprocess else response + if iscoroutinefunction(postprocess): + return await postprocess(response) if postprocess else response + else: + return postprocess(response) if postprocess else response async def intercept_stream_stream( self, continuation: Any, client_call_details: Any, request_iterator: Any @@ -99,7 +109,10 @@ async def intercept_stream_stream( client_call_details, request_iterator, True, True ) response_it = await continuation(new_details, new_request_iterator) - return postprocess(response_it) if postprocess else response_it + if iscoroutinefunction(postprocess): + return await postprocess(response_it) if postprocess else response_it + else: + return postprocess(response_it) if postprocess else response_it def create_generic_client_interceptor(intercept_call: Any) -> _GenericClientInterceptor: @@ -186,6 +199,26 @@ def header_adder_async_interceptor( new_metadata: list[tuple[str, str]], auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None, ) -> _GenericAsyncClientInterceptor: + async def process_response(call: Any) -> Any: + try: + return await call + except grpc.aio.AioRpcError as er: + if er.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: + retry_after = None + for item in er.trailing_metadata(): + if item[0] == "retry-after": + try: + retry_after = int(item[1]) + except Exception: + retry_after = None + break + reason_phrase = er.details() if er.details() else "" + if retry_after: + raise ResourceExhaustedResponse( + message=reason_phrase, retry_after_s=retry_after + ) from er + raise + async def intercept_call( client_call_details: grpc.aio.ClientCallDetails, request_iterator: Any, @@ -211,7 +244,7 @@ async def intercept_call( metadata.append(("authorization", f"Bearer {token}")) client_call_details = client_call_details._replace(metadata=metadata) - return client_call_details, request_iterator, None + return client_call_details, request_iterator, process_response return create_generic_async_client_interceptor(intercept_call)