Skip to content

Commit

Permalink
Merge branch 'master' into patch-4
Browse files Browse the repository at this point in the history
  • Loading branch information
T-256 authored Feb 20, 2024
2 parents 15d32f4 + accae7b commit 36c8021
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 12 deletions.
7 changes: 6 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@

version: 2
updates:
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "monthly"
groups:
python-packages:
patterns:
- "*"
exclude-patterns:
- "werkzeug"
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## Unreleased

- Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882)

## 1.0.3 (February 13th, 2024)

- Fix support for async cancellations. (#880)
- Fix trace extension when used with socks proxy. (#849)
- Fix SSL context for connections using the "wss" scheme (#869)
Expand Down
2 changes: 1 addition & 1 deletion httpcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, *args, **kwargs): # type: ignore
"WriteError",
]

__version__ = "1.0.2"
__version__ = "1.0.3"


__locals = locals()
Expand Down
50 changes: 47 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
AsyncIterable,
AsyncIterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ async def handle_async_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ async def handle_async_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ async def _send_event(

async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ async def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ async def aclose(self) -> None:
self._closed = True
async with Trace("response_closed", logger, self._request):
await self._connection._response_closed()


class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return await self._stream.read(max_bytes, timeout)

async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
await self._stream.write(buffer, timeout)

async def aclose(self) -> None:
await self._stream.aclose()

async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> AsyncNetworkStream:
return await self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
50 changes: 47 additions & 3 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ def handle_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ def handle_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ def _send_event(

def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

def _receive_response_body(self, request: Request) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ def close(self) -> None:
self._closed = True
with Trace("response_closed", logger, self._request):
self._connection._response_closed()


class HTTP11UpgradeStream(NetworkStream):
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return self._stream.read(max_bytes, timeout)

def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
self._stream.write(buffer, timeout)

def close(self) -> None:
self._stream.close()

def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> NetworkStream:
return self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ socks = [
"socksio==1.*",
]
trio = [
"trio>=0.22.0,<0.24.0",
"trio>=0.22.0,<0.25.0",
]
asyncio = [
"anyio>=4.0,<5.0",
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Docs
mkdocs==1.5.3
mkdocs-autorefs==0.5.0
mkdocs-material==9.5.7
mkdocs-material==9.5.10
mkdocs-material-extensions==1.3.1
mkdocstrings[python-legacy]==0.24.0
jinja2==3.1.3
Expand All @@ -14,10 +14,10 @@ twine==5.0.0

# Tests & Linting
coverage[toml]==7.4.1
ruff==0.2.1
ruff==0.2.2
mypy==1.8.0
trio-typing==0.10.0
pytest==8.0.0
pytest==8.0.1
pytest-httpbin==2.0.0
pytest-trio==0.8.0
werkzeug<2.1 # See: https://github.com/psf/httpbin/issues/35
51 changes: 51 additions & 0 deletions tests/_async/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,57 @@ async def test_http11_upgrade_connection():
assert content == b"..."


@pytest.mark.anyio
async def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.AsyncMockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
async with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = await network_stream.read(max_bytes=3)
assert content == b"foo"
content = await network_stream.read(max_bytes=3)
assert content == b"bar"
content = await network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for AsyncHTTP11UpgradeStream
await network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
await network_stream.aclose()


@pytest.mark.anyio
async def test_http11_early_hints():
"""
Expand Down
51 changes: 51 additions & 0 deletions tests/_sync/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,57 @@ def test_http11_upgrade_connection():



def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.MockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
with httpcore.HTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = network_stream.read(max_bytes=3)
assert content == b"foo"
content = network_stream.read(max_bytes=3)
assert content == b"bar"
content = network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for HTTP11UpgradeStream
network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
network_stream.close()



def test_http11_early_hints():
"""
HTTP "103 Early Hints" is an interim response.
Expand Down

0 comments on commit 36c8021

Please sign in to comment.