Skip to content

Commit

Permalink
Add option to always include port in build_host helper.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Feb 1, 2025
1 parent ded4288 commit 1c1ac3f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
10 changes: 8 additions & 2 deletions src/websockets/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
T = TypeVar("T")


def build_host(host: str, port: int, secure: bool) -> str:
def build_host(
host: str,
port: int,
secure: bool,
*,
always_include_port: bool = False,
) -> str:
"""
Build a ``Host`` header.
Expand All @@ -53,7 +59,7 @@ def build_host(host: str, port: int, secure: bool) -> str:
if address.version == 6:
host = f"[{host}]"

if port != (443 if secure else 80):
if always_include_port or port != (443 if secure else 80):
host = f"{host}:{port}"

return host
Expand Down
43 changes: 25 additions & 18 deletions tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,33 @@

class HeadersTests(unittest.TestCase):
def test_build_host(self):
for (host, port, secure), result in [
(("localhost", 80, False), "localhost"),
(("localhost", 8000, False), "localhost:8000"),
(("localhost", 443, True), "localhost"),
(("localhost", 8443, True), "localhost:8443"),
(("example.com", 80, False), "example.com"),
(("example.com", 8000, False), "example.com:8000"),
(("example.com", 443, True), "example.com"),
(("example.com", 8443, True), "example.com:8443"),
(("127.0.0.1", 80, False), "127.0.0.1"),
(("127.0.0.1", 8000, False), "127.0.0.1:8000"),
(("127.0.0.1", 443, True), "127.0.0.1"),
(("127.0.0.1", 8443, True), "127.0.0.1:8443"),
(("::1", 80, False), "[::1]"),
(("::1", 8000, False), "[::1]:8000"),
(("::1", 443, True), "[::1]"),
(("::1", 8443, True), "[::1]:8443"),
for (host, port, secure), (result, result_with_port) in [
(("localhost", 80, False), ("localhost", "localhost:80")),
(("localhost", 8000, False), ("localhost:8000", "localhost:8000")),
(("localhost", 443, True), ("localhost", "localhost:443")),
(("localhost", 8443, True), ("localhost:8443", "localhost:8443")),
(("example.com", 80, False), ("example.com", "example.com:80")),
(("example.com", 8000, False), ("example.com:8000", "example.com:8000")),
(("example.com", 443, True), ("example.com", "example.com:443")),
(("example.com", 8443, True), ("example.com:8443", "example.com:8443")),
(("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")),
(("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")),
(("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")),
(("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")),
(("::1", 80, False), ("[::1]", "[::1]:80")),
(("::1", 8000, False), ("[::1]:8000", "[::1]:8000")),
(("::1", 443, True), ("[::1]", "[::1]:443")),
(("::1", 8443, True), ("[::1]:8443", "[::1]:8443")),
]:
with self.subTest(host=host, port=port, secure=secure):
self.assertEqual(build_host(host, port, secure), result)
self.assertEqual(
build_host(host, port, secure),
result,
)
self.assertEqual(
build_host(host, port, secure, always_include_port=True),
result_with_port,
)

def test_parse_connection(self):
for header, parsed in [
Expand Down

0 comments on commit 1c1ac3f

Please sign in to comment.