diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 83e25c04..5c95cce0 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -149,4 +149,6 @@ Contributors - Jonathan Vanasco, 2022-11-15 +- Ananth Bhaskararaman, 2023-04-05 + - Simon King, 2024-11-12 diff --git a/docs/socket-activation.rst b/docs/socket-activation.rst index 63483a31..7343408b 100644 --- a/docs/socket-activation.rst +++ b/docs/socket-activation.rst @@ -4,7 +4,8 @@ Socket Activation While waitress does not support the various implementations of socket activation, for example using systemd or launchd, it is prepared to receive pre-bound sockets from init systems, process and socket managers, or other launchers that can provide -pre-bound sockets. +pre-bound sockets. Waitress supports INET and INET6 sockets, and UNIX stream sockets. +Additionally, on Linux it supports VSOCK stream sockets. The following shows a code example starting waitress with two pre-bound Internet sockets. diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py index 6266245c..91b3548c 100644 --- a/src/waitress/adjustments.py +++ b/src/waitress/adjustments.py @@ -17,7 +17,7 @@ import socket import warnings -from .compat import HAS_IPV6, WIN +from .compat import HAS_IPV6, VSOCK, WIN from .proxy_headers import PROXY_HEADERS truthy = frozenset(("t", "true", "y", "yes", "on", "1")) @@ -81,6 +81,10 @@ def str_iftruthy(s): return str(s) if s else None +def int_iftruthy(s): + return int(s) if s else None + + def as_socket_list(sockets): """Checks if the elements in the list are of type socket and removes them if not.""" @@ -130,6 +134,8 @@ class Adjustments: ("asyncore_use_poll", asbool), ("unix_socket", str), ("unix_socket_perms", asoctal), + ("vsock_socket_cid", int_iftruthy), + ("vsock_socket_port", int_iftruthy), ("sockets", as_socket_list), ("channel_request_lookahead", int), ("server_name", str), @@ -254,6 +260,10 @@ class Adjustments: # Path to a Unix domain socket to use. unix_socket_perms = 0o600 + # The CID and port to use for a vsock socket. + vsock_socket_cid = None + vsock_socket_port = None + # The socket options to set on receiving a connection. It is a list of # (level, optname, value) tuples. TCP_NODELAY disables the Nagle # algorithm for writes (Waitress already buffers its writes). @@ -302,12 +312,41 @@ def __init__(self, **kw): if "sockets" in kw and "unix_socket" in kw: raise ValueError("unix_socket may not be set if sockets is set") + if "sockets" in kw and ("vsock_socket_cid" in kw or "vsock_socket_port" in kw): + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if sockets is set" + ) + if "unix_socket" in kw and ("host" in kw or "port" in kw): raise ValueError("unix_socket may not be set if host or port is set") if "unix_socket" in kw and "listen" in kw: raise ValueError("unix_socket may not be set if listen is set") + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and not VSOCK: + raise ValueError( + "vsock_socket_cid and vsock_socket_port are not supported on this platform" + ) + + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and ( + "host" in kw or "port" in kw + ): + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if host or port is set" + ) + + if ("vsock_socket_cid" in kw or "vsock_socket_port" in kw) and "listen" in kw: + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if listen is set" + ) + + if ( + "vsock_socket_cid" in kw or "vsock_socket_port" in kw + ) and "unix_socket" in kw: + raise ValueError( + "vsock_socket_cid or vsock_socket_port may not be set if unix_socket is set" + ) + if "send_bytes" in kw: warnings.warn( "send_bytes will be removed in a future release", DeprecationWarning @@ -353,10 +392,10 @@ def __init__(self, **kw): # Try turning the port into an integer port = int(port) - except Exception: + except Exception as exc: raise ValueError( "Windows does not support service names instead of port numbers" - ) + ) from exc try: if "[" in host and "]" in host: # pragma: nocover @@ -391,20 +430,20 @@ def __init__(self, **kw): wanted_sockets.append((family, socktype, proto, sockaddr)) hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) - except Exception: - raise ValueError("Invalid host/port specified.") + except Exception as exc: + raise ValueError("Invalid host/port specified.") from exc if self.trusted_proxy_count is not None and self.trusted_proxy is None: raise ValueError( - "trusted_proxy_count has no meaning without setting " "trusted_proxy" + "trusted_proxy_count has no meaning without setting trusted_proxy" ) - elif self.trusted_proxy_count is None: + if self.trusted_proxy_count is None: self.trusted_proxy_count = 1 if self.trusted_proxy_headers and self.trusted_proxy is None: raise ValueError( - "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + "trusted_proxy_headers has no meaning without setting trusted_proxy" ) if self.trusted_proxy_headers: @@ -415,9 +454,9 @@ def __init__(self, **kw): unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS if unknown_values: raise ValueError( - "Received unknown trusted_proxy_headers value (%s) expected one " - "of %s" - % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + "Received unknown trusted_proxy_headers value " + f"({', '.join(unknown_values)}) expected one " + f"of {', '.join(KNOWN_PROXY_HEADERS)}" ) if ( @@ -486,23 +525,22 @@ def parse_args(cls, argv): @classmethod def check_sockets(cls, sockets): - has_unix_socket = False - has_inet_socket = False - has_unsupported_socket = False + supported_families = [socket.AF_INET, socket.AF_INET6] + if hasattr(socket, "AF_UNIX"): + supported_families.append(socket.AF_UNIX) + if hasattr(socket, "AF_VSOCK"): + supported_families.append(socket.AF_VSOCK) + + inet_families = (socket.AF_INET, socket.AF_INET6) + family = None for sock in sockets: - if ( - sock.family == socket.AF_INET or sock.family == socket.AF_INET6 - ) and sock.type == socket.SOCK_STREAM: - has_inet_socket = True - elif ( - hasattr(socket, "AF_UNIX") - and sock.family == socket.AF_UNIX - and sock.type == socket.SOCK_STREAM - ): - has_unix_socket = True - else: - has_unsupported_socket = True - if has_unix_socket and has_inet_socket: - raise ValueError("Internet and UNIX sockets may not be mixed.") - if has_unsupported_socket: - raise ValueError("Only Internet or UNIX stream sockets may be used.") + if sock.type != socket.SOCK_STREAM or sock.family not in supported_families: + raise ValueError( + "Only Internet, UNIX, or VSOCK stream sockets may be used." + ) + if family is None: + family = sock.family + elif family in inet_families and sock.family in inet_families: + pass + elif family != sock.family: + raise ValueError("All sockets must belong to the same family.") diff --git a/src/waitress/compat.py b/src/waitress/compat.py index 67543b9c..8cb82dfa 100644 --- a/src/waitress/compat.py +++ b/src/waitress/compat.py @@ -6,8 +6,9 @@ import sys import warnings -# True if we are running on Windows +# Platform detection. WIN = platform.system() == "Windows" +VSOCK = hasattr(socket, "AF_VSOCK") MAXINT = sys.maxsize HAS_IPV6 = socket.has_ipv6 diff --git a/src/waitress/server.py b/src/waitress/server.py index 1826cb84..73b7b9a3 100644 --- a/src/waitress/server.py +++ b/src/waitress/server.py @@ -25,6 +25,7 @@ from waitress.utilities import cleanup_unix_socket from . import wasyncore +from .compat import VSOCK from .proxy_headers import proxy_headers_middleware @@ -68,6 +69,22 @@ def create_server( sockinfo=sockinfo, ) + if (adj.vsock_socket_cid or adj.vsock_socket_port) and VSOCK: + if not adj.vsock_socket_cid: + adj.vsock_socket_cid = socket.VMADDR_CID_ANY + if not adj.vsock_socket_port: + adj.vsock_socket_port = socket.VMADDR_PORT_ANY + sockinfo = (socket.AF_VSOCK, socket.SOCK_STREAM, None, None) + return VsockWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen = [] last_serv = None if not adj.sockets: @@ -90,7 +107,7 @@ def create_server( for sock in adj.sockets: sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) - if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + if sock.family in (socket.AF_INET, socket.AF_INET6): last_serv = TcpWSGIServer( application, map, @@ -118,6 +135,20 @@ def create_server( effective_listen.append( (last_serv.effective_host, last_serv.effective_port) ) + elif VSOCK and sock.family == socket.AF_VSOCK: + last_serv = VsockWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) # We are running a single server, so we can just return the last server, # saves us from having to create one more object @@ -416,5 +447,40 @@ def fix_addr(self, addr): return ("localhost", None) +if VSOCK: + + class VsockWSGIServer(BaseWSGIServer): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): + if sockinfo is None: + sockinfo = (socket.AF_VSOCK, socket.SOCK_STREAM, None, None) + + super().__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw, + ) + + def bind_server_socket(self): + self.bind((self.adj.vsock_socket_cid, self.adj.vsock_socket_port)) + + def getsockname(self): + return ("vsock", self.socket.getsockname()) + + # Compatibility alias. WSGIServer = TcpWSGIServer diff --git a/tests/test_adjustments.py b/tests/test_adjustments.py index b7243a92..296d54c2 100644 --- a/tests/test_adjustments.py +++ b/tests/test_adjustments.py @@ -1,8 +1,9 @@ +from re import L import socket import unittest import warnings -from waitress.compat import WIN +from waitress.compat import VSOCK, WIN class Test_asbool(unittest.TestCase): @@ -106,35 +107,40 @@ def _makeOne(self, **kw): return Adjustments(**kw) def test_goodvars(self): - inst = self._makeOne( - host="localhost", - port="8080", - threads="5", - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"forwarded"}, - trusted_proxy_count=2, - log_untrusted_proxy_headers=True, - url_scheme="https", - backlog="20", - recv_bytes="200", - send_bytes="300", - outbuf_overflow="400", - inbuf_overflow="500", - connection_limit="1000", - cleanup_interval="1100", - channel_timeout="1200", - log_socket_errors="true", - max_request_header_size="1300", - max_request_body_size="1400", - expose_tracebacks="true", - ident="abc", - asyncore_loop_timeout="5", - asyncore_use_poll=True, - unix_socket_perms="777", - url_prefix="///foo/", - ipv4=True, - ipv6=False, - ) + kw = { + "host": "localhost", + "port": "8080", + "threads": "5", + "trusted_proxy": "192.168.1.1", + "trusted_proxy_headers": {"forwarded"}, + "trusted_proxy_count": 2, + "log_untrusted_proxy_headers": True, + "url_scheme": "https", + "backlog": "20", + "recv_bytes": "200", + "send_bytes": "300", + "outbuf_overflow": "400", + "inbuf_overflow": "500", + "connection_limit": "1000", + "cleanup_interval": "1100", + "channel_timeout": "1200", + "log_socket_errors": "true", + "max_request_header_size": "1300", + "max_request_body_size": 1400, + "expose_tracebacks": "true", + "ident": "abc", + "asyncore_loop_timeout": "5", + "asyncore_use_poll": True, + "unix_socket_perms": "777", + "url_prefix": "///foo/", + "ipv4": True, + "ipv6": False, + } + if VSOCK: + kw["vsock_socket_cid"] = -1 + kw["vsock_socket_port"] = -1 + + inst = self._makeOne(**kw) self.assertEqual(inst.host, "localhost") self.assertEqual(inst.port, 8080) @@ -164,6 +170,10 @@ def test_goodvars(self): self.assertTrue(inst.ipv4) self.assertFalse(inst.ipv6) + if VSOCK: + self.assertEqual(inst.vsock_socket_cid, -1) + self.assertEqual(inst.vsock_socket_port, -1) + bind_pairs = [ sockaddr[:2] for (family, _, _, sockaddr) in inst.listen @@ -251,6 +261,11 @@ def test_good_sockets(self): sockets[0].close() sockets[1].close() + def test_dont_use_dgram_sockets(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + def test_dont_mix_sockets_and_listen(self): sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] self.assertRaises( @@ -272,6 +287,28 @@ def test_dont_mix_sockets_and_unix_socket(self): ) sockets[0].close() + def test_dont_mix_unix_and_vsock_socket(self): + if not VSOCK: + return + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + for sock in sockets: + sock.close() + + def test_dont_mix_tcp_and_vsock_socket(self): + if not VSOCK: + return + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + for sock in sockets: + sock.close() + def test_dont_mix_unix_socket_and_host_port(self): self.assertRaises( ValueError, @@ -481,3 +518,21 @@ def test_dont_mix_internet_and_unix_sockets(self): self.assertRaises(ValueError, self._makeOne, sockets=sockets) sockets[0].close() sockets[1].close() + + +if VSOCK: + + class TestVsockSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/tests/test_server.py b/tests/test_server.py index cede49a7..c9eba48a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,6 +2,8 @@ import socket import unittest +from waitress.compat import VSOCK + dummy_app = object() @@ -414,6 +416,112 @@ def test_create_with_unix_socket(self): self.assertIsInstance(server[1], UnixWSGIServer) +if VSOCK: + + class TestVsockWSGIServer(unittest.TestCase): + vsock_socket_cid = 2 + vsock_socket_port = -1 + + def _makeOne(self, _start=True, _sock=None): + from waitress.server import create_server + + self.inst = create_server( + dummy_app, + map={}, + _start=_start, + _sock=_sock, + _dispatcher=DummyTaskDispatcher(), + vsock_socket_cid=self.vsock_socket_cid, + vsock_socket_port=self.vsock_socket_port, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + self.inst.close() + + def _makeDummy(self, *args, **kwargs): + sock = DummySock(*args, **kwargs) + sock.family = socket.AF_VSOCK + return sock + + def test_unix(self): + inst = self._makeOne(_start=False) + self.assertEqual(inst.socket.family, socket.AF_VSOCK) + self.assertEqual(inst.socket.getsockname(), self.vsock_socket_cid) + + def test_handle_accept(self): + # Working on the assumption that we only have to test the happy path + # for Unix domain sockets as the other paths should've been covered + # by inet sockets. + client = self._makeDummy() + listen = self._makeDummy(acceptresult=(client, None)) + inst = self._makeOne(_sock=listen) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(client.opts, []) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) + + def test_creates_new_sockinfo(self): + from waitress.server import VsockWSGIServer + + self.inst = VsockWSGIServer( + dummy_app, + vsock_socket_cid=self.vsock_socket_cid, + vsock_socket_port=self.vsock_socket_port, + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + BaseWSGIServer, + MultiSocketServer, + TcpWSGIServer, + VsockWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) + ) + self.assertTrue(isinstance(server[0], VsockWSGIServer)) + self.assertTrue(isinstance(server[1], VsockWSGIServer)) + + class DummySock(socket.socket): accepted = False blocking = False