Skip to content

Commit

Permalink
Fix GRPSService
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Jun 3, 2024
1 parent 57b26b7 commit 41bfd6e
Showing 1 changed file with 47 additions and 85 deletions.
132 changes: 47 additions & 85 deletions aiomisc/service/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import sys
from collections import defaultdict
from concurrent.futures import Executor
from types import MappingProxyType
from typing import (
Any, DefaultDict, Dict, Iterable, Optional, Sequence, Set, Tuple,
Any, DefaultDict, Dict, Mapping, Optional, Sequence, Set, Tuple,
)

from .base import Service
Expand All @@ -27,84 +28,18 @@
PortFuture = asyncio.Future


class LazyServer(grpc.aio.Server):
class GRPCService(Service):
GRACEFUL_STOP_TIME: float = 60.

_ADDRESS_REGEXP = re.compile(
r"(?P<address>(\[((([([0-9a-fA-F:]*)+)])?|([\w.]+))):(\d+)",
)

_server: grpc.aio.Server
_services: Set[grpc.ServiceRpcHandler]
_registered_services: DefaultDict[str, Dict[str, grpc.RpcMethodHandler]]

def __init__(self, *, reflection: bool = False, **kwargs: Any) -> None:
self._server = grpc.aio.server(**kwargs)
self._services = set()
self._registered_services = defaultdict(dict)
self._reflection = reflection

def add_registered_method_handlers(
self, name: str, methods: Dict[str, grpc.RpcMethodHandler],
) -> None:
self._registered_services[name].update(methods)

def add_generic_rpc_handlers(
self,
generic_rpc_handlers: Iterable[grpc.ServiceRpcHandler], # type: ignore
) -> None:
for service in generic_rpc_handlers:
self._services.add(service)

def add_insecure_port(self, address: str) -> int:
return self._server.add_insecure_port(address)

def add_secure_port(
self, address: str,
server_credentials: grpc.ServerCredentials,
) -> int:
return self._server.add_secure_port(address, server_credentials)

async def start(self) -> None:
for name, methods in self._registered_services.items():
self._server.add_registered_method_handlers( # type: ignore
name, methods,
)

if self._reflection:
service_names = [x.service_name() for x in self._services]
service_names.append(reflection.SERVICE_NAME)
reflection.enable_server_reflection(
service_names,
self._server,
)

self._server.add_generic_rpc_handlers(tuple(self._services))
return await self._server.start()

async def wait_for_termination(
self, timeout: Optional[float] = None,
) -> bool:
return await self._server.wait_for_termination(timeout)

async def stop(self, grace: Optional[float] = None) -> None:
return await self._server.stop(grace)

@classmethod
def _log_port(cls, msg: str, address: str, bind_port: Any) -> None:
match: Optional[re.Match] = cls._ADDRESS_REGEXP.match(address)

if match is not None:
groups = match.groupdict()
address = groups["address"]

log.info("%s: grpc://%s:%s", msg, address, bind_port)


class GRPCService(Service):
_server: LazyServer
_server_args: MappingProxyType
_insecure_ports: Set[Tuple[str, PortFuture]]
_secure_ports: Set[Tuple[str, grpc.ServerCredentials, PortFuture]]
_registered_services: DefaultDict[str, Dict[str, grpc.RpcMethodHandler]]

def __init__(
self, *,
Expand All @@ -117,20 +52,34 @@ def __init__(
reflection: bool = False,
**kwds: Any,
):
self._server = LazyServer(
compression=compression,
handlers=handlers,
interceptors=interceptors,
maximum_concurrent_rpcs=maximum_concurrent_rpcs,
migration_thread_pool=migration_thread_pool,
options=options,
reflection=reflection,
)
self._server_args = MappingProxyType({
"compression": compression,
"handlers": handlers,
"interceptors": interceptors,
"maximum_concurrent_rpcs": maximum_concurrent_rpcs,
"migration_thread_pool": migration_thread_pool,
"options": options,
})
self._services: Set[grpc.ServiceRpcHandler] = set()
self._insecure_ports = set()
self._secure_ports = set()
self._reflection = reflection
self._registered_services = defaultdict(dict)
super().__init__(**kwds)

@classmethod
def _log_port(cls, msg: str, address: str, bind_port: Any) -> None:
match: Optional[re.Match] = cls._ADDRESS_REGEXP.match(address)

if match is not None:
groups = match.groupdict()
address = groups["address"]

log.info("%s: grpc://%s:%s", msg, address, bind_port)

async def start(self) -> None:
self._server = grpc.aio.server(**self._server_args)

for address, future in self._insecure_ports:
port = self._server.add_insecure_port(address)
future.set_result(port)
Expand All @@ -141,20 +90,33 @@ async def start(self) -> None:
future.set_result(port)
self._log_port("Listening secure address", address, port)

if self._reflection:
service_names = [x.service_name() for x in self._services]
service_names.append(reflection.SERVICE_NAME)
reflection.enable_server_reflection(service_names, self._server)

for name, handlers in self._registered_services.items():
# noinspection PyUnresolvedReferences
self._server.add_registered_method_handlers( # type: ignore
name, handlers,
)

self._server.add_generic_rpc_handlers(tuple(self._services))
await self._server.start()

async def stop(self, exception: Optional[Exception] = None) -> None:
await self._server.stop(self.GRACEFUL_STOP_TIME)

def add_registered_method_handlers(
self, *args: Any, **kwargs: Any,
def add_generic_rpc_handlers(
self, generic_rpc_handlers: Sequence[grpc.ServiceRpcHandler],
) -> None:
return self._server.add_registered_method_handlers(*args, **kwargs)
for service in generic_rpc_handlers:
self._services.add(service)

def add_generic_rpc_handlers(
self, *args: Any, **kwargs: Any,
def add_registered_method_handlers(
self, name: str, handlers: Mapping[str, grpc.RpcMethodHandler],
) -> None:
return self._server.add_generic_rpc_handlers(*args, **kwargs)
self._registered_services[name].update(handlers)

def add_insecure_port(self, address: str) -> PortFuture:
future: PortFuture = asyncio.Future()
Expand Down

0 comments on commit 41bfd6e

Please sign in to comment.