From 87d063329c97c00fb0988e2b255eaa827cccdd50 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Tue, 4 Feb 2020 18:03:01 -0800 Subject: [PATCH 01/19] Switch to HTTP1.1 so that we can reuse sockets. --- adafruit_requests.py | 316 +++++++++++++++++++++++++------------------ 1 file changed, 186 insertions(+), 130 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 0888937..27bbb2f 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -54,22 +54,40 @@ __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" -_the_interface = None # pylint: disable=invalid-name -_the_sock = None # pylint: disable=invalid-name +# This module mirror CPython's socket class. It is settable because the network devices are external +# to CircuitPython. +_socket = None # pylint: disable=invalid-name -def set_socket(sock, iface=None): - """Helper to set the global socket and optionally set the global network interface. - :param sock: socket object. - :param iface: internet interface object +def set_socket(socket): + """Helper to set the global socket. + :param sock: socket module """ - global _the_sock # pylint: disable=invalid-name, global-statement - _the_sock = sock - if iface: - global _the_interface # pylint: disable=invalid-name, global-statement - _the_interface = iface - _the_sock.set_interface(iface) + global _socket # pylint: disable=invalid-name, global-statement + _socket = socket + +# Hang onto open sockets so that we can reuse them. +_socket_pool = {} +def _get_socket(host, port, proto, *, timeout=1): + key = (host, port, proto) + if key in _socket_pool: + socket = _socket_pool[key] + if not socket.connected(): + del _socket_pool[key] + else: + return socket + addr_info = _socket.getaddrinfo(host, port, 0, _socket.SOCK_STREAM)[0] + sock = _socket.socket(addr_info[0], addr_info[1], addr_info[2]) + sock.settimeout(timeout) # socket read timeout + + if proto == "https:": + # for SSL we need to know the host name + sock.connect((host, port), _socket.TLS_MODE) + else: + sock.connect(addr_info[-1], _socket.TCP_MODE) + _socket_pool[key] = sock + return sock class Response: @@ -81,10 +99,20 @@ def __init__(self, sock): self.socket = sock self.encoding = "utf-8" self._cached = None - self.status_code = None - self.reason = None - self._read_so_far = 0 - self.headers = {} + self._headers = {} + + if not self.socket.connected(): + raise RuntimeError("We were too slow") + + line = self.socket.readline() + if line is None: + raise RuntimeError("Unable to read HTTP response.") + line = line.split(b" ", 2) + self.status_code = int(line[1]) + self.reason = "" + if len(line) > 2: + self.reason = line[2].rstrip() + self._parse_headers() def __enter__(self): return self @@ -95,26 +123,83 @@ def __exit__(self, exc_type, exc_value, traceback): def close(self): """Close, delete and collect the response data""" if self.socket: - self.socket.close() - del self.socket + # Make sure we've read all of our response. + content_length = None + if "content-length" in self.headers: + content_length = int(self.headers["content-length"]) + + # print("Content length:", content_length) + if self._cached is None: + if content_length: + self.socket.recv(content_length) + else: + while True: + chunk_header = self.socket.readline() + if b";" in chunk_header: + chunk_header = chunk_header.split(b";")[0] + chunk_size = int(chunk_header, 16) + if chunk_size == 0: + break + self.socket.read(chunk_size + 2) + self._parse_headers() + self.socket = None del self._cached gc.collect() + def _parse_headers(self): + """ + Parses the header portion of an HTTP request/response from the socket. + Expects first line of HTTP request/response to have been read already. + """ + while True: + line = self.socket.readline() + if not line or line == b"\r\n": + break + + #print("**line: ", line) + splits = line.split(b': ', 1) + title = splits[0] + content = '' + if len(splits) > 1: + content = splits[1] + if title and content: + title = str(title.lower(), 'utf-8') + content = str(content, 'utf-8') + self._headers[title] = content + + @property + def headers(self): + return self._headers + @property def content(self): """The HTTP content direct from the socket, as bytes""" # print(self.headers) - try: + content_length = None + if "content-length" in self.headers: content_length = int(self.headers["content-length"]) - except KeyError: - content_length = 0 + # print("Content length:", content_length) if self._cached is None: - try: + if content_length: self._cached = self.socket.recv(content_length) - finally: - self.socket.close() - self.socket = None + else: + chunks = [] + while True: + chunk_header = self.socket.readline() + if chunk_header is None: + raise RuntimeError("Reading content timed out.") + if b";" in chunk_header: + chunk_header = chunk_header.split(b";")[0] + chunk_size = int(chunk_header, 16) + if chunk_size == 0: + break + chunks.append(self.socket.read(chunk_size)) + self.socket.read(2) # Read the trailing CR LF + self._parse_headers() + self._cached = b"".join(chunks) + self.socket = None + # print("Buffer length:", len(self._cached)) return self._cached @@ -131,6 +216,7 @@ def json(self): import json as json_module except ImportError: import ujson as json_module + return json_module.loads(self.content) def iter_content(self, chunk_size=1, decode_unicode=False): @@ -139,16 +225,47 @@ def iter_content(self, chunk_size=1, decode_unicode=False): if decode_unicode: raise NotImplementedError("Unicode not supported") - while True: - chunk = self.socket.recv(chunk_size) - if chunk: + content_length = None + if "content-length" in self.headers: + content_length = int(self.headers["content-length"]) + + total_read = 0 + if content_length: + while total_read < content_length: + chunk = self.socket.recv(chunk_size) + total_read += chunk_size yield chunk - else: - return + else: + pending_bytes = 0 + chunks = [] + while True: + chunk_header = self.socket.readline() + if b";" in chunk_header: + chunk_header = chunk_header.split(b";")[0] + http_chunk_size = int(chunk_header, 16) + if http_chunk_size == 0: + break + remaining_in_http_chunk = http_chunk_size + while remaining_in_http_chunk: + read_now = chunk_size - pending_bytes + if read_now > remaining_in_http_chunk: + read_now = remaining_in_http_chunk + chunks.append(self.socket.read(read_now)) + pending_bytes += read_now + if pending_bytes == chunk_size: + yield b"".join(chunks) + pending_bytes = 0 + chunks = [] + + self.socket.read(2) # Read the trailing CR LF + self._parse_headers() + if chunks: + yield b"".join(chunks) + self.socket = None # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals -def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): +def request(method, url, data=None, json=None, headers=None, stream=False, timeout=60): """Perform an HTTP request to the given url which we will parse to determine whether to use SSL ('https://') or not. We can also send some provided 'data' or a json dictionary which we will stringify. 'headers' is optional HTTP headers @@ -156,7 +273,7 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo read only when requested """ global _the_interface # pylint: disable=global-statement, invalid-name - global _the_sock # pylint: disable=global-statement, invalid-name + global _socket # pylint: disable=global-statement, invalid-name if not headers: headers = {} @@ -179,107 +296,46 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo host, port = host.split(":", 1) port = int(port) - addr_info = _the_sock.getaddrinfo(host, port, 0, _the_sock.SOCK_STREAM)[0] - sock = _the_sock.socket(addr_info[0], addr_info[1], addr_info[2]) - resp = Response(sock) # our response - - sock.settimeout(timeout) # socket read timeout - - try: - if proto == "https:": - conntype = _the_interface.TLS_MODE - sock.connect( - (host, port), conntype - ) # for SSL we need to know the host name + socket = _get_socket(host, port, proto, timeout=timeout) + socket.send(b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))) + if "Host" not in headers: + socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) + if "User-Agent" not in headers: + socket.send(b"User-Agent: Adafruit CircuitPython\r\n") + # Iterate over keys to avoid tuple alloc + for k in headers: + socket.send(k.encode()) + socket.send(b": ") + socket.send(headers[k].encode()) + socket.send(b"\r\n") + if json is not None: + assert data is None + try: + import json as json_module + except ImportError: + import ujson as json_module + data = json_module.dumps(json) + socket.send(b"Content-Type: application/json\r\n") + if data: + if isinstance(data, dict): + sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") + _post_data = "" + for k in data: + _post_data = "{}&{}={}".format(_post_data, k, data[k]) + data = _post_data[1:] + socket.send(b"Content-Length: %d\r\n" % len(data)) + socket.send(b"\r\n") + if data: + if isinstance(data, bytearray): + socket.send(bytes(data)) else: - conntype = _the_interface.TCP_MODE - sock.connect(addr_info[-1], conntype) - sock.send( - b"%s /%s HTTP/1.0\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")) - ) - if "Host" not in headers: - sock.send(b"Host: %s\r\n" % bytes(host, "utf-8")) - if "User-Agent" not in headers: - sock.send(b"User-Agent: Adafruit CircuitPython\r\n") - # Iterate over keys to avoid tuple alloc - for k in headers: - sock.send(k.encode()) - sock.send(b": ") - sock.send(headers[k].encode()) - sock.send(b"\r\n") - if json is not None: - assert data is None - # pylint: disable=import-outside-toplevel - try: - import json as json_module - except ImportError: - import ujson as json_module - # pylint: enable=import-outside-toplevel - data = json_module.dumps(json) - sock.send(b"Content-Type: application/json\r\n") - if data: - if isinstance(data, dict): - sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") - _post_data = "" - for k in data: - _post_data = "{}&{}={}".format(_post_data, k, data[k]) - data = _post_data[1:] - sock.send(b"Content-Length: %d\r\n" % len(data)) - sock.send(b"\r\n") - if data: - if isinstance(data, bytearray): - sock.send(bytes(data)) - else: - sock.send(bytes(data, "utf-8")) - - line = sock.readline() - # print(line) - line = line.split(None, 2) - status = int(line[1]) - reason = "" - if len(line) > 2: - reason = line[2].rstrip() - resp.headers = parse_headers(sock) - if resp.headers.get("transfer-encoding"): - if "chunked" in resp.headers.get("transfer-encoding"): - raise ValueError("Unsupported " + resp.headers.get("transfer-encoding")) - elif resp.headers.get("location") and not 200 <= status <= 299: - raise NotImplementedError("Redirects not yet supported") - - except: - sock.close() - raise - - resp.status_code = status - resp.reason = reason - return resp + socket.send(bytes(data, "utf-8")) + resp = Response(socket) # our response + if "location" in resp.headers and not 200 <= resp.status_code <= 299: + raise NotImplementedError("Redirects not yet supported") -def parse_headers(sock): - """ - Parses the header portion of an HTTP request/response from the socket. - Expects first line of HTTP request/response to have been read already - return: header dictionary - rtype: Dict - """ - headers = {} - while True: - line = sock.readline() - if not line or line == b"\r\n": - break - - # print("**line: ", line) - splits = line.split(b": ", 1) - title = splits[0] - content = "" - if len(splits) > 1: - content = splits[1] - if title and content: - title = str(title.lower(), "utf-8") - content = str(content, "utf-8") - headers[title] = content - return headers - + return resp def head(url, **kw): """Send HTTP HEAD request""" From 8d81d58b87086c1ff9b7c64918d8897694961abb Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 12 Feb 2020 12:28:07 -0800 Subject: [PATCH 02/19] Update examples --- examples/requests_advanced.py | 10 ++++++++-- examples/requests_simpletest.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/requests_advanced.py b/examples/requests_advanced.py index d74f9af..ec96596 100644 --- a/examples/requests_advanced.py +++ b/examples/requests_advanced.py @@ -5,6 +5,11 @@ from adafruit_esp32spi import adafruit_esp32spi import adafruit_requests as requests +# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and +# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other +# source control. +from secrets import secrets + # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) esp32_ready = DigitalInOut(board.ESP_BUSY) @@ -21,14 +26,15 @@ print("Connecting to AP...") while not esp.is_connected: try: - esp.connect_AP(b"MY_SSID_NAME", b"MY_SSID_PASSWORD") + esp.connect_AP(secrets["ssid"], secrets["password"]) except RuntimeError as e: print("could not connect to AP, retrying: ", e) continue print("Connected to", str(esp.ssid, "utf-8"), "\tRSSI:", esp.rssi) # Initialize a requests object with a socket and esp32spi interface -requests.set_socket(socket, esp) +socket.set_interface(esp) +requests.set_socket(socket) JSON_GET_URL = "http://httpbin.org/get" diff --git a/examples/requests_simpletest.py b/examples/requests_simpletest.py index 444648b..29ca6ed 100755 --- a/examples/requests_simpletest.py +++ b/examples/requests_simpletest.py @@ -6,6 +6,11 @@ from adafruit_esp32spi import adafruit_esp32spi import adafruit_requests as requests +# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and +# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other +# source control. +from secrets import secrets + # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) esp32_ready = DigitalInOut(board.ESP_BUSY) @@ -22,14 +27,15 @@ print("Connecting to AP...") while not esp.is_connected: try: - esp.connect_AP(b"MY_SSID_NAME", b"MY_SSID_PASSWORD") + esp.connect_AP(secrets["ssid"], secrets["password"]) except RuntimeError as e: print("could not connect to AP, retrying: ", e) continue print("Connected to", str(esp.ssid, "utf-8"), "\tRSSI:", esp.rssi) # Initialize a requests object with a socket and esp32spi interface -requests.set_socket(socket, esp) +socket.set_interface(esp) +requests.set_socket(socket) TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" JSON_GET_URL = "http://httpbin.org/get" From a3c61b3ace69a78f9c3612ba14e5e6db51fdd628 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 12 Feb 2020 12:37:42 -0800 Subject: [PATCH 03/19] Lint it all. --- adafruit_requests.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 27bbb2f..76d4a59 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -68,7 +68,7 @@ def set_socket(socket): _socket = socket # Hang onto open sockets so that we can reuse them. -_socket_pool = {} +_socket_pool = {} # pylint: disable=invalid-name def _get_socket(host, port, proto, *, timeout=1): key = (host, port, proto) if key in _socket_pool: @@ -169,6 +169,10 @@ def _parse_headers(self): @property def headers(self): + """ + The response headers. Does not include headers from the trailer until + the content has been read. + """ return self._headers @property @@ -272,9 +276,6 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo sent along. 'stream' will determine if we buffer everything, or whether to only read only when requested """ - global _the_interface # pylint: disable=global-statement, invalid-name - global _socket # pylint: disable=global-statement, invalid-name - if not headers: headers = {} @@ -310,6 +311,7 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo socket.send(b"\r\n") if json is not None: assert data is None + # pylint: disable=import-outside-toplevel try: import json as json_module except ImportError: From 6c204254510a443d8e4c170b3cc1fbcc78a9c4bf Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 12 Feb 2020 16:09:07 -0800 Subject: [PATCH 04/19] Exempt secrets from lint since CPython has its own version. --- examples/requests_advanced.py | 1 + examples/requests_simpletest.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/requests_advanced.py b/examples/requests_advanced.py index ec96596..1c89d91 100644 --- a/examples/requests_advanced.py +++ b/examples/requests_advanced.py @@ -8,6 +8,7 @@ # Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and # "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other # source control. +# pylint: disable=no-name-in-module,wrong-import-order from secrets import secrets # If you are using a board with pre-defined ESP32 Pins: diff --git a/examples/requests_simpletest.py b/examples/requests_simpletest.py index 29ca6ed..20b6004 100755 --- a/examples/requests_simpletest.py +++ b/examples/requests_simpletest.py @@ -9,6 +9,7 @@ # Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and # "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other # source control. +# pylint: disable=no-name-in-module,wrong-import-order from secrets import secrets # If you are using a board with pre-defined ESP32 Pins: From 956063ade0d53cd0475915d16abdd3ec09ad9d07 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Tue, 30 Jun 2020 18:37:46 -0700 Subject: [PATCH 05/19] Inital redo of readline within requests --- adafruit_requests.py | 163 ++++++++++++++---------- examples/requests_simpletest_cpython.py | 46 +++++++ 2 files changed, 142 insertions(+), 67 deletions(-) create mode 100755 examples/requests_simpletest_cpython.py diff --git a/adafruit_requests.py b/adafruit_requests.py index 76d4a59..fc175f7 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -56,40 +56,24 @@ # This module mirror CPython's socket class. It is settable because the network devices are external # to CircuitPython. -_socket = None # pylint: disable=invalid-name - - -def set_socket(socket): - """Helper to set the global socket. - :param sock: socket module - - """ - global _socket # pylint: disable=invalid-name, global-statement - _socket = socket +socket_module = None # pylint: disable=invalid-name # Hang onto open sockets so that we can reuse them. _socket_pool = {} # pylint: disable=invalid-name def _get_socket(host, port, proto, *, timeout=1): key = (host, port, proto) if key in _socket_pool: - socket = _socket_pool[key] - if not socket.connected(): - del _socket_pool[key] - else: - return socket - addr_info = _socket.getaddrinfo(host, port, 0, _socket.SOCK_STREAM)[0] - sock = _socket.socket(addr_info[0], addr_info[1], addr_info[2]) + return _socket_pool[key] + if not socket_module: + raise RuntimeError("socket_module must be set before using adafruit_requests") + addr_info = socket_module.getaddrinfo(host, port, 0, socket_module.SOCK_STREAM)[0] + sock = socket_module.socket(addr_info[0], addr_info[1], addr_info[2]) sock.settimeout(timeout) # socket read timeout - if proto == "https:": - # for SSL we need to know the host name - sock.connect((host, port), _socket.TLS_MODE) - else: - sock.connect(addr_info[-1], _socket.TCP_MODE) + sock.connect((host, port)) _socket_pool[key] = sock return sock - class Response: """The response from a request, contains all the headers/content""" @@ -100,18 +84,16 @@ def __init__(self, sock): self.encoding = "utf-8" self._cached = None self._headers = {} + # 0 means the first receive buffer is empty because we always consume some of it. non-zero + # means we need to look at it's tail for our pattern. + self._start_index = 0 + self._receive_buffers = [bytearray(32)] - if not self.socket.connected(): - raise RuntimeError("We were too slow") - - line = self.socket.readline() - if line is None: + http = self._readto(b" ") + if not http: raise RuntimeError("Unable to read HTTP response.") - line = line.split(b" ", 2) - self.status_code = int(line[1]) - self.reason = "" - if len(line) > 2: - self.reason = line[2].rstrip() + self.status_code = int(self._readto(b" ")) + self.reason = self._readto(b"\r\n") self._parse_headers() def __enter__(self): @@ -120,6 +102,73 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def _readto(self, first, second=b""): + # TODO: Make this work if either pattern spans buffers. + current_buffer = 0 + found = -1 + new_start = 0 + if self._start_index > 0: + first_i = self._receive_buffers[0].find(first, self._start_index) + if second: + second_i = self._receive_buffers[0].find(second, self._start_index) + if second_i >= 0 and (first_i < 0 or first_i > second_i): + found = second_i + new_start = second_i + len(second) + + if found == -1: + if first_i > 0: + found = first_i + new_start = first_i + len(first) + else: + current_buffer = 1 + + while found < 0: + if current_buffer == len(self._receive_buffers): + self._receive_buffers.append(bytearray(len(self._receive_buffers[0]))) + buf = self._receive_buffers[current_buffer] + size = self.socket.recv_into(buf) + if size != len(buf): + raise RuntimeError() + + first_i = buf.find(first) + if second: + second_i = buf.find(second) + if second_i >= 0 and (first_i < 0 or first_i > second_i): + found = second_i + new_start = second_i + len(second) + if found == -1: + if first_i >= 0: + found = first_i + new_start = first_i + len(first) + else: + current_buffer += 1 + + if current_buffer == 0: + b = bytes(self._receive_buffers[0][self._start_index:found]) + self._start_index = new_start + else: + new_len = len(self._receive_buffers[0]) * current_buffer + found - self._start_index + b = bytearray(new_len) + i = 0 + for bufi in range(0, current_buffer + 1): + buf = self._receive_buffers[bufi] + if bufi == 0 and self._start_index > 0: + i = len(buf) - self._start_index + b[:i] = buf[self._start_index:] + elif bufi == current_buffer: + b[i:i+found] = buf[:found] + else: + b[i:i+len(buf)] = buf + i += len(buf) + + self._start_index = new_start + # Swap the current buffer to the front because it has some bytes we + # need to search. + last_buf = self._receive_buffers[current_buffer] + self._receive_buffers[current_buffer] = self._receive_buffers[0] + self._receive_buffers[0] = last_buf + return b + def close(self): """Close, delete and collect the response data""" if self.socket: @@ -134,9 +183,7 @@ def close(self): self.socket.recv(content_length) else: while True: - chunk_header = self.socket.readline() - if b";" in chunk_header: - chunk_header = chunk_header.split(b";")[0] + chunk_header = self._readto(b";", b"\r\n") chunk_size = int(chunk_header, 16) if chunk_size == 0: break @@ -152,16 +199,11 @@ def _parse_headers(self): Expects first line of HTTP request/response to have been read already. """ while True: - line = self.socket.readline() - if not line or line == b"\r\n": + title = self._readto(b": ", b"\r\n") + if not title: break - #print("**line: ", line) - splits = line.split(b': ', 1) - title = splits[0] - content = '' - if len(splits) > 1: - content = splits[1] + content = self._readto(b"\r\n") if title and content: title = str(title.lower(), 'utf-8') content = str(content, 'utf-8') @@ -185,24 +227,7 @@ def content(self): # print("Content length:", content_length) if self._cached is None: - if content_length: - self._cached = self.socket.recv(content_length) - else: - chunks = [] - while True: - chunk_header = self.socket.readline() - if chunk_header is None: - raise RuntimeError("Reading content timed out.") - if b";" in chunk_header: - chunk_header = chunk_header.split(b";")[0] - chunk_size = int(chunk_header, 16) - if chunk_size == 0: - break - chunks.append(self.socket.read(chunk_size)) - self.socket.read(2) # Read the trailing CR LF - self._parse_headers() - self._cached = b"".join(chunks) - self.socket = None + self._cached = b"".join(self.iter_content(chunk_size=32)) # print("Buffer length:", len(self._cached)) return self._cached @@ -236,16 +261,20 @@ def iter_content(self, chunk_size=1, decode_unicode=False): total_read = 0 if content_length: while total_read < content_length: - chunk = self.socket.recv(chunk_size) - total_read += chunk_size + if total_read == 0 and self._start_index > 0: + chunk = bytearray(chunk_size) + left = len(self._receive_buffers[0]) - self._start_index + chunk = b"".join((self._receive_buffers[0][self._start_index:], + self.socket.recv(chunk_size - left))) + else: + chunk = self.socket.recv(chunk_size) + total_read += len(chunk) yield chunk else: pending_bytes = 0 chunks = [] while True: - chunk_header = self.socket.readline() - if b";" in chunk_header: - chunk_header = chunk_header.split(b";")[0] + chunk_header = self._readto(b";", b"\r\n") http_chunk_size = int(chunk_header, 16) if http_chunk_size == 0: break diff --git a/examples/requests_simpletest_cpython.py b/examples/requests_simpletest_cpython.py new file mode 100755 index 0000000..5ca5536 --- /dev/null +++ b/examples/requests_simpletest_cpython.py @@ -0,0 +1,46 @@ +# adafruit_requests usage with a CPython socket +import socket +import adafruit_requests as requests +requests.socket_module = socket + +TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" +JSON_GET_URL = "http://httpbin.org/get" +JSON_POST_URL = "http://httpbin.org/post" + +print("Fetching text from %s" % TEXT_URL) +response = requests.get(TEXT_URL) +print("-" * 40) + +print("Text Response: ", response.text) +print("-" * 40) +response.close() + +print("Fetching JSON data from %s" % JSON_GET_URL) +response = requests.get(JSON_GET_URL) +print("-" * 40) + +print("JSON Response: ", response.json()) +print("-" * 40) +response.close() + +data = "31F" +print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) +response = requests.post(JSON_POST_URL, data=data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'data' key from json_resp dict. +print("Data received from server:", json_resp["data"]) +print("-" * 40) +response.close() + +json_data = {"Date": "July 25, 2019"} +print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) +response = requests.post(JSON_POST_URL, json=json_data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'json' key from json_resp dict. +print("JSON Data received from server:", json_resp["json"]) +print("-" * 40) +response.close() From 43defd6ebedc43d4032b8c923317594264c391ab Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 3 Jul 2020 15:07:55 -0700 Subject: [PATCH 06/19] Switch to _readto from readline --- adafruit_requests.py | 94 +++++++++++++++++++-------- examples/requests_advanced_cpython.py | 27 ++++++++ examples/requests_https_cpython.py | 48 ++++++++++++++ 3 files changed, 143 insertions(+), 26 deletions(-) create mode 100644 examples/requests_advanced_cpython.py create mode 100755 examples/requests_https_cpython.py diff --git a/adafruit_requests.py b/adafruit_requests.py index fc175f7..fb4983c 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -57,6 +57,7 @@ # This module mirror CPython's socket class. It is settable because the network devices are external # to CircuitPython. socket_module = None # pylint: disable=invalid-name +ssl_context = None # Hang onto open sockets so that we can reuse them. _socket_pool = {} # pylint: disable=invalid-name @@ -66,8 +67,12 @@ def _get_socket(host, port, proto, *, timeout=1): return _socket_pool[key] if not socket_module: raise RuntimeError("socket_module must be set before using adafruit_requests") + if proto == "https:" and not ssl_context: + raise RuntimeError("ssl_context must be set before using adafruit_requests for https") addr_info = socket_module.getaddrinfo(host, port, 0, socket_module.SOCK_STREAM)[0] sock = socket_module.socket(addr_info[0], addr_info[1], addr_info[2]) + if proto == "https:": + sock = ssl_context.wrap_socket(sock, server_hostname=host) sock.settimeout(timeout) # socket read timeout sock.connect((host, port)) @@ -87,6 +92,7 @@ def __init__(self, sock): # 0 means the first receive buffer is empty because we always consume some of it. non-zero # means we need to look at it's tail for our pattern. self._start_index = 0 + self._buffer_sizes = [0] self._receive_buffers = [bytearray(32)] http = self._readto(b" ") @@ -104,44 +110,62 @@ def __exit__(self, exc_type, exc_value, traceback): def _readto(self, first, second=b""): # TODO: Make this work if either pattern spans buffers. + if len(first) > 2 or len(second) > 2: + raise ValueError("Pattern too long. Must be less than 3 characters.") current_buffer = 0 - found = -1 + found = -2 new_start = 0 - if self._start_index > 0: + if self._start_index < self._buffer_sizes[0]: first_i = self._receive_buffers[0].find(first, self._start_index) if second: second_i = self._receive_buffers[0].find(second, self._start_index) - if second_i >= 0 and (first_i < 0 or first_i > second_i): + if second_i >= 0 and (first_i <= -1 or first_i > second_i): found = second_i new_start = second_i + len(second) - if found == -1: - if first_i > 0: + if found == -2: + if first_i >= 0: found = first_i new_start = first_i + len(first) else: current_buffer = 1 - while found < 0: + while found < -1: if current_buffer == len(self._receive_buffers): self._receive_buffers.append(bytearray(len(self._receive_buffers[0]))) + self._buffer_sizes.append(0) buf = self._receive_buffers[current_buffer] size = self.socket.recv_into(buf) - if size != len(buf): - raise RuntimeError() + self._buffer_sizes[current_buffer] = size + + if len(first) == 2: + previous_size = self._buffer_sizes[current_buffer - 1] + if (self._receive_buffers[current_buffer - 1][previous_size - 1] == first[0] and + buf[0] == first[1]): + found = -1 + new_start = 1 + break - first_i = buf.find(first) + first_i = buf.find(first, 0, size) if second: - second_i = buf.find(second) + if len(second) == 2: + previous_size = self._buffer_sizes[current_buffer - 1] + if (self._receive_buffers[current_buffer - 1][previous_size - 1] == second[0] and + buf[0] == second[1]): + found = -1 + new_start = 1 + break + second_i = buf.find(second, 0, size) if second_i >= 0 and (first_i < 0 or first_i > second_i): found = second_i new_start = second_i + len(second) - if found == -1: - if first_i >= 0: - found = first_i - new_start = first_i + len(first) - else: - current_buffer += 1 + break + if first_i >= 0: + found = first_i + new_start = first_i + len(first) + break + current_buffer += 1 + if current_buffer == 0: b = bytes(self._receive_buffers[0][self._start_index:found]) @@ -152,21 +176,26 @@ def _readto(self, first, second=b""): i = 0 for bufi in range(0, current_buffer + 1): buf = self._receive_buffers[bufi] + size = self._buffer_sizes[bufi] if bufi == 0 and self._start_index > 0: - i = len(buf) - self._start_index - b[:i] = buf[self._start_index:] + i = size - self._start_index + b[:i] = buf[self._start_index:size] elif bufi == current_buffer: - b[i:i+found] = buf[:found] + if found > 0: + b[i:i+found] = buf[:found] else: - b[i:i+len(buf)] = buf - i += len(buf) + b[i:i+size] = buf[:size] + i += size self._start_index = new_start # Swap the current buffer to the front because it has some bytes we # need to search. last_buf = self._receive_buffers[current_buffer] self._receive_buffers[current_buffer] = self._receive_buffers[0] + self._buffer_sizes[0] = self._buffer_sizes[current_buffer] self._receive_buffers[0] = last_buf + self._buffer_sizes[current_buffer] = 0 + return b def close(self): @@ -230,8 +259,18 @@ def content(self): self._cached = b"".join(self.iter_content(chunk_size=32)) # print("Buffer length:", len(self._cached)) + # print("content", self._cached) return self._cached + def read(self, size=-1): + if size == -1: + return self.content + else: + raise NotImplementedError() + + def readinto(self, buf): + pass + @property def text(self): """The HTTP content, encoded into a string according to the HTTP @@ -246,7 +285,7 @@ def json(self): except ImportError: import ujson as json_module - return json_module.loads(self.content) + return json_module.load(self) def iter_content(self, chunk_size=1, decode_unicode=False): """An iterator that will stream data by only reading 'chunk_size' @@ -261,13 +300,16 @@ def iter_content(self, chunk_size=1, decode_unicode=False): total_read = 0 if content_length: while total_read < content_length: - if total_read == 0 and self._start_index > 0: - chunk = bytearray(chunk_size) - left = len(self._receive_buffers[0]) - self._start_index - chunk = b"".join((self._receive_buffers[0][self._start_index:], + if total_read == 0 and self._start_index < self._buffer_sizes[0]: + # print("remaining", self._start_index, self._buffer_sizes[0], self._receive_buffers[0]) + size = self._buffer_sizes[0] + left = size - self._start_index + chunk = b"".join((self._receive_buffers[0][self._start_index:size], self.socket.recv(chunk_size - left))) + # print("remaining", chunk) else: chunk = self.socket.recv(chunk_size) + # print("recv", chunk) total_read += len(chunk) yield chunk else: diff --git a/examples/requests_advanced_cpython.py b/examples/requests_advanced_cpython.py new file mode 100644 index 0000000..725974f --- /dev/null +++ b/examples/requests_advanced_cpython.py @@ -0,0 +1,27 @@ +import socket +import adafruit_requests as requests +requests.socket_module = socket + +JSON_GET_URL = "http://httpbin.org/get" + +# Define a custom header as a dict. +headers = {"user-agent": "blinka/1.0.0"} + +print("Fetching JSON data from %s..." % JSON_GET_URL) +response = requests.get(JSON_GET_URL, headers=headers) +print("-" * 60) + +json_data = response.json() +headers = json_data["headers"] +print("Response's Custom User-Agent Header: {0}".format(headers["User-Agent"])) +print("-" * 60) + +# Read Response's HTTP status code +print("Response HTTP Status Code: ", response.status_code) +print("-" * 60) + +# Read Response, as raw bytes instead of pretty text +print("Raw Response: ", response.content) + +# Close, delete and collect the response data +response.close() diff --git a/examples/requests_https_cpython.py b/examples/requests_https_cpython.py new file mode 100755 index 0000000..ca2b5a4 --- /dev/null +++ b/examples/requests_https_cpython.py @@ -0,0 +1,48 @@ +# adafruit_requests usage with a CPython socket +import socket +import ssl +import adafruit_requests as requests +requests.socket_module = socket +requests.ssl_context = ssl.create_default_context() + +TEXT_URL = "https://wifitest.adafruit.com/testwifi/index.html" +JSON_GET_URL = "https://httpbin.org/get" +JSON_POST_URL = "https://httpbin.org/post" + +# print("Fetching text from %s" % TEXT_URL) +# response = requests.get(TEXT_URL) +# print("-" * 40) + +# print("Text Response: ", response.text) +# print("-" * 40) +# response.close() + +print("Fetching JSON data from %s" % JSON_GET_URL) +response = requests.get(JSON_GET_URL) +print("-" * 40) + +print("JSON Response: ", response.json()) +print("-" * 40) +response.close() + +data = "31F" +print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) +response = requests.post(JSON_POST_URL, data=data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'data' key from json_resp dict. +print("Data received from server:", json_resp["data"]) +print("-" * 40) +response.close() + +json_data = {"Date": "July 25, 2019"} +print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) +response = requests.post(JSON_POST_URL, json=json_data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'json' key from json_resp dict. +print("JSON Data received from server:", json_resp["json"]) +print("-" * 40) +response.close() From b4e83a588f2ab571219df3e6dc7f7543bf2e1337 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 17 Jul 2020 17:48:38 -0700 Subject: [PATCH 07/19] WIP, switching to json stream parsing --- adafruit_requests.py | 362 ++++++++++++++++++++++++------------------- 1 file changed, 205 insertions(+), 157 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index fb4983c..abd1243 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -54,30 +54,18 @@ __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" -# This module mirror CPython's socket class. It is settable because the network devices are external -# to CircuitPython. -socket_module = None # pylint: disable=invalid-name -ssl_context = None - -# Hang onto open sockets so that we can reuse them. -_socket_pool = {} # pylint: disable=invalid-name -def _get_socket(host, port, proto, *, timeout=1): - key = (host, port, proto) - if key in _socket_pool: - return _socket_pool[key] - if not socket_module: - raise RuntimeError("socket_module must be set before using adafruit_requests") - if proto == "https:" and not ssl_context: - raise RuntimeError("ssl_context must be set before using adafruit_requests for https") - addr_info = socket_module.getaddrinfo(host, port, 0, socket_module.SOCK_STREAM)[0] - sock = socket_module.socket(addr_info[0], addr_info[1], addr_info[2]) - if proto == "https:": - sock = ssl_context.wrap_socket(sock, server_hostname=host) - sock.settimeout(timeout) # socket read timeout - - sock.connect((host, port)) - _socket_pool[key] = sock - return sock +class _RawResponse: + def __init__(self, response): + self._response = response + + def read(self, size=-1): + if size == -1: + return self.content + + return self._response.socket.recv(size) + + def readinto(self, buf): + return self._response._readinto(buf) class Response: """The response from a request, contains all the headers/content""" @@ -94,6 +82,7 @@ def __init__(self, sock): self._start_index = 0 self._buffer_sizes = [0] self._receive_buffers = [bytearray(32)] + self._content_length = None http = self._readto(b" ") if not http: @@ -101,6 +90,8 @@ def __init__(self, sock): self.status_code = int(self._readto(b" ")) self.reason = self._readto(b"\r\n") self._parse_headers() + self._raw = None + self._content_read = 0 def __enter__(self): return self @@ -198,29 +189,58 @@ def _readto(self, first, second=b""): return b - def close(self): - """Close, delete and collect the response data""" + def _readinto(self, buf): + remaining = self._content_length - self._content_read + nbytes = len(buf) + if nbytes > remaining: + nbytes = remaining + + if self._start_index < self._buffer_sizes[0]: + # print("remaining", self._start_index, self._buffer_sizes[0], self._receive_buffers[0]) + size = self._buffer_sizes[0] + left = size - self._start_index + if nbytes < left: + left = nbytes + start = self._start_index + end = start + left + if left == 1: + buf[0] = self._receive_buffers[0][start] + else: + buf[:] = self._receive_buffers[0][start:end] + read = left + self._start_index += left + if read < nbytes: + read += self.socket.recv_into(memoryview(buf)[left:nbytes]) + else: + read = self.socket.recv_into(buf, nbytes) + self._content_read += read + return read + + def _throw_away(self, nbytes): + buf = self._receive_buffers[0] + for i in range(nbytes // len(buf)): + self.socket.recv_into(buf) + remaining = nbytes % len(buf) + if remaining: + self.socket.recv_into(buf, remaining) + + def _close(self): + """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" if self.socket: # Make sure we've read all of our response. - content_length = None - if "content-length" in self.headers: - content_length = int(self.headers["content-length"]) - # print("Content length:", content_length) if self._cached is None: - if content_length: - self.socket.recv(content_length) + if self._content_length: + self._throw_away(self._content_length) else: while True: chunk_header = self._readto(b";", b"\r\n") chunk_size = int(chunk_header, 16) if chunk_size == 0: break - self.socket.read(chunk_size + 2) + self._throw_away(chunk_size + 2) self._parse_headers() self.socket = None - del self._cached - gc.collect() def _parse_headers(self): """ @@ -234,8 +254,11 @@ def _parse_headers(self): content = self._readto(b"\r\n") if title and content: - title = str(title.lower(), 'utf-8') + title = str(title, 'utf-8') content = str(content, 'utf-8') + # Check len first so we can skip the .lower allocation most of the time. + if len(title) == len("content-length") and title.lower() == "content-length": + self._content_length = int(content) self._headers[title] = content @property @@ -249,33 +272,24 @@ def headers(self): @property def content(self): """The HTTP content direct from the socket, as bytes""" - # print(self.headers) - content_length = None - if "content-length" in self.headers: - content_length = int(self.headers["content-length"]) - - # print("Content length:", content_length) - if self._cached is None: - self._cached = b"".join(self.iter_content(chunk_size=32)) + if self._cached is not None: + if isinstance(self._cached, bytes): + return self._cached + raise RuntimeError("Cannot access content after getting text or json") - # print("Buffer length:", len(self._cached)) - # print("content", self._cached) + self._cached = b"".join(self.iter_content(chunk_size=32)) return self._cached - def read(self, size=-1): - if size == -1: - return self.content - else: - raise NotImplementedError() - - def readinto(self, buf): - pass - @property def text(self): """The HTTP content, encoded into a string according to the HTTP header encoding""" - return str(self.content, self.encoding) + if self._cached is not None: + if isinstance(self._cached, str): + return self._cached + raise RuntimeError("Cannot access text after getting content or json") + self._cached = str(self.content, self.encoding) + return self._cached def json(self): """The HTTP content, parsed into a json dictionary""" @@ -285,7 +299,18 @@ def json(self): except ImportError: import ujson as json_module - return json_module.load(self) + # The cached JSON will be a list or dictionary. + if self._cached: + if isinstance(self._cached, (list, dict)): + return self._cached + raise RuntimeError("Cannot access json after getting text or content") + if not self._raw: + self._raw = _RawResponse(self) + + obj = json_module.load(self._raw) + if not self._cached: + self._cached = obj + return obj def iter_content(self, chunk_size=1, decode_unicode=False): """An iterator that will stream data by only reading 'chunk_size' @@ -293,24 +318,19 @@ def iter_content(self, chunk_size=1, decode_unicode=False): if decode_unicode: raise NotImplementedError("Unicode not supported") - content_length = None - if "content-length" in self.headers: - content_length = int(self.headers["content-length"]) - total_read = 0 - if content_length: - while total_read < content_length: - if total_read == 0 and self._start_index < self._buffer_sizes[0]: - # print("remaining", self._start_index, self._buffer_sizes[0], self._receive_buffers[0]) - size = self._buffer_sizes[0] - left = size - self._start_index - chunk = b"".join((self._receive_buffers[0][self._start_index:size], - self.socket.recv(chunk_size - left))) - # print("remaining", chunk) + if self._content_length: + while self._content_read < self._content_length: + remaining = self._content_length - self._content_read + if chunk_size > remaining: + chunk_size = remaining + b = bytearray(chunk_size) + size = self._readinto(b) + total_read += self._content_read + if size < chunk_size: + chunk = bytes(memoryview(b)[:size]) else: - chunk = self.socket.recv(chunk_size) - # print("recv", chunk) - total_read += len(chunk) + chunk = bytes(b) yield chunk else: pending_bytes = 0 @@ -338,103 +358,131 @@ def iter_content(self, chunk_size=1, decode_unicode=False): yield b"".join(chunks) self.socket = None +class Session: + def __init__(self, socket_pool, ssl_context=None): + self._socket_pool = socket_pool + self._ssl_context = ssl_context + # Hang onto open sockets so that we can reuse them. + self._open_sockets = {} + self._last_response = None + + def _get_socket(self, host, port, proto, *, timeout=1): + key = (host, port, proto) + if key in self._open_sockets: + return self._open_sockets[key] + if proto == "https:" and not self._ssl_context: + raise RuntimeError("ssl_context must be set before using adafruit_requests for https") + addr_info = self._socket_pool.getaddrinfo(host, port, 0, self._socket_pool.SOCK_STREAM)[0] + sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + if proto == "https:": + sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + sock.settimeout(timeout) # socket read timeout + + sock.connect((host, port)) + self._open_sockets[key] = sock + return sock + + # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals + def request(self, method, url, data=None, json=None, headers=None, stream=False, timeout=60): + """Perform an HTTP request to the given url which we will parse to determine + whether to use SSL ('https://') or not. We can also send some provided 'data' + or a json dictionary which we will stringify. 'headers' is optional HTTP headers + sent along. 'stream' will determine if we buffer everything, or whether to only + read only when requested + """ + if not headers: + headers = {} -# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals -def request(method, url, data=None, json=None, headers=None, stream=False, timeout=60): - """Perform an HTTP request to the given url which we will parse to determine - whether to use SSL ('https://') or not. We can also send some provided 'data' - or a json dictionary which we will stringify. 'headers' is optional HTTP headers - sent along. 'stream' will determine if we buffer everything, or whether to only - read only when requested - """ - if not headers: - headers = {} - - try: - proto, dummy, host, path = url.split("/", 3) - # replace spaces in path - path = path.replace(" ", "%20") - except ValueError: - proto, dummy, host = url.split("/", 2) - path = "" - if proto == "http:": - port = 80 - elif proto == "https:": - port = 443 - else: - raise ValueError("Unsupported protocol: " + proto) - - if ":" in host: - host, port = host.split(":", 1) - port = int(port) - - socket = _get_socket(host, port, proto, timeout=timeout) - socket.send(b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))) - if "Host" not in headers: - socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) - if "User-Agent" not in headers: - socket.send(b"User-Agent: Adafruit CircuitPython\r\n") - # Iterate over keys to avoid tuple alloc - for k in headers: - socket.send(k.encode()) - socket.send(b": ") - socket.send(headers[k].encode()) - socket.send(b"\r\n") - if json is not None: - assert data is None - # pylint: disable=import-outside-toplevel try: - import json as json_module - except ImportError: - import ujson as json_module - data = json_module.dumps(json) - socket.send(b"Content-Type: application/json\r\n") - if data: - if isinstance(data, dict): - sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") - _post_data = "" - for k in data: - _post_data = "{}&{}={}".format(_post_data, k, data[k]) - data = _post_data[1:] - socket.send(b"Content-Length: %d\r\n" % len(data)) - socket.send(b"\r\n") - if data: - if isinstance(data, bytearray): - socket.send(bytes(data)) + proto, dummy, host, path = url.split("/", 3) + # replace spaces in path + path = path.replace(" ", "%20") + except ValueError: + proto, dummy, host = url.split("/", 2) + path = "" + if proto == "http:": + port = 80 + elif proto == "https:": + port = 443 else: - socket.send(bytes(data, "utf-8")) + raise ValueError("Unsupported protocol: " + proto) + + if ":" in host: + host, port = host.split(":", 1) + port = int(port) + + if self._last_response: + self._last_response._close() + self._last_response = None + + socket = self._get_socket(host, port, proto, timeout=timeout) + socket.send(b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))) + if "Host" not in headers: + socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) + if "User-Agent" not in headers: + socket.send(b"User-Agent: Adafruit CircuitPython\r\n") + # Iterate over keys to avoid tuple alloc + for k in headers: + socket.send(k.encode()) + socket.send(b": ") + socket.send(headers[k].encode()) + socket.send(b"\r\n") + if json is not None: + assert data is None + # pylint: disable=import-outside-toplevel + try: + import json as json_module + except ImportError: + import ujson as json_module + data = json_module.dumps(json) + socket.send(b"Content-Type: application/json\r\n") + if data: + if isinstance(data, dict): + sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") + _post_data = "" + for k in data: + _post_data = "{}&{}={}".format(_post_data, k, data[k]) + data = _post_data[1:] + socket.send(b"Content-Length: %d\r\n" % len(data)) + socket.send(b"\r\n") + if data: + if isinstance(data, bytearray): + socket.send(bytes(data)) + else: + socket.send(bytes(data, "utf-8")) - resp = Response(socket) # our response - if "location" in resp.headers and not 200 <= resp.status_code <= 299: - raise NotImplementedError("Redirects not yet supported") + resp = Response(socket) # our response + if "location" in resp.headers and not 200 <= resp.status_code <= 299: + raise NotImplementedError("Redirects not yet supported") - return resp + self._last_response = resp + return resp -def head(url, **kw): - """Send HTTP HEAD request""" - return request("HEAD", url, **kw) + def head(self, url, **kw): + """Send HTTP HEAD request""" + return self.request("HEAD", url, **kw) -def get(url, **kw): - """Send HTTP GET request""" - return request("GET", url, **kw) + def get(self, url, **kw): + """Send HTTP GET request""" + return self.request("GET", url, **kw) -def post(url, **kw): - """Send HTTP POST request""" - return request("POST", url, **kw) + def post(self, url, **kw): + """Send HTTP POST request""" + return self.request("POST", url, **kw) -def put(url, **kw): - """Send HTTP PUT request""" - return request("PUT", url, **kw) + def put(self, url, **kw): + """Send HTTP PUT request""" + return self.request("PUT", url, **kw) -def patch(url, **kw): - """Send HTTP PATCH request""" - return request("PATCH", url, **kw) + def patch(self, url, **kw): + """Send HTTP PATCH request""" + return self.request("PATCH", url, **kw) -def delete(url, **kw): - """Send HTTP DELETE request""" - return request("DELETE", url, **kw) + def delete(self, url, **kw): + """Send HTTP DELETE request""" + return self.request("DELETE", url, **kw) From 4f0c493a30c403ca1536e942c9bc84a0699bbce8 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 19 Aug 2020 17:48:27 -0700 Subject: [PATCH 08/19] A couple fixes --- adafruit_requests.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index abd1243..c543d6f 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -120,6 +120,8 @@ def _readto(self, first, second=b""): new_start = first_i + len(first) else: current_buffer = 1 + else: + self._start_index = 0 while found < -1: if current_buffer == len(self._receive_buffers): @@ -206,11 +208,11 @@ def _readinto(self, buf): if left == 1: buf[0] = self._receive_buffers[0][start] else: - buf[:] = self._receive_buffers[0][start:end] + buf[:left] = self._receive_buffers[0][start:end] read = left self._start_index += left if read < nbytes: - read += self.socket.recv_into(memoryview(buf)[left:nbytes]) + read += self.socket.recv_into(memoryview(buf)[read:nbytes]) else: read = self.socket.recv_into(buf, nbytes) self._content_read += read @@ -294,10 +296,7 @@ def text(self): def json(self): """The HTTP content, parsed into a json dictionary""" # pylint: disable=import-outside-toplevel - try: - import json as json_module - except ImportError: - import ujson as json_module + import json as json_module # The cached JSON will be a list or dictionary. if self._cached: @@ -319,7 +318,7 @@ def iter_content(self, chunk_size=1, decode_unicode=False): raise NotImplementedError("Unicode not supported") total_read = 0 - if self._content_length: + if self._content_length is not None: while self._content_read < self._content_length: remaining = self._content_length - self._content_read if chunk_size > remaining: @@ -377,7 +376,6 @@ def _get_socket(self, host, port, proto, *, timeout=1): if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) sock.settimeout(timeout) # socket read timeout - sock.connect((host, port)) self._open_sockets[key] = sock return sock From a0819852da226521d4f347c9d5f8fabd09016c23 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 21 Aug 2020 10:10:08 -0700 Subject: [PATCH 09/19] Fix up chunk handling --- adafruit_requests.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index c543d6f..bfb07fa 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -198,7 +198,6 @@ def _readinto(self, buf): nbytes = remaining if self._start_index < self._buffer_sizes[0]: - # print("remaining", self._start_index, self._buffer_sizes[0], self._receive_buffers[0]) size = self._buffer_sizes[0] left = size - self._start_index if nbytes < left: @@ -230,7 +229,6 @@ def _close(self): """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" if self.socket: # Make sure we've read all of our response. - # print("Content length:", content_length) if self._cached is None: if self._content_length: self._throw_away(self._content_length) @@ -333,28 +331,28 @@ def iter_content(self, chunk_size=1, decode_unicode=False): yield chunk else: pending_bytes = 0 - chunks = [] + buf = memoryview(bytearray(chunk_size)) while True: chunk_header = self._readto(b";", b"\r\n") http_chunk_size = int(chunk_header, 16) if http_chunk_size == 0: break + self._content_length = http_chunk_size remaining_in_http_chunk = http_chunk_size while remaining_in_http_chunk: read_now = chunk_size - pending_bytes if read_now > remaining_in_http_chunk: read_now = remaining_in_http_chunk - chunks.append(self.socket.read(read_now)) + read_now = self._readinto(buf[pending_bytes:pending_bytes+read_now]) pending_bytes += read_now if pending_bytes == chunk_size: - yield b"".join(chunks) + yield bytes(buf) pending_bytes = 0 - chunks = [] - self.socket.read(2) # Read the trailing CR LF + self._throw_away(2) # Read the trailing CR LF self._parse_headers() - if chunks: - yield b"".join(chunks) + if pending_bytes > 0: + yield bytes(buf[:pending_bytes]) self.socket = None class Session: From f32426d994fd0247ffcf51e2e59008937d03a374 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Tue, 25 Aug 2020 16:49:00 -0700 Subject: [PATCH 10/19] Properly close sockets as needed --- adafruit_requests.py | 60 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index bfb07fa..966f299 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -72,7 +72,7 @@ class Response: encoding = None - def __init__(self, sock): + def __init__(self, sock, session=None): self.socket = sock self.encoding = "utf-8" self._cached = None @@ -92,6 +92,7 @@ def __init__(self, sock): self._parse_headers() self._raw = None self._content_read = 0 + self._session = session def __enter__(self): return self @@ -229,10 +230,12 @@ def _close(self): """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" if self.socket: # Make sure we've read all of our response. + # print("Content length:", content_length) if self._cached is None: - if self._content_length: - self._throw_away(self._content_length) - else: + remaining = self._content_length - self._content_read + if remaining > 0: + self._throw_away(remaining) + elif self._content_length is None: while True: chunk_header = self._readto(b";", b"\r\n") chunk_size = int(chunk_header, 16) @@ -240,7 +243,11 @@ def _close(self): break self._throw_away(chunk_size + 2) self._parse_headers() - self.socket = None + if self._session: + self._session.free_socket(self.socket) + else: + self.socket.close() + self.socket = None def _parse_headers(self): """ @@ -307,6 +314,7 @@ def json(self): obj = json_module.load(self._raw) if not self._cached: self._cached = obj + self._close() return obj def iter_content(self, chunk_size=1, decode_unicode=False): @@ -353,7 +361,7 @@ def iter_content(self, chunk_size=1, decode_unicode=False): self._parse_headers() if pending_bytes > 0: yield bytes(buf[:pending_bytes]) - self.socket = None + self._close() class Session: def __init__(self, socket_pool, ssl_context=None): @@ -361,12 +369,22 @@ def __init__(self, socket_pool, ssl_context=None): self._ssl_context = ssl_context # Hang onto open sockets so that we can reuse them. self._open_sockets = {} + self._socket_free = {} self._last_response = None + def free_socket(self, socket): + if socket not in self._open_sockets.values(): + raise RuntimeError("Socket not from session") + self._socket_free[socket] = True + + def _get_socket(self, host, port, proto, *, timeout=1): key = (host, port, proto) if key in self._open_sockets: - return self._open_sockets[key] + sock = self._open_sockets[key] + if self._socket_free[sock]: + self._socket_free[sock] = False + return sock if proto == "https:" and not self._ssl_context: raise RuntimeError("ssl_context must be set before using adafruit_requests for https") addr_info = self._socket_pool.getaddrinfo(host, port, 0, self._socket_pool.SOCK_STREAM)[0] @@ -374,8 +392,32 @@ def _get_socket(self, host, port, proto, *, timeout=1): if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) sock.settimeout(timeout) # socket read timeout - sock.connect((host, port)) + ok = True + try: + sock.connect((host, port)) + except MemoryError: + if not any(self._socket_free.items()): + raise + ok = False + + # We couldn't connect due to memory so clean up the open sockets. + if not ok: + for s in self._socket_free: + if self._socket_free[s]: + s.close() + del self._socket_free[s] + for k in self._open_sockets: + if self._open_sockets[k] == s: + del self._open_sockets[k] + # Recreate the socket because the ESP-IDF won't retry the connection if it failed once. + sock = None # Clear first so the first socket can be cleaned up. + sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + if proto == "https:": + sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + sock.settimeout(timeout) # socket read timeout + sock.connect((host, port)) self._open_sockets[key] = sock + self._socket_free[sock] = False return sock # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals @@ -447,7 +489,7 @@ def request(self, method, url, data=None, json=None, headers=None, stream=False, else: socket.send(bytes(data, "utf-8")) - resp = Response(socket) # our response + resp = Response(socket, self) # our response if "location" in resp.headers and not 200 <= resp.status_code <= 299: raise NotImplementedError("Redirects not yet supported") From 6914a593bc221c16043af9a6da3b64f6363a6722 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Tue, 25 Aug 2020 17:02:43 -0700 Subject: [PATCH 11/19] Add backwards compatible API --- adafruit_requests.py | 51 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 966f299..8a50d01 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # # Copyright (c) 2019 ladyada for Adafruit Industries +# Copyright (c) 2020 Scott Shawcroft for Adafruit Industries # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +27,7 @@ A requests-like library for web interfacing -* Author(s): ladyada, Paul Sokolovsky +* Author(s): ladyada, Paul Sokolovsky, Scott Shawcroft Implementation Notes -------------------- @@ -377,7 +378,6 @@ def free_socket(self, socket): raise RuntimeError("Socket not from session") self._socket_free[socket] = True - def _get_socket(self, host, port, proto, *, timeout=1): key = (host, port, proto) if key in self._open_sockets: @@ -524,3 +524,50 @@ def patch(self, url, **kw): def delete(self, url, **kw): """Send HTTP DELETE request""" return self.request("DELETE", url, **kw) + +# Backwards compatible API: + +_default_session = None + +class FakeSSLContext: + def wrap_socket(self, socket): + return socket + +def set_socket(sock, iface=None): + global _default_session + _default_session = Session(sock, FakeSSLContext()) + if iface: + sock.set_interface(iface) + +def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): + _default_session.request(method, url, data=data, json=json, headers=headers, stream=stream, timeout=timeout) + + +def head(url, **kw): + """Send HTTP HEAD request""" + return _default_session.request("HEAD", url, **kw) + + +def get(url, **kw): + """Send HTTP GET request""" + return _default_session.request("GET", url, **kw) + + +def post(url, **kw): + """Send HTTP POST request""" + return _default_session.request("POST", url, **kw) + + +def put(url, **kw): + """Send HTTP PUT request""" + return _default_session.request("PUT", url, **kw) + + +def patch(url, **kw): + """Send HTTP PATCH request""" + return _default_session.request("PATCH", url, **kw) + + +def delete(url, **kw): + """Send HTTP DELETE request""" + return _default_session.request("DELETE", url, **kw) From 2afe50dee8a39f1c4a1cfa4e3444547b7c6df49a Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Tue, 1 Sep 2020 18:05:46 -0700 Subject: [PATCH 12/19] WIP: More tests and update API --- .github/workflows/test.yml | 4 - adafruit_requests.py | 325 ++++++++++++++-------------- examples/requests_github_cpython.py | 13 ++ examples/requests_https_cpython.py | 13 +- tests/chunk_test.py | 45 ++++ tests/header_test.py | 2 +- tests/legacy_test.py | 35 +++ tests/mocket.py | 22 +- tests/parse_test.py | 15 +- tests/post_test.py | 8 +- tests/protocol_test.py | 8 +- 11 files changed, 291 insertions(+), 199 deletions(-) create mode 100755 examples/requests_github_cpython.py create mode 100644 tests/chunk_test.py create mode 100644 tests/legacy_test.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d4888ef..2f9163b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,10 +6,6 @@ jobs: test: runs-on: ubuntu-latest steps: - - name: Dump GitHub context - env: - GITHUB_CONTEXT: ${{ toJson(github) }} - run: echo "$GITHUB_CONTEXT" - name: Set up Python 3.6 uses: actions/setup-python@v1 with: diff --git a/adafruit_requests.py b/adafruit_requests.py index 8a50d01..6d1a76d 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -61,8 +61,7 @@ def __init__(self, response): def read(self, size=-1): if size == -1: - return self.content - + return self._response.content return self._response.socket.recv(size) def readinto(self, buf): @@ -78,12 +77,17 @@ def __init__(self, sock, session=None): self.encoding = "utf-8" self._cached = None self._headers = {} - # 0 means the first receive buffer is empty because we always consume some of it. non-zero - # means we need to look at it's tail for our pattern. - self._start_index = 0 - self._buffer_sizes = [0] - self._receive_buffers = [bytearray(32)] - self._content_length = None + + # _start_index and _receive_buffer are used when parsing headers. + # _receive_buffer will grow by 32 bytes everytime it is too small. + self._received_length = 0 + self._receive_buffer = bytearray(32) + self._remaining = None + self._chunked = False + + self._backwards_compatible = not hasattr(sock, "recv_into") + if self._backwards_compatible: + print("Socket missing recv_into. Using more memory to be compatible") http = self._readto(b" ") if not http: @@ -101,131 +105,143 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def _readto(self, first, second=b""): - # TODO: Make this work if either pattern spans buffers. - if len(first) > 2 or len(second) > 2: - raise ValueError("Pattern too long. Must be less than 3 characters.") - current_buffer = 0 - found = -2 - new_start = 0 - if self._start_index < self._buffer_sizes[0]: - first_i = self._receive_buffers[0].find(first, self._start_index) - if second: - second_i = self._receive_buffers[0].find(second, self._start_index) - if second_i >= 0 and (first_i <= -1 or first_i > second_i): - found = second_i - new_start = second_i + len(second) - - if found == -2: - if first_i >= 0: - found = first_i - new_start = first_i + len(first) - else: - current_buffer = 1 + def _recv_into(self, buf, size=None): + if self._backwards_compatible: + size = len(buf) if size is None else size + b = self.socket.recv(size) + read_size = len(b) + buf[:read_size] = b + return read_size else: - self._start_index = 0 - - while found < -1: - if current_buffer == len(self._receive_buffers): - self._receive_buffers.append(bytearray(len(self._receive_buffers[0]))) - self._buffer_sizes.append(0) - buf = self._receive_buffers[current_buffer] - size = self.socket.recv_into(buf) - self._buffer_sizes[current_buffer] = size - - if len(first) == 2: - previous_size = self._buffer_sizes[current_buffer - 1] - if (self._receive_buffers[current_buffer - 1][previous_size - 1] == first[0] and - buf[0] == first[1]): - found = -1 - new_start = 1 - break - - first_i = buf.find(first, 0, size) - if second: - if len(second) == 2: - previous_size = self._buffer_sizes[current_buffer - 1] - if (self._receive_buffers[current_buffer - 1][previous_size - 1] == second[0] and - buf[0] == second[1]): - found = -1 - new_start = 1 - break - second_i = buf.find(second, 0, size) - if second_i >= 0 and (first_i < 0 or first_i > second_i): - found = second_i - new_start = second_i + len(second) - break - if first_i >= 0: - found = first_i - new_start = first_i + len(first) - break - current_buffer += 1 - + return self.socket.recv_into(buf, size) - if current_buffer == 0: - b = bytes(self._receive_buffers[0][self._start_index:found]) - self._start_index = new_start + def _readto(self, first, second=b""): + buf = self._receive_buffer + end = self._received_length + while True: + print("searching", buf[:end]) + firsti = buf.find(first, 0, end) + secondi = -1 + if second: + secondi = buf.find(second, 0, end) + + i = -1 + needle_len = 0 + if firsti >= 0: + i = firsti + needle_len = len(first) + if secondi >= 0 and (firsti < 0 or secondi < firsti): + i = secondi + needle_len = len(second) + if i >= 0: + result = buf[:i] + new_start = i + needle_len + + if i + needle_len <= end: + new_end = end - new_start + buf[:new_end] = buf[new_start:end] + self._received_length = new_end + return result + + # Not found so load more. + + # If our buffer is full, then make it bigger to load more. + if end == len(buf): + new_size = len(buf) + 32 + new_buf = bytearray(new_size) + new_buf[:len(buf)] = buf + buf = new_buf + self._receive_buffer = buf + + read = self._recv_into(memoryview(buf)[end:]) + if read == 0: + self._received_length = 0 + return buf[:end] + end += read + + return b"" + + def _read_from_buffer(self, buf=None, nbytes=None): + if self._received_length == 0: + return 0 + read = self._received_length + if nbytes < read: + read = nbytes + membuf = memoryview(self._receive_buffer) + if buf: + buf[:read] = membuf[:read] + if read < self._received_length: + new_end = self._received_length - read + self._receive_buffer[:new_end] = membuf[read:self._received_length] + self._received_length = new_end else: - new_len = len(self._receive_buffers[0]) * current_buffer + found - self._start_index - b = bytearray(new_len) - i = 0 - for bufi in range(0, current_buffer + 1): - buf = self._receive_buffers[bufi] - size = self._buffer_sizes[bufi] - if bufi == 0 and self._start_index > 0: - i = size - self._start_index - b[:i] = buf[self._start_index:size] - elif bufi == current_buffer: - if found > 0: - b[i:i+found] = buf[:found] - else: - b[i:i+size] = buf[:size] - i += size - - self._start_index = new_start - # Swap the current buffer to the front because it has some bytes we - # need to search. - last_buf = self._receive_buffers[current_buffer] - self._receive_buffers[current_buffer] = self._receive_buffers[0] - self._buffer_sizes[0] = self._buffer_sizes[current_buffer] - self._receive_buffers[0] = last_buf - self._buffer_sizes[current_buffer] = 0 - - return b + self._received_length = 0 + return read def _readinto(self, buf): - remaining = self._content_length - self._content_read - nbytes = len(buf) - if nbytes > remaining: - nbytes = remaining - - if self._start_index < self._buffer_sizes[0]: - size = self._buffer_sizes[0] - left = size - self._start_index - if nbytes < left: - left = nbytes - start = self._start_index - end = start + left - if left == 1: - buf[0] = self._receive_buffers[0][start] + if not self._remaining: + # Consume the chunk header if need be. + if self._chunked: + # Consume trailing \r\n for chunks 2+ + if self._remaining == 0: + self._throw_away(2) + chunk_header = self._readto(b";", b"\r\n") + http_chunk_size = int(chunk_header, 16) + if http_chunk_size == 0: + self._chunked = False + self._parse_headers() + return 0 + self._remaining = http_chunk_size else: - buf[:left] = self._receive_buffers[0][start:end] - read = left - self._start_index += left - if read < nbytes: - read += self.socket.recv_into(memoryview(buf)[read:nbytes]) - else: - read = self.socket.recv_into(buf, nbytes) - self._content_read += read + return 0 + + nbytes = len(buf) + if nbytes > self._remaining: + nbytes = self._remaining + + read = self._read_from_buffer(buf, nbytes) + if read == 0: + read = self._recv_into(buf, nbytes) + self._remaining -= read + + # else: + # print("chunked") + # pending_bytes = 0 + # buf = memoryview(bytearray(chunk_size)) + # while True: + # print("chunk", self._content_read, self._content_length) + # print("chunk header", chunk_header) + # self._content_length = http_chunk_size + # self._content_read = 0 + # remaining_in_http_chunk = http_chunk_size + + # pending_bytes = 0 + # while remaining_in_http_chunk: + # read_now = chunk_size - pending_bytes + # if read_now > remaining_in_http_chunk: + # read_now = remaining_in_http_chunk + # read_now = self._readinto(buf[pending_bytes:pending_bytes+read_now]) + # pending_bytes += read_now + # if pending_bytes == chunk_size: + # break + # yield bytes(buf) + + # self._throw_away(2) # Read the trailing CR LF + # + # if pending_bytes > 0: + # yield bytes(buf[:pending_bytes]) + return read def _throw_away(self, nbytes): - buf = self._receive_buffers[0] + nbytes -= self._read_from_buffer(nbytes=nbytes) + + buf = self._receive_buffer for i in range(nbytes // len(buf)): - self.socket.recv_into(buf) + self._recv_into(buf) remaining = nbytes % len(buf) if remaining: - self.socket.recv_into(buf, remaining) + self._recv_into(buf, remaining) def _close(self): """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" @@ -233,10 +249,9 @@ def _close(self): # Make sure we've read all of our response. # print("Content length:", content_length) if self._cached is None: - remaining = self._content_length - self._content_read - if remaining > 0: - self._throw_away(remaining) - elif self._content_length is None: + if self._remaining > 0: + self._throw_away(self._remaining) + elif self._chunked: while True: chunk_header = self._readto(b";", b"\r\n") chunk_size = int(chunk_header, 16) @@ -266,8 +281,10 @@ def _parse_headers(self): content = str(content, 'utf-8') # Check len first so we can skip the .lower allocation most of the time. if len(title) == len("content-length") and title.lower() == "content-length": - self._content_length = int(content) - self._headers[title] = content + self._remaining = int(content) + if len(title) == len("transfer-encoding") and title.lower() == "transfer-encoding": + self._chunked = content.lower() == "chunked" + self._headers[title] = content @property def headers(self): @@ -302,7 +319,7 @@ def text(self): def json(self): """The HTTP content, parsed into a json dictionary""" # pylint: disable=import-outside-toplevel - import json as json_module + import json # The cached JSON will be a list or dictionary. if self._cached: @@ -312,7 +329,7 @@ def json(self): if not self._raw: self._raw = _RawResponse(self) - obj = json_module.load(self._raw) + obj = json.load(self._raw) if not self._cached: self._cached = obj self._close() @@ -324,44 +341,16 @@ def iter_content(self, chunk_size=1, decode_unicode=False): if decode_unicode: raise NotImplementedError("Unicode not supported") - total_read = 0 - if self._content_length is not None: - while self._content_read < self._content_length: - remaining = self._content_length - self._content_read - if chunk_size > remaining: - chunk_size = remaining - b = bytearray(chunk_size) - size = self._readinto(b) - total_read += self._content_read - if size < chunk_size: - chunk = bytes(memoryview(b)[:size]) - else: - chunk = bytes(b) - yield chunk - else: - pending_bytes = 0 - buf = memoryview(bytearray(chunk_size)) - while True: - chunk_header = self._readto(b";", b"\r\n") - http_chunk_size = int(chunk_header, 16) - if http_chunk_size == 0: - break - self._content_length = http_chunk_size - remaining_in_http_chunk = http_chunk_size - while remaining_in_http_chunk: - read_now = chunk_size - pending_bytes - if read_now > remaining_in_http_chunk: - read_now = remaining_in_http_chunk - read_now = self._readinto(buf[pending_bytes:pending_bytes+read_now]) - pending_bytes += read_now - if pending_bytes == chunk_size: - yield bytes(buf) - pending_bytes = 0 - - self._throw_away(2) # Read the trailing CR LF - self._parse_headers() - if pending_bytes > 0: - yield bytes(buf[:pending_bytes]) + b = bytearray(chunk_size) + while True: + size = self._readinto(b) + if size == 0: + break + if size < chunk_size: + chunk = bytes(memoryview(b)[:size]) + else: + chunk = bytes(b) + yield chunk self._close() class Session: @@ -530,7 +519,7 @@ def delete(self, url, **kw): _default_session = None class FakeSSLContext: - def wrap_socket(self, socket): + def wrap_socket(self, socket, server_hostname=None): return socket def set_socket(sock, iface=None): diff --git a/examples/requests_github_cpython.py b/examples/requests_github_cpython.py new file mode 100755 index 0000000..af83549 --- /dev/null +++ b/examples/requests_github_cpython.py @@ -0,0 +1,13 @@ +# adafruit_requests usage with a CPython socket +import socket +import ssl +import adafruit_requests + +http = adafruit_requests.Session(socket, ssl.create_default_context()) + +print("Getting CircuitPython star count") +headers = {"Transfer-Encoding": "chunked"} +response = http.get("https://api.github.com/repos/adafruit/circuitpython", headers=headers) +print(response.headers) +print("circuitpython stars", response.json()["stargazers_count"]) + diff --git a/examples/requests_https_cpython.py b/examples/requests_https_cpython.py index ca2b5a4..71877c7 100755 --- a/examples/requests_https_cpython.py +++ b/examples/requests_https_cpython.py @@ -2,8 +2,8 @@ import socket import ssl import adafruit_requests as requests -requests.socket_module = socket -requests.ssl_context = ssl.create_default_context() + +http = requests.Session(socket, ssl.create_default_context()) TEXT_URL = "https://wifitest.adafruit.com/testwifi/index.html" JSON_GET_URL = "https://httpbin.org/get" @@ -18,31 +18,28 @@ # response.close() print("Fetching JSON data from %s" % JSON_GET_URL) -response = requests.get(JSON_GET_URL) +response = http.get(JSON_GET_URL) print("-" * 40) print("JSON Response: ", response.json()) print("-" * 40) -response.close() data = "31F" print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) -response = requests.post(JSON_POST_URL, data=data) +response = http.post(JSON_POST_URL, data=data) print("-" * 40) json_resp = response.json() # Parse out the 'data' key from json_resp dict. print("Data received from server:", json_resp["data"]) print("-" * 40) -response.close() json_data = {"Date": "July 25, 2019"} print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) -response = requests.post(JSON_POST_URL, json=json_data) +response = http.post(JSON_POST_URL, json=json_data) print("-" * 40) json_resp = response.json() # Parse out the 'json' key from json_resp dict. print("JSON Data received from server:", json_resp["json"]) print("-" * 40) -response.close() diff --git a/tests/chunk_test.py b/tests/chunk_test.py new file mode 100644 index 0000000..2497f2a --- /dev/null +++ b/tests/chunk_test.py @@ -0,0 +1,45 @@ +from unittest import mock +import mocket +import adafruit_requests + +ip = "1.2.3.4" +host = "wifitest.adafruit.com" +path = "/testwifi/index.html" +text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" +headers = b"HTTP/1.0 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + +def _chunk(response, split): + i = 0 + chunked = b"" + while i < len(response): + remaining = len(response) - i + chunk_size = split + if remaining < chunk_size: + chunk_size = remaining + new_i = i + chunk_size + chunked += hex(chunk_size)[2:].encode("ascii") + b"\r\n" + response[i:new_i] + b"\r\n" + i = new_i + # The final chunk is zero length. + chunked += b"0\r\n\r\n" + return chunked + +def test_get_text(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + c = _chunk(text, 33) + print(c) + sock = mocket.Mocket(headers + c) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + path) + + sock.connect.assert_called_once_with((host, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + assert r.text == str(text, "utf-8") + diff --git a/tests/header_test.py b/tests/header_test.py index d1d273e..0bf5659 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -19,7 +19,7 @@ def test_json(): headers = {"user-agent": "blinka/1.0.0"} r = adafruit_requests.get("http://" + host + "/get", headers=headers) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((host, 80)) sent = b"".join(sent).lower() assert b"user-agent: blinka/1.0.0\r\n" in sent # The current implementation sends two user agents. Fix it, and uncomment below. diff --git a/tests/legacy_test.py b/tests/legacy_test.py new file mode 100644 index 0000000..d015d64 --- /dev/null +++ b/tests/legacy_test.py @@ -0,0 +1,35 @@ +from unittest import mock +import mocket +import json +import adafruit_requests + +ip = "1.2.3.4" +host = "httpbin.org" +response = {"Date": "July 25, 2019"} +encoded = json.dumps(response).encode("utf-8") +headers = "HTTP/1.0 200 OK\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( + "utf-8" +) + +def test_get_json(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + del sock.recv_into + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("http://" + host + "/get") + sock.connect.assert_called_once_with((host, 80)) + assert r.json() == response + +def test_post_string(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + del sock.recv_into + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + data = "31F" + r = adafruit_requests.post("http://" + host + "/post", data=data) + sock.connect.assert_called_once_with((host, 80)) + sock.send.assert_called_with(b"31F") diff --git a/tests/mocket.py b/tests/mocket.py index 684aece..853783f 100644 --- a/tests/mocket.py +++ b/tests/mocket.py @@ -1,13 +1,15 @@ from unittest import mock -SOCK_STREAM = 0 - -getaddrinfo = mock.Mock() -socket = mock.Mock() set_interface = mock.Mock() interface = mock.MagicMock() +class MocketPool: + SOCK_STREAM = 0 + + def __init__(self): + self.getaddrinfo = mock.Mock() + self.socket = mock.Mock() class Mocket: def __init__(self, response): @@ -17,6 +19,7 @@ def __init__(self, response): self.send = mock.Mock() self.readline = mock.Mock(side_effect=self._readline) self.recv = mock.Mock(side_effect=self._recv) + self.recv_into = mock.Mock(side_effect=self._recv_into) self._response = response self._position = 0 @@ -31,3 +34,14 @@ def _recv(self, count): r = self._response[self._position : end] self._position = end return r + + def _recv_into(self, buf, nbytes=None): + assert nbytes is None or nbytes > 0 + read = nbytes if nbytes else len(buf) + remaining = len(self._response) - self._position + if read > remaining: + read = remaining + end = self._position + read + buf[:read] = self._response[self._position : end] + self._position = end + return read diff --git a/tests/parse_test.py b/tests/parse_test.py index 477a128..d1362b8 100644 --- a/tests/parse_test.py +++ b/tests/parse_test.py @@ -7,17 +7,20 @@ host = "httpbin.org" response = {"Date": "July 25, 2019"} encoded = json.dumps(response).encode("utf-8") -headers = "HTTP/1.0 200 OK\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( +# Padding here tests the case where a header line is exactly 32 bytes buffered by +# aligning the Content-Type header after it. +headers = "HTTP/1.0 200 OK\r\npadding: 000\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( "utf-8" ) def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("http://" + host + "/get") - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + "/get") + sock.connect.assert_called_once_with((host, 80)) assert r.json() == response diff --git a/tests/post_test.py b/tests/post_test.py index a2f5977..1f2701f 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -19,9 +19,9 @@ def test_method(): adafruit_requests.set_socket(mocket, mocket.interface) r = adafruit_requests.post("http://" + host + "/post") - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((host, 80)) sock.send.assert_has_calls( - [mock.call(b"POST /post HTTP/1.0\r\n"), mock.call(b"Host: httpbin.org\r\n")] + [mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")] ) @@ -33,7 +33,7 @@ def test_string(): adafruit_requests.set_socket(mocket, mocket.interface) data = "31F" r = adafruit_requests.post("http://" + host + "/post", data=data) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((host, 80)) sock.send.assert_called_with(b"31F") @@ -45,5 +45,5 @@ def test_json(): adafruit_requests.set_socket(mocket, mocket.interface) json_data = {"Date": "July 25, 2019"} r = adafruit_requests.post("http://" + host + "/post", json=json_data) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((host, 80)) sock.send.assert_called_with(b'{"Date": "July 25, 2019"}') diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 0eec002..3202395 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -17,10 +17,10 @@ def test_get_https_text(): adafruit_requests.set_socket(mocket, mocket.interface) r = adafruit_requests.get("https://" + host + path) - sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + sock.connect.assert_called_once_with((host, 443)) sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.0\r\n"), + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), mock.call(b"Host: wifitest.adafruit.com\r\n"), ] ) @@ -35,10 +35,10 @@ def test_get_http_text(): adafruit_requests.set_socket(mocket, mocket.interface) r = adafruit_requests.get("http://" + host + path) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((host, 80)) sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.0\r\n"), + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), mock.call(b"Host: wifitest.adafruit.com\r\n"), ] ) From e372574ad19c0b876780ac4d907963a9ac055839 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 2 Sep 2020 12:53:21 -0700 Subject: [PATCH 13/19] Fix close, test it and run black --- adafruit_requests.py | 158 +++++++++++++----------- examples/requests_advanced_cpython.py | 7 +- examples/requests_github_cpython.py | 6 +- examples/requests_https_cpython.py | 8 +- examples/requests_simpletest_cpython.py | 16 ++- tests/chunk_test.py | 7 +- tests/header_test.py | 9 +- tests/legacy_mocket.py | 32 +++++ tests/legacy_test.py | 9 +- tests/mocket.py | 18 ++- tests/parse_test.py | 4 +- tests/post_test.py | 27 ++-- tests/protocol_test.py | 37 ++++-- tests/reuse_test.py | 107 ++++++++++++++++ 14 files changed, 312 insertions(+), 133 deletions(-) create mode 100644 tests/legacy_mocket.py create mode 100644 tests/reuse_test.py diff --git a/adafruit_requests.py b/adafruit_requests.py index 6d1a76d..fb6e17c 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -55,6 +55,7 @@ __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" + class _RawResponse: def __init__(self, response): self._response = response @@ -67,6 +68,7 @@ def read(self, size=-1): def readinto(self, buf): return self._response._readinto(buf) + class Response: """The response from a request, contains all the headers/content""" @@ -105,9 +107,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def _recv_into(self, buf, size=None): + def _recv_into(self, buf, size=0): if self._backwards_compatible: - size = len(buf) if size is None else size + size = len(buf) if size == 0 else size b = self.socket.recv(size) read_size = len(b) buf[:read_size] = b @@ -119,7 +121,6 @@ def _readto(self, first, second=b""): buf = self._receive_buffer end = self._received_length while True: - print("searching", buf[:end]) firsti = buf.find(first, 0, end) secondi = -1 if second: @@ -149,7 +150,7 @@ def _readto(self, first, second=b""): if end == len(buf): new_size = len(buf) + 32 new_buf = bytearray(new_size) - new_buf[:len(buf)] = buf + new_buf[: len(buf)] = buf buf = new_buf self._receive_buffer = buf @@ -172,13 +173,18 @@ def _read_from_buffer(self, buf=None, nbytes=None): buf[:read] = membuf[:read] if read < self._received_length: new_end = self._received_length - read - self._receive_buffer[:new_end] = membuf[read:self._received_length] + self._receive_buffer[:new_end] = membuf[read : self._received_length] self._received_length = new_end else: self._received_length = 0 return read def _readinto(self, buf): + if not self.socket: + raise RuntimeError( + "Newer Response closed this one. Use Responses immediately." + ) + if not self._remaining: # Consume the chunk header if need be. if self._chunked: @@ -204,33 +210,6 @@ def _readinto(self, buf): read = self._recv_into(buf, nbytes) self._remaining -= read - # else: - # print("chunked") - # pending_bytes = 0 - # buf = memoryview(bytearray(chunk_size)) - # while True: - # print("chunk", self._content_read, self._content_length) - # print("chunk header", chunk_header) - # self._content_length = http_chunk_size - # self._content_read = 0 - # remaining_in_http_chunk = http_chunk_size - - # pending_bytes = 0 - # while remaining_in_http_chunk: - # read_now = chunk_size - pending_bytes - # if read_now > remaining_in_http_chunk: - # read_now = remaining_in_http_chunk - # read_now = self._readinto(buf[pending_bytes:pending_bytes+read_now]) - # pending_bytes += read_now - # if pending_bytes == chunk_size: - # break - # yield bytes(buf) - - # self._throw_away(2) # Read the trailing CR LF - # - # if pending_bytes > 0: - # yield bytes(buf[:pending_bytes]) - return read def _throw_away(self, nbytes): @@ -243,27 +222,27 @@ def _throw_away(self, nbytes): if remaining: self._recv_into(buf, remaining) - def _close(self): + def close(self): """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" - if self.socket: - # Make sure we've read all of our response. - # print("Content length:", content_length) - if self._cached is None: - if self._remaining > 0: - self._throw_away(self._remaining) - elif self._chunked: - while True: - chunk_header = self._readto(b";", b"\r\n") - chunk_size = int(chunk_header, 16) - if chunk_size == 0: - break - self._throw_away(chunk_size + 2) - self._parse_headers() - if self._session: - self._session.free_socket(self.socket) - else: - self.socket.close() - self.socket = None + if not self.socket: + return + # Make sure we've read all of our response. + if self._cached is None: + if self._remaining > 0: + self._throw_away(self._remaining) + elif self._chunked: + while True: + chunk_header = self._readto(b";", b"\r\n") + chunk_size = int(chunk_header, 16) + if chunk_size == 0: + break + self._throw_away(chunk_size + 2) + self._parse_headers() + if self._session: + self._session.free_socket(self.socket) + else: + self.socket.close() + self.socket = None def _parse_headers(self): """ @@ -277,14 +256,20 @@ def _parse_headers(self): content = self._readto(b"\r\n") if title and content: - title = str(title, 'utf-8') - content = str(content, 'utf-8') + title = str(title, "utf-8") + content = str(content, "utf-8") # Check len first so we can skip the .lower allocation most of the time. - if len(title) == len("content-length") and title.lower() == "content-length": + if ( + len(title) == len("content-length") + and title.lower() == "content-length" + ): self._remaining = int(content) - if len(title) == len("transfer-encoding") and title.lower() == "transfer-encoding": + if ( + len(title) == len("transfer-encoding") + and title.lower() == "transfer-encoding" + ): self._chunked = content.lower() == "chunked" - self._headers[title] = content + self._headers[title] = content @property def headers(self): @@ -332,7 +317,7 @@ def json(self): obj = json.load(self._raw) if not self._cached: self._cached = obj - self._close() + self.close() return obj def iter_content(self, chunk_size=1, decode_unicode=False): @@ -351,7 +336,8 @@ def iter_content(self, chunk_size=1, decode_unicode=False): else: chunk = bytes(b) yield chunk - self._close() + self.close() + class Session: def __init__(self, socket_pool, ssl_context=None): @@ -375,8 +361,12 @@ def _get_socket(self, host, port, proto, *, timeout=1): self._socket_free[sock] = False return sock if proto == "https:" and not self._ssl_context: - raise RuntimeError("ssl_context must be set before using adafruit_requests for https") - addr_info = self._socket_pool.getaddrinfo(host, port, 0, self._socket_pool.SOCK_STREAM)[0] + raise RuntimeError( + "ssl_context must be set before using adafruit_requests for https" + ) + addr_info = self._socket_pool.getaddrinfo( + host, port, 0, self._socket_pool.SOCK_STREAM + )[0] sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) @@ -391,15 +381,22 @@ def _get_socket(self, host, port, proto, *, timeout=1): # We couldn't connect due to memory so clean up the open sockets. if not ok: + free_sockets = [] for s in self._socket_free: if self._socket_free[s]: s.close() - del self._socket_free[s] - for k in self._open_sockets: - if self._open_sockets[k] == s: - del self._open_sockets[k] + free_sockets.append(s) + for s in free_sockets: + del self._socket_free[s] + key = None + for k in self._open_sockets: + if self._open_sockets[k] == s: + key = k + break + if key: + del self._open_sockets[key] # Recreate the socket because the ESP-IDF won't retry the connection if it failed once. - sock = None # Clear first so the first socket can be cleaned up. + sock = None # Clear first so the first socket can be cleaned up. sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) @@ -410,7 +407,9 @@ def _get_socket(self, host, port, proto, *, timeout=1): return sock # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals - def request(self, method, url, data=None, json=None, headers=None, stream=False, timeout=60): + def request( + self, method, url, data=None, json=None, headers=None, stream=False, timeout=60 + ): """Perform an HTTP request to the given url which we will parse to determine whether to use SSL ('https://') or not. We can also send some provided 'data' or a json dictionary which we will stringify. 'headers' is optional HTTP headers @@ -439,11 +438,13 @@ def request(self, method, url, data=None, json=None, headers=None, stream=False, port = int(port) if self._last_response: - self._last_response._close() + self._last_response.close() self._last_response = None socket = self._get_socket(host, port, proto, timeout=timeout) - socket.send(b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))) + socket.send( + b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")) + ) if "Host" not in headers: socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) if "User-Agent" not in headers: @@ -489,47 +490,54 @@ def head(self, url, **kw): """Send HTTP HEAD request""" return self.request("HEAD", url, **kw) - def get(self, url, **kw): """Send HTTP GET request""" return self.request("GET", url, **kw) - def post(self, url, **kw): """Send HTTP POST request""" return self.request("POST", url, **kw) - def put(self, url, **kw): """Send HTTP PUT request""" return self.request("PUT", url, **kw) - def patch(self, url, **kw): """Send HTTP PATCH request""" return self.request("PATCH", url, **kw) - def delete(self, url, **kw): """Send HTTP DELETE request""" return self.request("DELETE", url, **kw) + # Backwards compatible API: _default_session = None + class FakeSSLContext: def wrap_socket(self, socket, server_hostname=None): return socket + def set_socket(sock, iface=None): global _default_session _default_session = Session(sock, FakeSSLContext()) if iface: sock.set_interface(iface) + def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): - _default_session.request(method, url, data=data, json=json, headers=headers, stream=stream, timeout=timeout) + _default_session.request( + method, + url, + data=data, + json=json, + headers=headers, + stream=stream, + timeout=timeout, + ) def head(url, **kw): diff --git a/examples/requests_advanced_cpython.py b/examples/requests_advanced_cpython.py index 725974f..379620e 100644 --- a/examples/requests_advanced_cpython.py +++ b/examples/requests_advanced_cpython.py @@ -1,6 +1,7 @@ import socket -import adafruit_requests as requests -requests.socket_module = socket +import adafruit_requests + +http = adafruit_requests.Session(socket) JSON_GET_URL = "http://httpbin.org/get" @@ -8,7 +9,7 @@ headers = {"user-agent": "blinka/1.0.0"} print("Fetching JSON data from %s..." % JSON_GET_URL) -response = requests.get(JSON_GET_URL, headers=headers) +response = http.get(JSON_GET_URL, headers=headers) print("-" * 60) json_data = response.json() diff --git a/examples/requests_github_cpython.py b/examples/requests_github_cpython.py index af83549..e258d1f 100755 --- a/examples/requests_github_cpython.py +++ b/examples/requests_github_cpython.py @@ -7,7 +7,7 @@ print("Getting CircuitPython star count") headers = {"Transfer-Encoding": "chunked"} -response = http.get("https://api.github.com/repos/adafruit/circuitpython", headers=headers) -print(response.headers) +response = http.get( + "https://api.github.com/repos/adafruit/circuitpython", headers=headers +) print("circuitpython stars", response.json()["stargazers_count"]) - diff --git a/examples/requests_https_cpython.py b/examples/requests_https_cpython.py index 71877c7..dca4d31 100755 --- a/examples/requests_https_cpython.py +++ b/examples/requests_https_cpython.py @@ -3,7 +3,7 @@ import ssl import adafruit_requests as requests -http = requests.Session(socket, ssl.create_default_context()) +https = requests.Session(socket, ssl.create_default_context()) TEXT_URL = "https://wifitest.adafruit.com/testwifi/index.html" JSON_GET_URL = "https://httpbin.org/get" @@ -18,7 +18,7 @@ # response.close() print("Fetching JSON data from %s" % JSON_GET_URL) -response = http.get(JSON_GET_URL) +response = https.get(JSON_GET_URL) print("-" * 40) print("JSON Response: ", response.json()) @@ -26,7 +26,7 @@ data = "31F" print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) -response = http.post(JSON_POST_URL, data=data) +response = https.post(JSON_POST_URL, data=data) print("-" * 40) json_resp = response.json() @@ -36,7 +36,7 @@ json_data = {"Date": "July 25, 2019"} print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) -response = http.post(JSON_POST_URL, json=json_data) +response = https.post(JSON_POST_URL, json=json_data) print("-" * 40) json_resp = response.json() diff --git a/examples/requests_simpletest_cpython.py b/examples/requests_simpletest_cpython.py index 5ca5536..db9fca2 100755 --- a/examples/requests_simpletest_cpython.py +++ b/examples/requests_simpletest_cpython.py @@ -1,22 +1,22 @@ # adafruit_requests usage with a CPython socket import socket -import adafruit_requests as requests -requests.socket_module = socket +import adafruit_requests + +http = adafruit_requests.Session(socket) TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" JSON_GET_URL = "http://httpbin.org/get" JSON_POST_URL = "http://httpbin.org/post" print("Fetching text from %s" % TEXT_URL) -response = requests.get(TEXT_URL) +response = http.get(TEXT_URL) print("-" * 40) print("Text Response: ", response.text) print("-" * 40) -response.close() print("Fetching JSON data from %s" % JSON_GET_URL) -response = requests.get(JSON_GET_URL) +response = http.get(JSON_GET_URL) print("-" * 40) print("JSON Response: ", response.json()) @@ -25,22 +25,20 @@ data = "31F" print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) -response = requests.post(JSON_POST_URL, data=data) +response = http.post(JSON_POST_URL, data=data) print("-" * 40) json_resp = response.json() # Parse out the 'data' key from json_resp dict. print("Data received from server:", json_resp["data"]) print("-" * 40) -response.close() json_data = {"Date": "July 25, 2019"} print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) -response = requests.post(JSON_POST_URL, json=json_data) +response = http.post(JSON_POST_URL, json=json_data) print("-" * 40) json_resp = response.json() # Parse out the 'json' key from json_resp dict. print("JSON Data received from server:", json_resp["json"]) print("-" * 40) -response.close() diff --git a/tests/chunk_test.py b/tests/chunk_test.py index 2497f2a..cf2f01d 100644 --- a/tests/chunk_test.py +++ b/tests/chunk_test.py @@ -8,6 +8,7 @@ text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" headers = b"HTTP/1.0 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + def _chunk(response, split): i = 0 chunked = b"" @@ -17,12 +18,15 @@ def _chunk(response, split): if remaining < chunk_size: chunk_size = remaining new_i = i + chunk_size - chunked += hex(chunk_size)[2:].encode("ascii") + b"\r\n" + response[i:new_i] + b"\r\n" + chunked += ( + hex(chunk_size)[2:].encode("ascii") + b"\r\n" + response[i:new_i] + b"\r\n" + ) i = new_i # The final chunk is zero length. chunked += b"0\r\n\r\n" return chunked + def test_get_text(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) @@ -42,4 +46,3 @@ def test_get_text(): ] ) assert r.text == str(text, "utf-8") - diff --git a/tests/header_test.py b/tests/header_test.py index 0bf5659..921fac3 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -9,15 +9,16 @@ def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response_headers) - mocket.socket.return_value = sock + pool.socket.return_value = sock sent = [] sock.send.side_effect = sent.append - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) headers = {"user-agent": "blinka/1.0.0"} - r = adafruit_requests.get("http://" + host + "/get", headers=headers) + r = s.get("http://" + host + "/get", headers=headers) sock.connect.assert_called_once_with((host, 80)) sent = b"".join(sent).lower() diff --git a/tests/legacy_mocket.py b/tests/legacy_mocket.py new file mode 100644 index 0000000..4a37bd2 --- /dev/null +++ b/tests/legacy_mocket.py @@ -0,0 +1,32 @@ +from unittest import mock + +SOCK_STREAM = 0 + +set_interface = mock.Mock() +interface = mock.MagicMock() +getaddrinfo = mock.Mock() +socket = mock.Mock() + + +class Mocket: + def __init__(self, response): + self.settimeout = mock.Mock() + self.close = mock.Mock() + self.connect = mock.Mock() + self.send = mock.Mock() + self.readline = mock.Mock(side_effect=self._readline) + self.recv = mock.Mock(side_effect=self._recv) + self._response = response + self._position = 0 + + def _readline(self): + i = self._response.find(b"\r\n", self._position) + r = self._response[self._position : i + 2] + self._position = i + 2 + return r + + def _recv(self, count): + end = self._position + count + r = self._response[self._position : end] + self._position = end + return r diff --git a/tests/legacy_test.py b/tests/legacy_test.py index d015d64..c3662e2 100644 --- a/tests/legacy_test.py +++ b/tests/legacy_test.py @@ -1,5 +1,5 @@ from unittest import mock -import mocket +import legacy_mocket as mocket import json import adafruit_requests @@ -11,21 +11,23 @@ "utf-8" ) + def test_get_json(): mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - del sock.recv_into mocket.socket.return_value = sock adafruit_requests.set_socket(mocket, mocket.interface) r = adafruit_requests.get("http://" + host + "/get") + sock.connect.assert_called_once_with((host, 80)) assert r.json() == response + r.close() + def test_post_string(): mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - del sock.recv_into mocket.socket.return_value = sock adafruit_requests.set_socket(mocket, mocket.interface) @@ -33,3 +35,4 @@ def test_post_string(): r = adafruit_requests.post("http://" + host + "/post", data=data) sock.connect.assert_called_once_with((host, 80)) sock.send.assert_called_with(b"31F") + r.close() diff --git a/tests/mocket.py b/tests/mocket.py index 853783f..ec9a557 100644 --- a/tests/mocket.py +++ b/tests/mocket.py @@ -1,8 +1,5 @@ from unittest import mock -set_interface = mock.Mock() - -interface = mock.MagicMock() class MocketPool: SOCK_STREAM = 0 @@ -11,6 +8,7 @@ def __init__(self): self.getaddrinfo = mock.Mock() self.socket = mock.Mock() + class Mocket: def __init__(self, response): self.settimeout = mock.Mock() @@ -35,9 +33,9 @@ def _recv(self, count): self._position = end return r - def _recv_into(self, buf, nbytes=None): - assert nbytes is None or nbytes > 0 - read = nbytes if nbytes else len(buf) + def _recv_into(self, buf, nbytes=0): + assert isinstance(nbytes, int) and nbytes >= 0 + read = nbytes if nbytes > 0 else len(buf) remaining = len(self._response) - self._position if read > remaining: read = remaining @@ -45,3 +43,11 @@ def _recv_into(self, buf, nbytes=None): buf[:read] = self._response[self._position : end] self._position = end return read + + +class SSLContext: + def __init__(self): + self.wrap_socket = mock.Mock(side_effect=self._wrap_socket) + + def _wrap_socket(self, sock, server_hostname=None): + return sock diff --git a/tests/parse_test.py b/tests/parse_test.py index d1362b8..d2331dc 100644 --- a/tests/parse_test.py +++ b/tests/parse_test.py @@ -9,7 +9,9 @@ encoded = json.dumps(response).encode("utf-8") # Padding here tests the case where a header line is exactly 32 bytes buffered by # aligning the Content-Type header after it. -headers = "HTTP/1.0 200 OK\r\npadding: 000\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( +headers = "HTTP/1.0 200 OK\r\npadding: 000\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n".format( + len(encoded) +).encode( "utf-8" ) diff --git a/tests/post_test.py b/tests/post_test.py index 1f2701f..20c0618 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -13,12 +13,13 @@ def test_method(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.post("http://" + host + "/post") + s = adafruit_requests.Session(pool) + r = s.post("http://" + host + "/post") sock.connect.assert_called_once_with((host, 80)) sock.send.assert_has_calls( [mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")] @@ -26,24 +27,26 @@ def test_method(): def test_string(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) data = "31F" - r = adafruit_requests.post("http://" + host + "/post", data=data) + r = s.post("http://" + host + "/post", data=data) sock.connect.assert_called_once_with((host, 80)) sock.send.assert_called_with(b"31F") def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) json_data = {"Date": "July 25, 2019"} - r = adafruit_requests.post("http://" + host + "/post", json=json_data) + r = s.post("http://" + host + "/post", json=json_data) sock.connect.assert_called_once_with((host, 80)) sock.send.assert_called_with(b'{"Date": "July 25, 2019"}') diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 3202395..76e2289 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -1,5 +1,6 @@ from unittest import mock import mocket +import pytest import adafruit_requests ip = "1.2.3.4" @@ -9,13 +10,26 @@ response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text +def test_get_https_no_ssl(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + with pytest.raises(RuntimeError): + r = s.get("https://" + host + path) + + def test_get_https_text(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response) - mocket.socket.return_value = sock + pool.socket.return_value = sock + ssl = mocket.SSLContext() - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("https://" + host + path) + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) sock.connect.assert_called_once_with((host, 443)) sock.send.assert_has_calls( @@ -26,14 +40,18 @@ def test_get_https_text(): ) assert r.text == str(text, "utf-8") + # Close isn't needed but can be called to release the socket early. + r.close() + def test_get_http_text(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("http://" + host + path) + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + path) sock.connect.assert_called_once_with((host, 80)) sock.send.assert_has_calls( @@ -43,6 +61,3 @@ def test_get_http_text(): ] ) assert r.text == str(text, "utf-8") - - -# Add a chunked response test when we support HTTP 1.1 diff --git a/tests/reuse_test.py b/tests/reuse_test.py new file mode 100644 index 0000000..2e06931 --- /dev/null +++ b/tests/reuse_test.py @@ -0,0 +1,107 @@ +from unittest import mock +import mocket +import pytest +import adafruit_requests + +ip = "1.2.3.4" +host = "wifitest.adafruit.com" +host2 = "wifitest2.adafruit.com" +path = "/testwifi/index.html" +text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" +response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text + +# def test_get_twice(): +# pool = mocket.MocketPool() +# pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) +# sock = mocket.Mocket(response + response) +# pool.socket.return_value = sock +# ssl = mocket.SSLContext() + +# s = adafruit_requests.Session(pool, ssl) +# r = s.get("https://" + host + path) + +# sock.send.assert_has_calls( +# [ +# mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), +# mock.call(b"Host: wifitest.adafruit.com\r\n"), +# ] +# ) +# assert r.text == str(text, "utf-8") + +# r = s.get("https://" + host + path + "2") +# sock.send.assert_has_calls( +# [ +# mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), +# mock.call(b"Host: wifitest.adafruit.com\r\n"), +# ] +# ) + +# assert r.text == str(text, "utf-8") +# sock.connect.assert_called_once_with((host, 443)) +# pool.socket.assert_called_once() + + +def test_get_twice_after_second(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response + response) + pool.socket.return_value = sock + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + + r2 = s.get("https://" + host + path + "2") + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + sock.connect.assert_called_once_with((host, 443)) + pool.socket.assert_called_once() + + with pytest.raises(RuntimeError): + r.text == str(text, "utf-8") + + +def test_connect_out_of_memory(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + sock2 = mocket.Mocket(response) + sock3 = mocket.Mocket(response) + pool.socket.side_effect = [sock, sock2, sock3] + sock2.connect.side_effect = MemoryError() + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + assert r.text == str(text, "utf-8") + + r = s.get("https://" + host2 + path) + sock3.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest2.adafruit.com\r\n"), + ] + ) + + assert r.text == str(text, "utf-8") + sock.close.assert_called_once() + sock.connect.assert_called_once_with((host, 443)) + sock3.connect.assert_called_once_with((host2, 443)) From 6d7a44ae8857d714ae65ca835831477140951a98 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 2 Sep 2020 17:44:42 -0700 Subject: [PATCH 14/19] Lint, test and fix bug with form data --- adafruit_requests.py | 72 ++++++++++++++++++++++++++------------------ tests/post_test.py | 13 ++++++++ 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index fb6e17c..517a54c 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -50,8 +50,6 @@ """ -import gc - __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" @@ -61,16 +59,23 @@ def __init__(self, response): self._response = response def read(self, size=-1): + """Read as much as available or up to size and return it in a byte string. + + Do NOT use this unless you really need to. Reusing memory with `readinto` is much better. + """ if size == -1: return self._response.content return self._response.socket.recv(size) def readinto(self, buf): - return self._response._readinto(buf) + """Read as much as available into buf or until it is full. Returns the number of bytes read + into buf.""" + return self._response._readinto(buf) # pylint: disable=protected-access class Response: """The response from a request, contains all the headers/content""" + # pylint: disable=too-many-instance-attributes encoding = None @@ -98,7 +103,6 @@ def __init__(self, sock, session=None): self.reason = self._readto(b"\r\n") self._parse_headers() self._raw = None - self._content_read = 0 self._session = session def __enter__(self): @@ -114,8 +118,7 @@ def _recv_into(self, buf, size=0): read_size = len(b) buf[:read_size] = b return read_size - else: - return self.socket.recv_into(buf, size) + return self.socket.recv_into(buf, size) def _readto(self, first, second=b""): buf = self._receive_buffer @@ -216,7 +219,7 @@ def _throw_away(self, nbytes): nbytes -= self._read_from_buffer(nbytes=nbytes) buf = self._receive_buffer - for i in range(nbytes // len(buf)): + for _ in range(nbytes // len(buf)): self._recv_into(buf) remaining = nbytes % len(buf) if remaining: @@ -239,7 +242,7 @@ def close(self): self._throw_away(chunk_size + 2) self._parse_headers() if self._session: - self._session.free_socket(self.socket) + self._session._free_socket(self.socket) # pylint: disable=protected-access else: self.socket.close() self.socket = None @@ -340,6 +343,7 @@ def iter_content(self, chunk_size=1, decode_unicode=False): class Session: + """HTTP session that shares sockets and ssl context.""" def __init__(self, socket_pool, ssl_context=None): self._socket_pool = socket_pool self._ssl_context = ssl_context @@ -348,11 +352,28 @@ def __init__(self, socket_pool, ssl_context=None): self._socket_free = {} self._last_response = None - def free_socket(self, socket): + def _free_socket(self, socket): + if socket not in self._open_sockets.values(): raise RuntimeError("Socket not from session") self._socket_free[socket] = True + def _free_sockets(self): + free_sockets = [] + for sock in self._socket_free: + if self._socket_free[sock]: + sock.close() + free_sockets.append(sock) + for sock in free_sockets: + del self._socket_free[sock] + key = None + for k in self._open_sockets: + if self._open_sockets[k] == sock: + key = k + break + if key: + del self._open_sockets[key] + def _get_socket(self, host, port, proto, *, timeout=1): key = (host, port, proto) if key in self._open_sockets: @@ -381,20 +402,7 @@ def _get_socket(self, host, port, proto, *, timeout=1): # We couldn't connect due to memory so clean up the open sockets. if not ok: - free_sockets = [] - for s in self._socket_free: - if self._socket_free[s]: - s.close() - free_sockets.append(s) - for s in free_sockets: - del self._socket_free[s] - key = None - for k in self._open_sockets: - if self._open_sockets[k] == s: - key = k - break - if key: - del self._open_sockets[key] + self._free_sockets() # Recreate the socket because the ESP-IDF won't retry the connection if it failed once. sock = None # Clear first so the first socket can be cleaned up. sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) @@ -466,7 +474,7 @@ def request( socket.send(b"Content-Type: application/json\r\n") if data: if isinstance(data, dict): - sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") + socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n") _post_data = "" for k in data: _post_data = "{}&{}={}".format(_post_data, k, data[k]) @@ -513,22 +521,28 @@ def delete(self, url, **kw): # Backwards compatible API: -_default_session = None +_default_session = None # pylint: disable=invalid-name -class FakeSSLContext: - def wrap_socket(self, socket, server_hostname=None): +class _FakeSSLContext: + @staticmethod + def wrap_socket(socket, server_hostname=None): + """Return the same socket""" + # pylint: disable=unused-argument return socket def set_socket(sock, iface=None): - global _default_session - _default_session = Session(sock, FakeSSLContext()) + """Legacy API for setting the socket and network interface. Use a `Session` instead.""" + global _default_session # pylint: disable=global-statement,invalid-name + _default_session = Session(sock, _FakeSSLContext()) if iface: sock.set_interface(iface) def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): + """Send HTTP request""" + # pylint: disable=too-many-arguments _default_session.request( method, url, diff --git a/tests/post_test.py b/tests/post_test.py index 20c0618..164d85c 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -39,6 +39,19 @@ def test_string(): sock.send.assert_called_with(b"31F") +def test_form(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + data = {"Date": "July 25, 2019"} + r = s.post("http://" + host + "/post", data=data) + sock.connect.assert_called_once_with((host, 80)) + sock.send.assert_called_with(b"Date=July 25, 2019") + + def test_json(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) From e6a4c8f73ab1a43907eac6ce142e1226e7d78237 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 2 Sep 2020 17:46:57 -0700 Subject: [PATCH 15/19] Black --- adafruit_requests.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 517a54c..2ebfb03 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -61,20 +61,21 @@ def __init__(self, response): def read(self, size=-1): """Read as much as available or up to size and return it in a byte string. - Do NOT use this unless you really need to. Reusing memory with `readinto` is much better. - """ + Do NOT use this unless you really need to. Reusing memory with `readinto` is much better. + """ if size == -1: return self._response.content return self._response.socket.recv(size) def readinto(self, buf): """Read as much as available into buf or until it is full. Returns the number of bytes read - into buf.""" - return self._response._readinto(buf) # pylint: disable=protected-access + into buf.""" + return self._response._readinto(buf) # pylint: disable=protected-access class Response: """The response from a request, contains all the headers/content""" + # pylint: disable=too-many-instance-attributes encoding = None @@ -242,7 +243,7 @@ def close(self): self._throw_away(chunk_size + 2) self._parse_headers() if self._session: - self._session._free_socket(self.socket) # pylint: disable=protected-access + self._session._free_socket(self.socket) # pylint: disable=protected-access else: self.socket.close() self.socket = None @@ -344,6 +345,7 @@ def iter_content(self, chunk_size=1, decode_unicode=False): class Session: """HTTP session that shares sockets and ssl context.""" + def __init__(self, socket_pool, ssl_context=None): self._socket_pool = socket_pool self._ssl_context = ssl_context @@ -521,7 +523,7 @@ def delete(self, url, **kw): # Backwards compatible API: -_default_session = None # pylint: disable=invalid-name +_default_session = None # pylint: disable=invalid-name class _FakeSSLContext: @@ -534,7 +536,7 @@ def wrap_socket(socket, server_hostname=None): def set_socket(sock, iface=None): """Legacy API for setting the socket and network interface. Use a `Session` instead.""" - global _default_session # pylint: disable=global-statement,invalid-name + global _default_session # pylint: disable=global-statement,invalid-name _default_session = Session(sock, _FakeSSLContext()) if iface: sock.set_interface(iface) From 46921d0837085fdae4bf11dcc8005acb07f04070 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 11 Sep 2020 15:17:19 -0700 Subject: [PATCH 16/19] Complete backwards compatibility: * Add TLS_MODE to old connect() calls * Provide ip to connect when in HTTP mode --- adafruit_requests.py | 34 ++++++++++++++++++++++++---------- tests/chunk_test.py | 2 +- tests/header_test.py | 2 +- tests/legacy_test.py | 15 +++++++++++++-- tests/parse_test.py | 2 +- tests/post_test.py | 8 ++++---- tests/protocol_test.py | 2 +- 7 files changed, 45 insertions(+), 20 deletions(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 2ebfb03..f594b72 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -100,7 +100,7 @@ def __init__(self, sock, session=None): http = self._readto(b" ") if not http: raise RuntimeError("Unable to read HTTP response.") - self.status_code = int(self._readto(b" ")) + self.status_code = int(bytes(self._readto(b" "))) self.reason = self._readto(b"\r\n") self._parse_headers() self._raw = None @@ -196,7 +196,7 @@ def _readinto(self, buf): if self._remaining == 0: self._throw_away(2) chunk_header = self._readto(b";", b"\r\n") - http_chunk_size = int(chunk_header, 16) + http_chunk_size = int(bytes(chunk_header), 16) if http_chunk_size == 0: self._chunked = False self._parse_headers() @@ -237,7 +237,7 @@ def close(self): elif self._chunked: while True: chunk_header = self._readto(b";", b"\r\n") - chunk_size = int(chunk_header, 16) + chunk_size = int(bytes(chunk_header), 16) if chunk_size == 0: break self._throw_away(chunk_size + 2) @@ -391,12 +391,14 @@ def _get_socket(self, host, port, proto, *, timeout=1): host, port, 0, self._socket_pool.SOCK_STREAM )[0] sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + connect_host = addr_info[-1][0] if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + connect_host = host sock.settimeout(timeout) # socket read timeout ok = True try: - sock.connect((host, port)) + sock.connect((connect_host, port)) except MemoryError: if not any(self._socket_free.items()): raise @@ -411,7 +413,7 @@ def _get_socket(self, host, port, proto, *, timeout=1): if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) sock.settimeout(timeout) # socket read timeout - sock.connect((host, port)) + sock.connect((connect_host, port)) self._open_sockets[key] = sock self._socket_free[sock] = False return sock @@ -490,7 +492,7 @@ def request( socket.send(bytes(data, "utf-8")) resp = Response(socket, self) # our response - if "location" in resp.headers and not 200 <= resp.status_code <= 299: + if "location" in resp.headers and 300 <= resp.status_code <= 399: raise NotImplementedError("Redirects not yet supported") self._last_response = resp @@ -525,19 +527,31 @@ def delete(self, url, **kw): _default_session = None # pylint: disable=invalid-name +class _FakeSSLSocket: + def __init__(self, socket, tls_mode): + self._socket = socket + self._mode = tls_mode + self.settimeout = socket.settimeout + self.send = socket.send + self.recv = socket.recv + + def connect(self, address): + return self._socket.connect(address, self._mode) class _FakeSSLContext: - @staticmethod - def wrap_socket(socket, server_hostname=None): + def __init__(self, iface): + self._iface = iface + + def wrap_socket(self, socket, server_hostname=None): """Return the same socket""" # pylint: disable=unused-argument - return socket + return _FakeSSLSocket(socket, self._iface.TLS_MODE) def set_socket(sock, iface=None): """Legacy API for setting the socket and network interface. Use a `Session` instead.""" global _default_session # pylint: disable=global-statement,invalid-name - _default_session = Session(sock, _FakeSSLContext()) + _default_session = Session(sock, _FakeSSLContext(iface)) if iface: sock.set_interface(iface) diff --git a/tests/chunk_test.py b/tests/chunk_test.py index cf2f01d..67f09ec 100644 --- a/tests/chunk_test.py +++ b/tests/chunk_test.py @@ -38,7 +38,7 @@ def test_get_text(): s = adafruit_requests.Session(pool) r = s.get("http://" + host + path) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_has_calls( [ mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), diff --git a/tests/header_test.py b/tests/header_test.py index 921fac3..2c9db60 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -20,7 +20,7 @@ def test_json(): headers = {"user-agent": "blinka/1.0.0"} r = s.get("http://" + host + "/get", headers=headers) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sent = b"".join(sent).lower() assert b"user-agent: blinka/1.0.0\r\n" in sent # The current implementation sends two user agents. Fix it, and uncomment below. diff --git a/tests/legacy_test.py b/tests/legacy_test.py index c3662e2..b943c36 100644 --- a/tests/legacy_test.py +++ b/tests/legacy_test.py @@ -20,10 +20,21 @@ def test_get_json(): adafruit_requests.set_socket(mocket, mocket.interface) r = adafruit_requests.get("http://" + host + "/get") - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) assert r.json() == response r.close() +def test_tls_mode(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("https://" + host + "/get") + + sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + assert r.json() == response + r.close() def test_post_string(): mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) @@ -33,6 +44,6 @@ def test_post_string(): adafruit_requests.set_socket(mocket, mocket.interface) data = "31F" r = adafruit_requests.post("http://" + host + "/post", data=data) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b"31F") r.close() diff --git a/tests/parse_test.py b/tests/parse_test.py index d2331dc..bef739e 100644 --- a/tests/parse_test.py +++ b/tests/parse_test.py @@ -24,5 +24,5 @@ def test_json(): s = adafruit_requests.Session(pool) r = s.get("http://" + host + "/get") - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) assert r.json() == response diff --git a/tests/post_test.py b/tests/post_test.py index 164d85c..c8660a2 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -20,7 +20,7 @@ def test_method(): s = adafruit_requests.Session(pool) r = s.post("http://" + host + "/post") - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_has_calls( [mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")] ) @@ -35,7 +35,7 @@ def test_string(): s = adafruit_requests.Session(pool) data = "31F" r = s.post("http://" + host + "/post", data=data) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b"31F") @@ -48,7 +48,7 @@ def test_form(): s = adafruit_requests.Session(pool) data = {"Date": "July 25, 2019"} r = s.post("http://" + host + "/post", data=data) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b"Date=July 25, 2019") @@ -61,5 +61,5 @@ def test_json(): s = adafruit_requests.Session(pool) json_data = {"Date": "July 25, 2019"} r = s.post("http://" + host + "/post", json=json_data) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b'{"Date": "July 25, 2019"}') diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 76e2289..c7ad9da 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -53,7 +53,7 @@ def test_get_http_text(): s = adafruit_requests.Session(pool) r = s.get("http://" + host + path) - sock.connect.assert_called_once_with((host, 80)) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_has_calls( [ mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), From def66e3b07c9251b7d67e5133804e06cda0f61ae Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 11 Sep 2020 15:23:01 -0700 Subject: [PATCH 17/19] More detailed secrets error --- examples/requests_advanced.py | 6 +++++- examples/requests_simpletest.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/requests_advanced.py b/examples/requests_advanced.py index 1c89d91..6fb81d4 100644 --- a/examples/requests_advanced.py +++ b/examples/requests_advanced.py @@ -9,7 +9,11 @@ # "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other # source control. # pylint: disable=no-name-in-module,wrong-import-order -from secrets import secrets +try: + from secrets import secrets +except ImportError: + print("WiFi secrets are kept in secrets.py, please add them there!") + raise # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) diff --git a/examples/requests_simpletest.py b/examples/requests_simpletest.py index 20b6004..4d4070e 100755 --- a/examples/requests_simpletest.py +++ b/examples/requests_simpletest.py @@ -10,7 +10,11 @@ # "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other # source control. # pylint: disable=no-name-in-module,wrong-import-order -from secrets import secrets +try: + from secrets import secrets +except ImportError: + print("WiFi secrets are kept in secrets.py, please add them there!") + raise # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) From 29e954d1040accafcf60ecc22031ba364f65a717 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 11 Sep 2020 15:25:07 -0700 Subject: [PATCH 18/19] Black --- adafruit_requests.py | 2 ++ tests/legacy_test.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/adafruit_requests.py b/adafruit_requests.py index f594b72..72e231b 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -527,6 +527,7 @@ def delete(self, url, **kw): _default_session = None # pylint: disable=invalid-name + class _FakeSSLSocket: def __init__(self, socket, tls_mode): self._socket = socket @@ -538,6 +539,7 @@ def __init__(self, socket, tls_mode): def connect(self, address): return self._socket.connect(address, self._mode) + class _FakeSSLContext: def __init__(self, iface): self._iface = iface diff --git a/tests/legacy_test.py b/tests/legacy_test.py index b943c36..3d9cdbb 100644 --- a/tests/legacy_test.py +++ b/tests/legacy_test.py @@ -24,6 +24,7 @@ def test_get_json(): assert r.json() == response r.close() + def test_tls_mode(): mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) @@ -36,6 +37,7 @@ def test_tls_mode(): assert r.json() == response r.close() + def test_post_string(): mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) From 51d9f09a749bbc3a414215b306f2202d3eaf9264 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Fri, 11 Sep 2020 15:33:52 -0700 Subject: [PATCH 19/19] pylint --- adafruit_requests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/adafruit_requests.py b/adafruit_requests.py index 72e231b..66aac31 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -537,6 +537,7 @@ def __init__(self, socket, tls_mode): self.recv = socket.recv def connect(self, address): + """connect wrapper to add non-standard mode parameter""" return self._socket.connect(address, self._mode)