diff --git a/.gitignore b/.gitignore index fd6e3ad..7fe2de5 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,9 @@ ENV/ # PyTests .pytest_cache + +# CLion +/cmake-build-debug/ +/cmake-build-release/ + +.DS_Store diff --git a/src/purerpc/grpc_socket.py b/src/purerpc/grpc_socket.py index bdcbc57..2acd3ec 100644 --- a/src/purerpc/grpc_socket.py +++ b/src/purerpc/grpc_socket.py @@ -6,15 +6,13 @@ import anyio import async_exit_stack from async_generator import async_generator, yield_, yield_from_ -import h2 -import h2.events -import h2.exceptions from purerpc.utils import is_darwin from purerpc.grpclib.exceptions import ProtocolError from .grpclib.connection import GRPCConfiguration, GRPCConnection -from .grpclib.events import RequestReceived, RequestEnded, ResponseEnded, MessageReceived +from .grpclib.events import RequestReceived, RequestEnded, ResponseEnded, MessageReceived, WindowUpdated from .grpclib.buffers import MessageWriteBuffer, MessageReadBuffer +from .grpclib.exceptions import StreamClosedError class SocketWrapper(async_exit_stack.AsyncExitStack): @@ -174,7 +172,7 @@ async def close(self, status=None, content_type_suffix="", custom_metadata=()): if self.client_side: try: self._grpc_connection.end_request(self._stream_id) - except h2.exceptions.StreamClosedError: + except StreamClosedError: # Remote end already closed connection, do nothing here pass elif self._response_started: @@ -236,7 +234,7 @@ async def _listen(self): events = self._grpc_connection.receive_data(data) await self._socket.flush() for event in events: - if isinstance(event, h2.events.WindowUpdated): + if isinstance(event, WindowUpdated): if event.stream_id == 0: for stream in self._streams.values(): await stream._set_flow_control_update() diff --git a/src/purerpc/grpclib/config.py b/src/purerpc/grpclib/config.py index 6b74e89..ba420b6 100644 --- a/src/purerpc/grpclib/config.py +++ b/src/purerpc/grpclib/config.py @@ -1,11 +1,8 @@ -import h2.config - - class GRPCConfiguration: def __init__(self, client_side: bool, server_string=None, user_agent=None, message_encoding=None, message_accept_encoding=None, max_message_length=4194304): - self._h2_config = h2.config.H2Configuration(client_side=client_side, header_encoding="utf-8") + self._client_side = client_side if client_side and server_string is not None: raise ValueError("Passed client_side=True and server_string at the same time") if not client_side and user_agent is not None: @@ -24,7 +21,7 @@ def __init__(self, client_side: bool, server_string=None, user_agent=None, @property def client_side(self): - return self._h2_config.client_side + return self._client_side @property def server_string(self): diff --git a/src/purerpc/grpclib/connection.py b/src/purerpc/grpclib/connection.py index 3080825..37da73d 100644 --- a/src/purerpc/grpclib/connection.py +++ b/src/purerpc/grpclib/connection.py @@ -2,20 +2,18 @@ import logging import datetime -import ngh2 - -import h2.stream -import h2.errors +import h2.config import h2.events import h2.connection import h2.exceptions from h2.settings import SettingCodes +from h2.errors import ErrorCodes from .headers import HeaderDict, sanitize_headers from .status import Status from .config import GRPCConfiguration -from .events import MessageReceived, RequestReceived, RequestEnded, ResponseReceived, ResponseEnded -from .exceptions import ProtocolError +from .events import MessageReceived, RequestReceived, RequestEnded, ResponseReceived, ResponseEnded, WindowUpdated +from .exceptions import ProtocolError, StreamClosedError from .buffers import MessageReadBuffer, MessageWriteBuffer from ._h2_monkeypatch import apply_patch @@ -31,7 +29,8 @@ class GRPCConnection: def __init__(self, config: GRPCConfiguration): self.config = config - self.h2_connection = h2.connection.H2Connection(config._h2_config) + self.h2_connection = h2.connection.H2Connection(h2.config.H2Configuration(client_side=config.client_side, + header_encoding="utf-8")) self.h2_connection.clear_outbound_data_buffer() self._set_h2_connection_local_settings() self.message_read_buffers = {} @@ -88,7 +87,7 @@ def _data_received(self, event: h2.events.DataReceived): self.message_read_buffers[event.stream_id].data_received(event.data, event.flow_controlled_length) except KeyError: - self.h2_connection.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR) + self.h2_connection.reset_stream(event.stream_id, ErrorCodes.PROTOCOL_ERROR) iterator = (self.message_read_buffers[event.stream_id] .read_all_complete_messages_flowcontrol()) @@ -98,13 +97,10 @@ def _data_received(self, event: h2.events.DataReceived): return events def _window_updated(self, event: h2.events.WindowUpdated): - return [event] + return [WindowUpdated(stream_id=event.stream_id, delta=event.delta)] def _remote_settings_changed(self, event: h2.events.RemoteSettingsChanged): - fake_event = h2.events.WindowUpdated() - fake_event.stream_id = 0 - fake_event.delta = 1 - return [fake_event] + return [WindowUpdated(stream_id=0, delta=1)] def _ping_acknowledged(self, event: h2.events.PingAcknowledged): return [] @@ -232,7 +228,10 @@ def start_request(self, stream_id: int, scheme: str, service_name: str, method_n self.h2_connection.send_headers(stream_id, headers, end_stream=False) def end_request(self, stream_id: int): - self.h2_connection.send_data(stream_id, b"", end_stream=True) + try: + self.h2_connection.send_data(stream_id, b"", end_stream=True) + except h2.exceptions.StreamClosedError as e: + raise StreamClosedError(stream_id=e.stream_id, error_code=e.error_code) def start_response(self, stream_id: int, content_type_suffix="", custom_metadata=()): headers = [ diff --git a/src/purerpc/grpclib/events.py b/src/purerpc/grpclib/events.py index ccc3746..5aac485 100644 --- a/src/purerpc/grpclib/events.py +++ b/src/purerpc/grpclib/events.py @@ -10,6 +10,18 @@ class Event: pass +class WindowUpdated(Event): + def __init__(self, stream_id, delta): + self.stream_id = stream_id + self.delta = delta + + def __repr__(self): + return "" % ( + self.stream_id, self.delta + ) + + + class RequestReceived(Event): def __init__(self, stream_id: int, scheme: str, service_name: str, method_name: str, content_type: str): diff --git a/src/purerpc/grpclib/exceptions.py b/src/purerpc/grpclib/exceptions.py index 4c3c519..2909c8b 100644 --- a/src/purerpc/grpclib/exceptions.py +++ b/src/purerpc/grpclib/exceptions.py @@ -5,6 +5,12 @@ class GRPCError(Exception): pass +class StreamClosedError(GRPCError): + def __init__(self, stream_id, error_code): + self.stream_id = stream_id + self.error_code = error_code + + class ProtocolError(GRPCError): pass