Skip to content

Commit

Permalink
Add support for async grpc
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Mar 11, 2025
1 parent 87bc1a1 commit f6eb28c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
5 changes: 4 additions & 1 deletion qdrant_client/common/client_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from grpc.aio import AioRpcError


class QdrantException(Exception):
"""Base class"""


class ResourceExhaustedResponse(QdrantException):
class ResourceExhaustedResponse(QdrantException, AioRpcError):
def __init__(self, message: str, retry_after_s: int) -> None:
self.message = message if message else "Resource Exhausted Response"
try:
Expand Down
43 changes: 38 additions & 5 deletions qdrant_client/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Optional, Union

import grpc
Expand Down Expand Up @@ -72,7 +73,10 @@ async def intercept_unary_unary(
)
next_request = next(new_request_iterator)
response = await continuation(new_details, next_request)
return postprocess(response) if postprocess else response
if iscoroutinefunction(postprocess):
return await postprocess(response) if postprocess else response
else:
return postprocess(response) if postprocess else response

async def intercept_unary_stream(
self, continuation: Any, client_call_details: Any, request: Any
Expand All @@ -81,7 +85,10 @@ async def intercept_unary_stream(
client_call_details, iter((request,)), False, True
)
response_it = await continuation(new_details, next(new_request_iterator))
return postprocess(response_it) if postprocess else response_it
if iscoroutinefunction(postprocess):
return await postprocess(response_it) if postprocess else response_it
else:
return postprocess(response_it) if postprocess else response_it

async def intercept_stream_unary(
self, continuation: Any, client_call_details: Any, request_iterator: Any
Expand All @@ -90,7 +97,10 @@ async def intercept_stream_unary(
client_call_details, request_iterator, True, False
)
response = await continuation(new_details, new_request_iterator)
return postprocess(response) if postprocess else response
if iscoroutinefunction(postprocess):
return await postprocess(response) if postprocess else response
else:
return postprocess(response) if postprocess else response

async def intercept_stream_stream(
self, continuation: Any, client_call_details: Any, request_iterator: Any
Expand All @@ -99,7 +109,10 @@ async def intercept_stream_stream(
client_call_details, request_iterator, True, True
)
response_it = await continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
if iscoroutinefunction(postprocess):
return await postprocess(response_it) if postprocess else response_it
else:
return postprocess(response_it) if postprocess else response_it


def create_generic_client_interceptor(intercept_call: Any) -> _GenericClientInterceptor:
Expand Down Expand Up @@ -186,6 +199,26 @@ def header_adder_async_interceptor(
new_metadata: list[tuple[str, str]],
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
) -> _GenericAsyncClientInterceptor:
async def process_response(call: Any) -> Any:
try:
return await call
except grpc.aio.AioRpcError as er:
if er.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
retry_after = None
for item in er.trailing_metadata():
if item[0] == "retry-after":
try:
retry_after = int(item[1])
except Exception:
retry_after = None
break
reason_phrase = er.details() if er.details() else ""
if retry_after:
raise ResourceExhaustedResponse(
message=reason_phrase, retry_after_s=retry_after
) from er
raise

async def intercept_call(
client_call_details: grpc.aio.ClientCallDetails,
request_iterator: Any,
Expand All @@ -211,7 +244,7 @@ async def intercept_call(
metadata.append(("authorization", f"Bearer {token}"))

client_call_details = client_call_details._replace(metadata=metadata)
return client_call_details, request_iterator, None
return client_call_details, request_iterator, process_response

return create_generic_async_client_interceptor(intercept_call)

Expand Down

0 comments on commit f6eb28c

Please sign in to comment.