From 3573fd38d066b0702f605e561259c5fe027d4e1f Mon Sep 17 00:00:00 2001 From: Randall Leeds Date: Thu, 17 Dec 2020 16:24:37 -0500 Subject: [PATCH] Capture peer name from accept Avoid calls to getpeername by capturing the peer name returned by accept. --- gunicorn/http/message.py | 35 +++++++++++----------------------- gunicorn/http/parser.py | 5 +++-- gunicorn/workers/base_async.py | 2 +- gunicorn/workers/gthread.py | 2 +- gunicorn/workers/sync.py | 2 +- tests/t.py | 2 +- tests/treq.py | 4 ++-- 7 files changed, 20 insertions(+), 32 deletions(-) diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index c7f795ee5..17d22402b 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -6,9 +6,7 @@ import io import re import socket -from errno import ENOTCONN -from gunicorn.http.unreader import SocketUnreader from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body from gunicorn.http.errors import ( InvalidHeader, InvalidHeaderName, NoMoreData, @@ -29,9 +27,10 @@ class Message(object): - def __init__(self, cfg, unreader): + def __init__(self, cfg, unreader, peer_addr): self.cfg = cfg self.unreader = unreader + self.peer_addr = peer_addr self.version = None self.headers = [] self.trailers = [] @@ -69,16 +68,10 @@ def parse_headers(self, data): # handle scheme headers scheme_header = False secure_scheme_headers = {} - if '*' in cfg.forwarded_allow_ips: + if ('*' in cfg.forwarded_allow_ips or + not isinstance(self.peer_addr, tuple) + or self.peer_addr[0] in cfg.forwarded_allow_ips): secure_scheme_headers = cfg.secure_scheme_headers - elif isinstance(self.unreader, SocketUnreader): - remote_addr = self.unreader.sock.getpeername() - if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6): - remote_host = remote_addr[0] - if remote_host in cfg.forwarded_allow_ips: - secure_scheme_headers = cfg.secure_scheme_headers - elif self.unreader.sock.family == socket.AF_UNIX: - secure_scheme_headers = cfg.secure_scheme_headers # Parse headers into key/value pairs paying attention # to continuation lines. @@ -169,7 +162,7 @@ def should_close(self): class Request(Message): - def __init__(self, cfg, unreader, req_number=1): + def __init__(self, cfg, unreader, peer_addr, req_number=1): self.method = None self.uri = None self.path = None @@ -184,7 +177,7 @@ def __init__(self, cfg, unreader, req_number=1): self.req_number = req_number self.proxy_protocol_info = None - super().__init__(cfg, unreader) + super().__init__(cfg, unreader, peer_addr) def get_data(self, unreader, buf, stop=False): data = unreader.read() @@ -280,16 +273,10 @@ def proxy_protocol(self, line): def proxy_protocol_access_check(self): # check in allow list - if isinstance(self.unreader, SocketUnreader): - try: - remote_host = self.unreader.sock.getpeername()[0] - except socket.error as e: - if e.args[0] == ENOTCONN: - raise ForbiddenProxyRequest("UNKNOW") - raise - if ("*" not in self.cfg.proxy_allow_ips and - remote_host not in self.cfg.proxy_allow_ips): - raise ForbiddenProxyRequest(remote_host) + if ("*" not in self.cfg.proxy_allow_ips and + isinstance(self.peer_addr, tuple) and + self.peer_addr[0] not in self.cfg.proxy_allow_ips): + raise ForbiddenProxyRequest(self.peer_addr[0]) def parse_proxy_protocol(self, line): bits = line.split() diff --git a/gunicorn/http/parser.py b/gunicorn/http/parser.py index a4a0f1e48..5d689f06a 100644 --- a/gunicorn/http/parser.py +++ b/gunicorn/http/parser.py @@ -11,13 +11,14 @@ class Parser(object): mesg_class = None - def __init__(self, cfg, source): + def __init__(self, cfg, source, source_addr): self.cfg = cfg if hasattr(source, "recv"): self.unreader = SocketUnreader(source) else: self.unreader = IterUnreader(source) self.mesg = None + self.source_addr = source_addr # request counter (for keepalive connetions) self.req_count = 0 @@ -38,7 +39,7 @@ def __next__(self): # Parse the next request self.req_count += 1 - self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count) + self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count) if not self.mesg: raise StopIteration() return self.mesg diff --git a/gunicorn/workers/base_async.py b/gunicorn/workers/base_async.py index 6e7ae3f80..851ddc796 100644 --- a/gunicorn/workers/base_async.py +++ b/gunicorn/workers/base_async.py @@ -33,7 +33,7 @@ def is_already_handled(self, respiter): def handle(self, listener, client, addr): req = None try: - parser = http.RequestParser(self.cfg, client) + parser = http.RequestParser(self.cfg, client, addr) try: listener_name = listener.getsockname() if not self.cfg.keepalive: diff --git a/gunicorn/workers/gthread.py b/gunicorn/workers/gthread.py index 5cd3e8a59..09776aebb 100644 --- a/gunicorn/workers/gthread.py +++ b/gunicorn/workers/gthread.py @@ -53,7 +53,7 @@ def init(self): **self.cfg.ssl_options) # initialize the parser - self.parser = http.RequestParser(self.cfg, self.sock) + self.parser = http.RequestParser(self.cfg, self.sock, self.client) def set_timeout(self): # set the timeout diff --git a/gunicorn/workers/sync.py b/gunicorn/workers/sync.py index fd423bc9e..7831a43a8 100644 --- a/gunicorn/workers/sync.py +++ b/gunicorn/workers/sync.py @@ -131,7 +131,7 @@ def handle(self, listener, client, addr): client = ssl.wrap_socket(client, server_side=True, **self.cfg.ssl_options) - parser = http.RequestParser(self.cfg, client) + parser = http.RequestParser(self.cfg, client, addr) req = next(parser) self.handle_request(listener, req, client, addr) except http.errors.NoMoreData as e: diff --git a/tests/t.py b/tests/t.py index 539df27e8..9b76e7deb 100644 --- a/tests/t.py +++ b/tests/t.py @@ -29,7 +29,7 @@ def __init__(self, name): def __call__(self, func): def run(): src = data_source(self.fname) - func(src, RequestParser(src, None)) + func(src, RequestParser(src, None, None)) run.func_name = func.func_name return run diff --git a/tests/treq.py b/tests/treq.py index 9b6cdd1b9..ffe0691fd 100644 --- a/tests/treq.py +++ b/tests/treq.py @@ -245,7 +245,7 @@ def test_req(sn, sz, mt): def check(self, cfg, sender, sizer, matcher): cases = self.expect[:] - p = RequestParser(cfg, sender()) + p = RequestParser(cfg, sender(), None) for req in p: self.same(req, sizer, matcher, cases.pop(0)) assert not cases @@ -282,5 +282,5 @@ def send(self): read += chunk def check(self, cfg): - p = RequestParser(cfg, self.send()) + p = RequestParser(cfg, self.send(), None) next(p)