Skip to content

Commit

Permalink
[PR aio-libs#7713/d697d421 backport][3.9] Add check to validate absol…
Browse files Browse the repository at this point in the history
…ute URIs (aio-libs#7714)

**This is a backport of PR aio-libs#7713 as merged into master
(d697d42).**
  • Loading branch information
patchback[bot] authored and Xiang Li committed Dec 4, 2023
1 parent 41f4e95 commit 62e2764
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGES/7712.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add check to validate that absolute URIs have schemes.
103 changes: 50 additions & 53 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@
from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import (
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
status_code_must_be_empty_body,
)
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
Expand Down Expand Up @@ -71,17 +65,15 @@
# token = 1*tchar
METHRE: Final[Pattern[str]] = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d).(\d)")
HDRRE: Final[Pattern[bytes]] = re.compile(
rb"[\x00-\x1F\x7F-\xFF()<>@,;:\[\]={} \t\"\\]"
)
HDRRE: Final[Pattern[bytes]] = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\"\\]")
HEXDIGIT = re.compile(rb"[0-9a-fA-F]+")


class RawRequestMessage(NamedTuple):
method: str
path: str
version: HttpVersion
headers: CIMultiDictProxy[str]
headers: "CIMultiDictProxy[str]"
raw_headers: RawHeaders
should_close: bool
compression: Optional[str]
Expand All @@ -106,6 +98,7 @@ class RawResponseMessage(NamedTuple):


class ParseState(IntEnum):

PARSE_NONE = 0
PARSE_LENGTH = 1
PARSE_CHUNKED = 2
Expand All @@ -124,9 +117,11 @@ class HeadersParser:
def __init__(
self,
max_line_size: int = 8190,
max_headers: int = 32768,
max_field_size: int = 8190,
) -> None:
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size

def parse_headers(
Expand Down Expand Up @@ -225,10 +220,11 @@ class HttpParser(abc.ABC, Generic[_MsgT]):

def __init__(
self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
limit: int,
protocol: Optional[BaseProtocol] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
limit: int = 2**16,
max_line_size: int = 8190,
max_headers: int = 32768,
max_field_size: int = 8190,
timer: Optional[BaseTimerContext] = None,
code: Optional[int] = None,
Expand All @@ -242,6 +238,7 @@ def __init__(
self.protocol = protocol
self.loop = loop
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
self.timer = timer
self.code = code
Expand All @@ -258,7 +255,7 @@ def __init__(
self._payload_parser: Optional[HttpPayloadParser] = None
self._auto_decompress = auto_decompress
self._limit = limit
self._headers_parser = HeadersParser(max_line_size, max_field_size)
self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size)

@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> _MsgT:
Expand Down Expand Up @@ -289,6 +286,7 @@ def feed_data(
METH_CONNECT: str = hdrs.METH_CONNECT,
SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1,
) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]:

messages = []

if self._tail:
Expand All @@ -299,6 +297,7 @@ def feed_data(
loop = self.loop

while start_pos < data_len:

# read HTTP message (request/response line + headers), \r\n\r\n
# and split by lines
if self._payload_parser is None and not self._upgraded:
Expand Down Expand Up @@ -344,15 +343,10 @@ def get_content_length() -> Optional[int]:
self._upgraded = msg.upgrade

method = getattr(msg, "method", self.method)
# code is only present on responses
code = getattr(msg, "code", 0)

assert self.protocol is not None
# calculate payload
empty_body = status_code_must_be_empty_body(code) or bool(
method and method_must_be_empty_body(method)
)
if not empty_body and (
if (
(length is not None and length > 0)
or msg.chunked
and not msg.upgrade
Expand Down Expand Up @@ -394,29 +388,34 @@ def get_content_length() -> Optional[int]:
auto_decompress=self._auto_decompress,
lax=self.lax,
)
elif not empty_body and length is None and self.read_until_eof:
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD
if (
getattr(msg, "code", 100) >= 199
and length is None
and self.read_until_eof
):
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD

messages.append((msg, payload))
else:
Expand Down Expand Up @@ -503,8 +502,7 @@ def parse_headers(
close_conn = True
elif v == "keep-alive":
close_conn = False
# https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols
elif v == "upgrade" and headers.get(hdrs.UPGRADE):
elif v == "upgrade":
upgrade = True

# encoding
Expand Down Expand Up @@ -549,7 +547,7 @@ def parse_message(self, lines: List[bytes]) -> RawRequestMessage:
# request line
line = lines[0].decode("utf-8", "surrogateescape")
try:
method, path, version = line.split(" ", maxsplit=2)
method, path, version = line.split(maxsplit=2)
except ValueError:
raise BadStatusLine(line) from None

Expand Down Expand Up @@ -597,9 +595,7 @@ def parse_message(self, lines: List[bytes]) -> RawRequestMessage:
url = URL(path, encoded=True)
if url.scheme == "":
# not absolute-form
raise InvalidURLError(
path.encode(errors="surrogateescape").decode("latin1")
)
raise InvalidURLError(line)

# read headers
(
Expand Down Expand Up @@ -662,7 +658,6 @@ def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
try:
status, reason = status.split(maxsplit=1)
except ValueError:
status = status.strip()
reason = ""

if len(reason) > self.max_line_size:
Expand Down Expand Up @@ -805,6 +800,7 @@ def feed_data(
self._chunk_tail = b""

while chunk:

# read next chunk size
if self._chunk == ChunkState.PARSE_CHUNKED_SIZE:
pos = chunk.find(SEP)
Expand Down Expand Up @@ -868,7 +864,7 @@ def feed_data(

# if stream does not contain trailer, after 0\r\n
# we should get another \r\n otherwise
# trailers needs to be skipped until \r\n\r\n
# trailers needs to be skiped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
head = chunk[: len(SEP)]
if head == SEP:
Expand Down Expand Up @@ -908,6 +904,8 @@ def feed_data(
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""

decompressor: Any

def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
self.out = out
self.size = 0
Expand Down Expand Up @@ -965,8 +963,7 @@ def feed_eof(self) -> None:

if chunk or self.size > 0:
self.out.feed_data(chunk, len(chunk))
# decompressor is not brotli unless encoding is "br"
if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore[union-attr]
if self.encoding == "deflate" and not self.decompressor.eof:
raise ContentEncodingError("deflate")

self.out.feed_eof()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,11 @@ def test_http_request_parser_bad_version_number(parser: Any) -> None:
parser.feed_data(b"GET /test HTTP/1.32\r\n\r\n")


def test_http_request_parser_bad_uri(parser: Any) -> None:
with pytest.raises(http_exceptions.InvalidURLError):
parser.feed_data(b"GET ! HTTP/1.1\r\n\r\n")


@pytest.mark.parametrize("size", [40965, 8191])
def test_http_request_max_status_line(parser, size) -> None:
path = b"t" * (size - 5)
Expand Down

0 comments on commit 62e2764

Please sign in to comment.