Skip to content

Commit

Permalink
Hide H2Connection inside _LockedObject (#3318)
Browse files Browse the repository at this point in the history
  • Loading branch information
pquentin authored Jan 25, 2024
1 parent 26a07db commit 8c8e26d
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions src/urllib3/http2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
import threading
import types
import typing

import h2.config # type: ignore[import]
Expand All @@ -18,15 +18,41 @@

orig_HTTPSConnection = HTTPSConnection

T = typing.TypeVar("T")


class _LockedObject(typing.Generic[T]):
"""
A wrapper class that hides a specific object behind a lock.
The goal here is to provide a simple way to protect access to an object
that cannot safely be simultaneously accessed from multiple threads. The
intended use of this class is simple: take hold of it with a context
manager, which returns the protected object.
"""

def __init__(self, obj: T):
self.lock = threading.RLock()
self._obj = obj

def __enter__(self) -> T:
self.lock.acquire()
return self._obj

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
self.lock.release()


class HTTP2Connection(HTTPSConnection):
def __init__(
self, host: str, port: int | None = None, **kwargs: typing.Any
) -> None:
self._h2_lock = threading.RLock()
self._h2_conn = h2.connection.H2Connection(
config=h2.config.H2Configuration(client_side=True)
)
self._h2_conn = self._new_h2_conn()
self._h2_stream: int | None = None
self._h2_headers: list[tuple[bytes, bytes]] = []

Expand All @@ -35,15 +61,14 @@ def __init__(

super().__init__(host, port, **kwargs)

@contextlib.contextmanager
def _lock_h2_conn(self) -> typing.Generator[h2.connection.H2Connection, None, None]:
with self._h2_lock:
yield self._h2_conn
def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
config = h2.config.H2Configuration(client_side=True)
return _LockedObject(h2.connection.H2Connection(config=config))

def connect(self) -> None:
super().connect()

with self._lock_h2_conn() as h2_conn:
with self._h2_conn as h2_conn:
h2_conn.initiate_connection()
self.sock.sendall(h2_conn.data_to_send())

Expand All @@ -54,7 +79,7 @@ def putrequest(
skip_host: bool = False,
skip_accept_encoding: bool = False,
) -> None:
with self._lock_h2_conn() as h2_conn:
with self._h2_conn as h2_conn:
self._request_url = url
self._h2_stream = h2_conn.get_next_available_stream_id()

Expand All @@ -79,7 +104,7 @@ def putheader(self, header: str, *values: str) -> None:
)

def endheaders(self) -> None: # type: ignore[override]
with self._lock_h2_conn() as h2_conn:
with self._h2_conn as h2_conn:
h2_conn.send_headers(
stream_id=self._h2_stream,
headers=self._h2_headers,
Expand All @@ -98,7 +123,7 @@ def getresponse( # type: ignore[override]
) -> HTTP2Response:
status = None
data = bytearray()
with self._lock_h2_conn() as h2_conn:
with self._h2_conn as h2_conn:
end_stream = False
while not end_stream:
# TODO: Arbitrary read value.
Expand Down Expand Up @@ -144,18 +169,16 @@ def getresponse( # type: ignore[override]
)

def close(self) -> None:
with self._lock_h2_conn() as h2_conn:
with self._h2_conn as h2_conn:
try:
self._h2_conn.close_connection()
h2_conn.close_connection()
if data := h2_conn.data_to_send():
self.sock.sendall(data)
except Exception:
pass

# Reset all our HTTP/2 connection state.
self._h2_conn = h2.connection.H2Connection(
config=h2.config.H2Configuration(client_side=True)
)
self._h2_conn = self._new_h2_conn()
self._h2_stream = None
self._h2_headers = []

Expand Down

0 comments on commit 8c8e26d

Please sign in to comment.