diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d4888ef..2f9163b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,10 +6,6 @@ jobs: test: runs-on: ubuntu-latest steps: - - name: Dump GitHub context - env: - GITHUB_CONTEXT: ${{ toJson(github) }} - run: echo "$GITHUB_CONTEXT" - name: Set up Python 3.6 uses: actions/setup-python@v1 with: diff --git a/adafruit_requests.py b/adafruit_requests.py index 0888937..66aac31 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # # Copyright (c) 2019 ladyada for Adafruit Industries +# Copyright (c) 2020 Scott Shawcroft for Adafruit Industries # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +27,7 @@ A requests-like library for web interfacing -* Author(s): ladyada, Paul Sokolovsky +* Author(s): ladyada, Paul Sokolovsky, Scott Shawcroft Implementation Notes -------------------- @@ -49,42 +50,61 @@ """ -import gc - __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" -_the_interface = None # pylint: disable=invalid-name -_the_sock = None # pylint: disable=invalid-name +class _RawResponse: + def __init__(self, response): + self._response = response -def set_socket(sock, iface=None): - """Helper to set the global socket and optionally set the global network interface. - :param sock: socket object. - :param iface: internet interface object + def read(self, size=-1): + """Read as much as available or up to size and return it in a byte string. - """ - global _the_sock # pylint: disable=invalid-name, global-statement - _the_sock = sock - if iface: - global _the_interface # pylint: disable=invalid-name, global-statement - _the_interface = iface - _the_sock.set_interface(iface) + Do NOT use this unless you really need to. Reusing memory with `readinto` is much better. + """ + if size == -1: + return self._response.content + return self._response.socket.recv(size) + + def readinto(self, buf): + """Read as much as available into buf or until it is full. Returns the number of bytes read + into buf.""" + return self._response._readinto(buf) # pylint: disable=protected-access class Response: """The response from a request, contains all the headers/content""" + # pylint: disable=too-many-instance-attributes + encoding = None - def __init__(self, sock): + def __init__(self, sock, session=None): self.socket = sock self.encoding = "utf-8" self._cached = None - self.status_code = None - self.reason = None - self._read_so_far = 0 - self.headers = {} + self._headers = {} + + # _start_index and _receive_buffer are used when parsing headers. + # _receive_buffer will grow by 32 bytes everytime it is too small. + self._received_length = 0 + self._receive_buffer = bytearray(32) + self._remaining = None + self._chunked = False + + self._backwards_compatible = not hasattr(sock, "recv_into") + if self._backwards_compatible: + print("Socket missing recv_into. Using more memory to be compatible") + + http = self._readto(b" ") + if not http: + raise RuntimeError("Unable to read HTTP response.") + self.status_code = int(bytes(self._readto(b" "))) + self.reason = self._readto(b"\r\n") + self._parse_headers() + self._raw = None + self._session = session def __enter__(self): return self @@ -92,46 +112,217 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def _recv_into(self, buf, size=0): + if self._backwards_compatible: + size = len(buf) if size == 0 else size + b = self.socket.recv(size) + read_size = len(b) + buf[:read_size] = b + return read_size + return self.socket.recv_into(buf, size) + + def _readto(self, first, second=b""): + buf = self._receive_buffer + end = self._received_length + while True: + firsti = buf.find(first, 0, end) + secondi = -1 + if second: + secondi = buf.find(second, 0, end) + + i = -1 + needle_len = 0 + if firsti >= 0: + i = firsti + needle_len = len(first) + if secondi >= 0 and (firsti < 0 or secondi < firsti): + i = secondi + needle_len = len(second) + if i >= 0: + result = buf[:i] + new_start = i + needle_len + + if i + needle_len <= end: + new_end = end - new_start + buf[:new_end] = buf[new_start:end] + self._received_length = new_end + return result + + # Not found so load more. + + # If our buffer is full, then make it bigger to load more. + if end == len(buf): + new_size = len(buf) + 32 + new_buf = bytearray(new_size) + new_buf[: len(buf)] = buf + buf = new_buf + self._receive_buffer = buf + + read = self._recv_into(memoryview(buf)[end:]) + if read == 0: + self._received_length = 0 + return buf[:end] + end += read + + return b"" + + def _read_from_buffer(self, buf=None, nbytes=None): + if self._received_length == 0: + return 0 + read = self._received_length + if nbytes < read: + read = nbytes + membuf = memoryview(self._receive_buffer) + if buf: + buf[:read] = membuf[:read] + if read < self._received_length: + new_end = self._received_length - read + self._receive_buffer[:new_end] = membuf[read : self._received_length] + self._received_length = new_end + else: + self._received_length = 0 + return read + + def _readinto(self, buf): + if not self.socket: + raise RuntimeError( + "Newer Response closed this one. Use Responses immediately." + ) + + if not self._remaining: + # Consume the chunk header if need be. + if self._chunked: + # Consume trailing \r\n for chunks 2+ + if self._remaining == 0: + self._throw_away(2) + chunk_header = self._readto(b";", b"\r\n") + http_chunk_size = int(bytes(chunk_header), 16) + if http_chunk_size == 0: + self._chunked = False + self._parse_headers() + return 0 + self._remaining = http_chunk_size + else: + return 0 + + nbytes = len(buf) + if nbytes > self._remaining: + nbytes = self._remaining + + read = self._read_from_buffer(buf, nbytes) + if read == 0: + read = self._recv_into(buf, nbytes) + self._remaining -= read + + return read + + def _throw_away(self, nbytes): + nbytes -= self._read_from_buffer(nbytes=nbytes) + + buf = self._receive_buffer + for _ in range(nbytes // len(buf)): + self._recv_into(buf) + remaining = nbytes % len(buf) + if remaining: + self._recv_into(buf, remaining) + def close(self): - """Close, delete and collect the response data""" - if self.socket: + """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" + if not self.socket: + return + # Make sure we've read all of our response. + if self._cached is None: + if self._remaining > 0: + self._throw_away(self._remaining) + elif self._chunked: + while True: + chunk_header = self._readto(b";", b"\r\n") + chunk_size = int(bytes(chunk_header), 16) + if chunk_size == 0: + break + self._throw_away(chunk_size + 2) + self._parse_headers() + if self._session: + self._session._free_socket(self.socket) # pylint: disable=protected-access + else: self.socket.close() - del self.socket - del self._cached - gc.collect() + self.socket = None + + def _parse_headers(self): + """ + Parses the header portion of an HTTP request/response from the socket. + Expects first line of HTTP request/response to have been read already. + """ + while True: + title = self._readto(b": ", b"\r\n") + if not title: + break + + content = self._readto(b"\r\n") + if title and content: + title = str(title, "utf-8") + content = str(content, "utf-8") + # Check len first so we can skip the .lower allocation most of the time. + if ( + len(title) == len("content-length") + and title.lower() == "content-length" + ): + self._remaining = int(content) + if ( + len(title) == len("transfer-encoding") + and title.lower() == "transfer-encoding" + ): + self._chunked = content.lower() == "chunked" + self._headers[title] = content + + @property + def headers(self): + """ + The response headers. Does not include headers from the trailer until + the content has been read. + """ + return self._headers @property def content(self): """The HTTP content direct from the socket, as bytes""" - # print(self.headers) - try: - content_length = int(self.headers["content-length"]) - except KeyError: - content_length = 0 - # print("Content length:", content_length) - if self._cached is None: - try: - self._cached = self.socket.recv(content_length) - finally: - self.socket.close() - self.socket = None - # print("Buffer length:", len(self._cached)) + if self._cached is not None: + if isinstance(self._cached, bytes): + return self._cached + raise RuntimeError("Cannot access content after getting text or json") + + self._cached = b"".join(self.iter_content(chunk_size=32)) return self._cached @property def text(self): """The HTTP content, encoded into a string according to the HTTP header encoding""" - return str(self.content, self.encoding) + if self._cached is not None: + if isinstance(self._cached, str): + return self._cached + raise RuntimeError("Cannot access text after getting content or json") + self._cached = str(self.content, self.encoding) + return self._cached def json(self): """The HTTP content, parsed into a json dictionary""" # pylint: disable=import-outside-toplevel - try: - import json as json_module - except ImportError: - import ujson as json_module - return json_module.loads(self.content) + import json + + # The cached JSON will be a list or dictionary. + if self._cached: + if isinstance(self._cached, (list, dict)): + return self._cached + raise RuntimeError("Cannot access json after getting text or content") + if not self._raw: + self._raw = _RawResponse(self) + + obj = json.load(self._raw) + if not self._cached: + self._cached = obj + self.close() + return obj def iter_content(self, chunk_size=1, decode_unicode=False): """An iterator that will stream data by only reading 'chunk_size' @@ -139,74 +330,143 @@ def iter_content(self, chunk_size=1, decode_unicode=False): if decode_unicode: raise NotImplementedError("Unicode not supported") + b = bytearray(chunk_size) while True: - chunk = self.socket.recv(chunk_size) - if chunk: - yield chunk + size = self._readinto(b) + if size == 0: + break + if size < chunk_size: + chunk = bytes(memoryview(b)[:size]) else: - return + chunk = bytes(b) + yield chunk + self.close() -# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals -def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): - """Perform an HTTP request to the given url which we will parse to determine - whether to use SSL ('https://') or not. We can also send some provided 'data' - or a json dictionary which we will stringify. 'headers' is optional HTTP headers - sent along. 'stream' will determine if we buffer everything, or whether to only - read only when requested - """ - global _the_interface # pylint: disable=global-statement, invalid-name - global _the_sock # pylint: disable=global-statement, invalid-name - - if not headers: - headers = {} - - try: - proto, dummy, host, path = url.split("/", 3) - # replace spaces in path - path = path.replace(" ", "%20") - except ValueError: - proto, dummy, host = url.split("/", 2) - path = "" - if proto == "http:": - port = 80 - elif proto == "https:": - port = 443 - else: - raise ValueError("Unsupported protocol: " + proto) - - if ":" in host: - host, port = host.split(":", 1) - port = int(port) - - addr_info = _the_sock.getaddrinfo(host, port, 0, _the_sock.SOCK_STREAM)[0] - sock = _the_sock.socket(addr_info[0], addr_info[1], addr_info[2]) - resp = Response(sock) # our response - - sock.settimeout(timeout) # socket read timeout - - try: +class Session: + """HTTP session that shares sockets and ssl context.""" + + def __init__(self, socket_pool, ssl_context=None): + self._socket_pool = socket_pool + self._ssl_context = ssl_context + # Hang onto open sockets so that we can reuse them. + self._open_sockets = {} + self._socket_free = {} + self._last_response = None + + def _free_socket(self, socket): + + if socket not in self._open_sockets.values(): + raise RuntimeError("Socket not from session") + self._socket_free[socket] = True + + def _free_sockets(self): + free_sockets = [] + for sock in self._socket_free: + if self._socket_free[sock]: + sock.close() + free_sockets.append(sock) + for sock in free_sockets: + del self._socket_free[sock] + key = None + for k in self._open_sockets: + if self._open_sockets[k] == sock: + key = k + break + if key: + del self._open_sockets[key] + + def _get_socket(self, host, port, proto, *, timeout=1): + key = (host, port, proto) + if key in self._open_sockets: + sock = self._open_sockets[key] + if self._socket_free[sock]: + self._socket_free[sock] = False + return sock + if proto == "https:" and not self._ssl_context: + raise RuntimeError( + "ssl_context must be set before using adafruit_requests for https" + ) + addr_info = self._socket_pool.getaddrinfo( + host, port, 0, self._socket_pool.SOCK_STREAM + )[0] + sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + connect_host = addr_info[-1][0] if proto == "https:": - conntype = _the_interface.TLS_MODE - sock.connect( - (host, port), conntype - ) # for SSL we need to know the host name + sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + connect_host = host + sock.settimeout(timeout) # socket read timeout + ok = True + try: + sock.connect((connect_host, port)) + except MemoryError: + if not any(self._socket_free.items()): + raise + ok = False + + # We couldn't connect due to memory so clean up the open sockets. + if not ok: + self._free_sockets() + # Recreate the socket because the ESP-IDF won't retry the connection if it failed once. + sock = None # Clear first so the first socket can be cleaned up. + sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + if proto == "https:": + sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + sock.settimeout(timeout) # socket read timeout + sock.connect((connect_host, port)) + self._open_sockets[key] = sock + self._socket_free[sock] = False + return sock + + # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals + def request( + self, method, url, data=None, json=None, headers=None, stream=False, timeout=60 + ): + """Perform an HTTP request to the given url which we will parse to determine + whether to use SSL ('https://') or not. We can also send some provided 'data' + or a json dictionary which we will stringify. 'headers' is optional HTTP headers + sent along. 'stream' will determine if we buffer everything, or whether to only + read only when requested + """ + if not headers: + headers = {} + + try: + proto, dummy, host, path = url.split("/", 3) + # replace spaces in path + path = path.replace(" ", "%20") + except ValueError: + proto, dummy, host = url.split("/", 2) + path = "" + if proto == "http:": + port = 80 + elif proto == "https:": + port = 443 else: - conntype = _the_interface.TCP_MODE - sock.connect(addr_info[-1], conntype) - sock.send( - b"%s /%s HTTP/1.0\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")) + raise ValueError("Unsupported protocol: " + proto) + + if ":" in host: + host, port = host.split(":", 1) + port = int(port) + + if self._last_response: + self._last_response.close() + self._last_response = None + + socket = self._get_socket(host, port, proto, timeout=timeout) + socket.send( + b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")) ) if "Host" not in headers: - sock.send(b"Host: %s\r\n" % bytes(host, "utf-8")) + socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) if "User-Agent" not in headers: - sock.send(b"User-Agent: Adafruit CircuitPython\r\n") + socket.send(b"User-Agent: Adafruit CircuitPython\r\n") # Iterate over keys to avoid tuple alloc for k in headers: - sock.send(k.encode()) - sock.send(b": ") - sock.send(headers[k].encode()) - sock.send(b"\r\n") + socket.send(k.encode()) + socket.send(b": ") + socket.send(headers[k].encode()) + socket.send(b"\r\n") if json is not None: assert data is None # pylint: disable=import-outside-toplevel @@ -214,98 +474,130 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo import json as json_module except ImportError: import ujson as json_module - # pylint: enable=import-outside-toplevel data = json_module.dumps(json) - sock.send(b"Content-Type: application/json\r\n") + socket.send(b"Content-Type: application/json\r\n") if data: if isinstance(data, dict): - sock.send(b"Content-Type: application/x-www-form-urlencoded\r\n") + socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n") _post_data = "" for k in data: _post_data = "{}&{}={}".format(_post_data, k, data[k]) data = _post_data[1:] - sock.send(b"Content-Length: %d\r\n" % len(data)) - sock.send(b"\r\n") + socket.send(b"Content-Length: %d\r\n" % len(data)) + socket.send(b"\r\n") if data: if isinstance(data, bytearray): - sock.send(bytes(data)) + socket.send(bytes(data)) else: - sock.send(bytes(data, "utf-8")) - - line = sock.readline() - # print(line) - line = line.split(None, 2) - status = int(line[1]) - reason = "" - if len(line) > 2: - reason = line[2].rstrip() - resp.headers = parse_headers(sock) - if resp.headers.get("transfer-encoding"): - if "chunked" in resp.headers.get("transfer-encoding"): - raise ValueError("Unsupported " + resp.headers.get("transfer-encoding")) - elif resp.headers.get("location") and not 200 <= status <= 299: + socket.send(bytes(data, "utf-8")) + + resp = Response(socket, self) # our response + if "location" in resp.headers and 300 <= resp.status_code <= 399: raise NotImplementedError("Redirects not yet supported") - except: - sock.close() - raise - - resp.status_code = status - resp.reason = reason - return resp - - -def parse_headers(sock): - """ - Parses the header portion of an HTTP request/response from the socket. - Expects first line of HTTP request/response to have been read already - return: header dictionary - rtype: Dict - """ - headers = {} - while True: - line = sock.readline() - if not line or line == b"\r\n": - break - - # print("**line: ", line) - splits = line.split(b": ", 1) - title = splits[0] - content = "" - if len(splits) > 1: - content = splits[1] - if title and content: - title = str(title.lower(), "utf-8") - content = str(content, "utf-8") - headers[title] = content - return headers + self._last_response = resp + return resp + + def head(self, url, **kw): + """Send HTTP HEAD request""" + return self.request("HEAD", url, **kw) + + def get(self, url, **kw): + """Send HTTP GET request""" + return self.request("GET", url, **kw) + + def post(self, url, **kw): + """Send HTTP POST request""" + return self.request("POST", url, **kw) + + def put(self, url, **kw): + """Send HTTP PUT request""" + return self.request("PUT", url, **kw) + + def patch(self, url, **kw): + """Send HTTP PATCH request""" + return self.request("PATCH", url, **kw) + + def delete(self, url, **kw): + """Send HTTP DELETE request""" + return self.request("DELETE", url, **kw) + + +# Backwards compatible API: + +_default_session = None # pylint: disable=invalid-name + + +class _FakeSSLSocket: + def __init__(self, socket, tls_mode): + self._socket = socket + self._mode = tls_mode + self.settimeout = socket.settimeout + self.send = socket.send + self.recv = socket.recv + + def connect(self, address): + """connect wrapper to add non-standard mode parameter""" + return self._socket.connect(address, self._mode) + + +class _FakeSSLContext: + def __init__(self, iface): + self._iface = iface + + def wrap_socket(self, socket, server_hostname=None): + """Return the same socket""" + # pylint: disable=unused-argument + return _FakeSSLSocket(socket, self._iface.TLS_MODE) + + +def set_socket(sock, iface=None): + """Legacy API for setting the socket and network interface. Use a `Session` instead.""" + global _default_session # pylint: disable=global-statement,invalid-name + _default_session = Session(sock, _FakeSSLContext(iface)) + if iface: + sock.set_interface(iface) + + +def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): + """Send HTTP request""" + # pylint: disable=too-many-arguments + _default_session.request( + method, + url, + data=data, + json=json, + headers=headers, + stream=stream, + timeout=timeout, + ) def head(url, **kw): """Send HTTP HEAD request""" - return request("HEAD", url, **kw) + return _default_session.request("HEAD", url, **kw) def get(url, **kw): """Send HTTP GET request""" - return request("GET", url, **kw) + return _default_session.request("GET", url, **kw) def post(url, **kw): """Send HTTP POST request""" - return request("POST", url, **kw) + return _default_session.request("POST", url, **kw) def put(url, **kw): """Send HTTP PUT request""" - return request("PUT", url, **kw) + return _default_session.request("PUT", url, **kw) def patch(url, **kw): """Send HTTP PATCH request""" - return request("PATCH", url, **kw) + return _default_session.request("PATCH", url, **kw) def delete(url, **kw): """Send HTTP DELETE request""" - return request("DELETE", url, **kw) + return _default_session.request("DELETE", url, **kw) diff --git a/examples/requests_advanced.py b/examples/requests_advanced.py index d74f9af..6fb81d4 100644 --- a/examples/requests_advanced.py +++ b/examples/requests_advanced.py @@ -5,6 +5,16 @@ from adafruit_esp32spi import adafruit_esp32spi import adafruit_requests as requests +# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and +# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other +# source control. +# pylint: disable=no-name-in-module,wrong-import-order +try: + from secrets import secrets +except ImportError: + print("WiFi secrets are kept in secrets.py, please add them there!") + raise + # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) esp32_ready = DigitalInOut(board.ESP_BUSY) @@ -21,14 +31,15 @@ print("Connecting to AP...") while not esp.is_connected: try: - esp.connect_AP(b"MY_SSID_NAME", b"MY_SSID_PASSWORD") + esp.connect_AP(secrets["ssid"], secrets["password"]) except RuntimeError as e: print("could not connect to AP, retrying: ", e) continue print("Connected to", str(esp.ssid, "utf-8"), "\tRSSI:", esp.rssi) # Initialize a requests object with a socket and esp32spi interface -requests.set_socket(socket, esp) +socket.set_interface(esp) +requests.set_socket(socket) JSON_GET_URL = "http://httpbin.org/get" diff --git a/examples/requests_advanced_cpython.py b/examples/requests_advanced_cpython.py new file mode 100644 index 0000000..379620e --- /dev/null +++ b/examples/requests_advanced_cpython.py @@ -0,0 +1,28 @@ +import socket +import adafruit_requests + +http = adafruit_requests.Session(socket) + +JSON_GET_URL = "http://httpbin.org/get" + +# Define a custom header as a dict. +headers = {"user-agent": "blinka/1.0.0"} + +print("Fetching JSON data from %s..." % JSON_GET_URL) +response = http.get(JSON_GET_URL, headers=headers) +print("-" * 60) + +json_data = response.json() +headers = json_data["headers"] +print("Response's Custom User-Agent Header: {0}".format(headers["User-Agent"])) +print("-" * 60) + +# Read Response's HTTP status code +print("Response HTTP Status Code: ", response.status_code) +print("-" * 60) + +# Read Response, as raw bytes instead of pretty text +print("Raw Response: ", response.content) + +# Close, delete and collect the response data +response.close() diff --git a/examples/requests_github_cpython.py b/examples/requests_github_cpython.py new file mode 100755 index 0000000..e258d1f --- /dev/null +++ b/examples/requests_github_cpython.py @@ -0,0 +1,13 @@ +# adafruit_requests usage with a CPython socket +import socket +import ssl +import adafruit_requests + +http = adafruit_requests.Session(socket, ssl.create_default_context()) + +print("Getting CircuitPython star count") +headers = {"Transfer-Encoding": "chunked"} +response = http.get( + "https://api.github.com/repos/adafruit/circuitpython", headers=headers +) +print("circuitpython stars", response.json()["stargazers_count"]) diff --git a/examples/requests_https_cpython.py b/examples/requests_https_cpython.py new file mode 100755 index 0000000..dca4d31 --- /dev/null +++ b/examples/requests_https_cpython.py @@ -0,0 +1,45 @@ +# adafruit_requests usage with a CPython socket +import socket +import ssl +import adafruit_requests as requests + +https = requests.Session(socket, ssl.create_default_context()) + +TEXT_URL = "https://wifitest.adafruit.com/testwifi/index.html" +JSON_GET_URL = "https://httpbin.org/get" +JSON_POST_URL = "https://httpbin.org/post" + +# print("Fetching text from %s" % TEXT_URL) +# response = requests.get(TEXT_URL) +# print("-" * 40) + +# print("Text Response: ", response.text) +# print("-" * 40) +# response.close() + +print("Fetching JSON data from %s" % JSON_GET_URL) +response = https.get(JSON_GET_URL) +print("-" * 40) + +print("JSON Response: ", response.json()) +print("-" * 40) + +data = "31F" +print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) +response = https.post(JSON_POST_URL, data=data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'data' key from json_resp dict. +print("Data received from server:", json_resp["data"]) +print("-" * 40) + +json_data = {"Date": "July 25, 2019"} +print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) +response = https.post(JSON_POST_URL, json=json_data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'json' key from json_resp dict. +print("JSON Data received from server:", json_resp["json"]) +print("-" * 40) diff --git a/examples/requests_simpletest.py b/examples/requests_simpletest.py index 444648b..4d4070e 100755 --- a/examples/requests_simpletest.py +++ b/examples/requests_simpletest.py @@ -6,6 +6,16 @@ from adafruit_esp32spi import adafruit_esp32spi import adafruit_requests as requests +# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and +# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other +# source control. +# pylint: disable=no-name-in-module,wrong-import-order +try: + from secrets import secrets +except ImportError: + print("WiFi secrets are kept in secrets.py, please add them there!") + raise + # If you are using a board with pre-defined ESP32 Pins: esp32_cs = DigitalInOut(board.ESP_CS) esp32_ready = DigitalInOut(board.ESP_BUSY) @@ -22,14 +32,15 @@ print("Connecting to AP...") while not esp.is_connected: try: - esp.connect_AP(b"MY_SSID_NAME", b"MY_SSID_PASSWORD") + esp.connect_AP(secrets["ssid"], secrets["password"]) except RuntimeError as e: print("could not connect to AP, retrying: ", e) continue print("Connected to", str(esp.ssid, "utf-8"), "\tRSSI:", esp.rssi) # Initialize a requests object with a socket and esp32spi interface -requests.set_socket(socket, esp) +socket.set_interface(esp) +requests.set_socket(socket) TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" JSON_GET_URL = "http://httpbin.org/get" diff --git a/examples/requests_simpletest_cpython.py b/examples/requests_simpletest_cpython.py new file mode 100755 index 0000000..db9fca2 --- /dev/null +++ b/examples/requests_simpletest_cpython.py @@ -0,0 +1,44 @@ +# adafruit_requests usage with a CPython socket +import socket +import adafruit_requests + +http = adafruit_requests.Session(socket) + +TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" +JSON_GET_URL = "http://httpbin.org/get" +JSON_POST_URL = "http://httpbin.org/post" + +print("Fetching text from %s" % TEXT_URL) +response = http.get(TEXT_URL) +print("-" * 40) + +print("Text Response: ", response.text) +print("-" * 40) + +print("Fetching JSON data from %s" % JSON_GET_URL) +response = http.get(JSON_GET_URL) +print("-" * 40) + +print("JSON Response: ", response.json()) +print("-" * 40) +response.close() + +data = "31F" +print("POSTing data to {0}: {1}".format(JSON_POST_URL, data)) +response = http.post(JSON_POST_URL, data=data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'data' key from json_resp dict. +print("Data received from server:", json_resp["data"]) +print("-" * 40) + +json_data = {"Date": "July 25, 2019"} +print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data)) +response = http.post(JSON_POST_URL, json=json_data) +print("-" * 40) + +json_resp = response.json() +# Parse out the 'json' key from json_resp dict. +print("JSON Data received from server:", json_resp["json"]) +print("-" * 40) diff --git a/tests/chunk_test.py b/tests/chunk_test.py new file mode 100644 index 0000000..67f09ec --- /dev/null +++ b/tests/chunk_test.py @@ -0,0 +1,48 @@ +from unittest import mock +import mocket +import adafruit_requests + +ip = "1.2.3.4" +host = "wifitest.adafruit.com" +path = "/testwifi/index.html" +text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" +headers = b"HTTP/1.0 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + + +def _chunk(response, split): + i = 0 + chunked = b"" + while i < len(response): + remaining = len(response) - i + chunk_size = split + if remaining < chunk_size: + chunk_size = remaining + new_i = i + chunk_size + chunked += ( + hex(chunk_size)[2:].encode("ascii") + b"\r\n" + response[i:new_i] + b"\r\n" + ) + i = new_i + # The final chunk is zero length. + chunked += b"0\r\n\r\n" + return chunked + + +def test_get_text(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + c = _chunk(text, 33) + print(c) + sock = mocket.Mocket(headers + c) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + path) + + sock.connect.assert_called_once_with((ip, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + assert r.text == str(text, "utf-8") diff --git a/tests/header_test.py b/tests/header_test.py index d1d273e..2c9db60 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -9,17 +9,18 @@ def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response_headers) - mocket.socket.return_value = sock + pool.socket.return_value = sock sent = [] sock.send.side_effect = sent.append - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) headers = {"user-agent": "blinka/1.0.0"} - r = adafruit_requests.get("http://" + host + "/get", headers=headers) + r = s.get("http://" + host + "/get", headers=headers) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((ip, 80)) sent = b"".join(sent).lower() assert b"user-agent: blinka/1.0.0\r\n" in sent # The current implementation sends two user agents. Fix it, and uncomment below. diff --git a/tests/legacy_mocket.py b/tests/legacy_mocket.py new file mode 100644 index 0000000..4a37bd2 --- /dev/null +++ b/tests/legacy_mocket.py @@ -0,0 +1,32 @@ +from unittest import mock + +SOCK_STREAM = 0 + +set_interface = mock.Mock() +interface = mock.MagicMock() +getaddrinfo = mock.Mock() +socket = mock.Mock() + + +class Mocket: + def __init__(self, response): + self.settimeout = mock.Mock() + self.close = mock.Mock() + self.connect = mock.Mock() + self.send = mock.Mock() + self.readline = mock.Mock(side_effect=self._readline) + self.recv = mock.Mock(side_effect=self._recv) + self._response = response + self._position = 0 + + def _readline(self): + i = self._response.find(b"\r\n", self._position) + r = self._response[self._position : i + 2] + self._position = i + 2 + return r + + def _recv(self, count): + end = self._position + count + r = self._response[self._position : end] + self._position = end + return r diff --git a/tests/legacy_test.py b/tests/legacy_test.py new file mode 100644 index 0000000..3d9cdbb --- /dev/null +++ b/tests/legacy_test.py @@ -0,0 +1,51 @@ +from unittest import mock +import legacy_mocket as mocket +import json +import adafruit_requests + +ip = "1.2.3.4" +host = "httpbin.org" +response = {"Date": "July 25, 2019"} +encoded = json.dumps(response).encode("utf-8") +headers = "HTTP/1.0 200 OK\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( + "utf-8" +) + + +def test_get_json(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("http://" + host + "/get") + + sock.connect.assert_called_once_with((ip, 80)) + assert r.json() == response + r.close() + + +def test_tls_mode(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("https://" + host + "/get") + + sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + assert r.json() == response + r.close() + + +def test_post_string(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + mocket.socket.return_value = sock + + adafruit_requests.set_socket(mocket, mocket.interface) + data = "31F" + r = adafruit_requests.post("http://" + host + "/post", data=data) + sock.connect.assert_called_once_with((ip, 80)) + sock.send.assert_called_with(b"31F") + r.close() diff --git a/tests/mocket.py b/tests/mocket.py index 684aece..ec9a557 100644 --- a/tests/mocket.py +++ b/tests/mocket.py @@ -1,12 +1,12 @@ from unittest import mock -SOCK_STREAM = 0 -getaddrinfo = mock.Mock() -socket = mock.Mock() -set_interface = mock.Mock() +class MocketPool: + SOCK_STREAM = 0 -interface = mock.MagicMock() + def __init__(self): + self.getaddrinfo = mock.Mock() + self.socket = mock.Mock() class Mocket: @@ -17,6 +17,7 @@ def __init__(self, response): self.send = mock.Mock() self.readline = mock.Mock(side_effect=self._readline) self.recv = mock.Mock(side_effect=self._recv) + self.recv_into = mock.Mock(side_effect=self._recv_into) self._response = response self._position = 0 @@ -31,3 +32,22 @@ def _recv(self, count): r = self._response[self._position : end] self._position = end return r + + def _recv_into(self, buf, nbytes=0): + assert isinstance(nbytes, int) and nbytes >= 0 + read = nbytes if nbytes > 0 else len(buf) + remaining = len(self._response) - self._position + if read > remaining: + read = remaining + end = self._position + read + buf[:read] = self._response[self._position : end] + self._position = end + return read + + +class SSLContext: + def __init__(self): + self.wrap_socket = mock.Mock(side_effect=self._wrap_socket) + + def _wrap_socket(self, sock, server_hostname=None): + return sock diff --git a/tests/parse_test.py b/tests/parse_test.py index 477a128..bef739e 100644 --- a/tests/parse_test.py +++ b/tests/parse_test.py @@ -7,17 +7,22 @@ host = "httpbin.org" response = {"Date": "July 25, 2019"} encoded = json.dumps(response).encode("utf-8") -headers = "HTTP/1.0 200 OK\r\nContent-Length: {}\r\n\r\n".format(len(encoded)).encode( +# Padding here tests the case where a header line is exactly 32 bytes buffered by +# aligning the Content-Type header after it. +headers = "HTTP/1.0 200 OK\r\npadding: 000\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n".format( + len(encoded) +).encode( "utf-8" ) def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("http://" + host + "/get") - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + "/get") + sock.connect.assert_called_once_with((ip, 80)) assert r.json() == response diff --git a/tests/post_test.py b/tests/post_test.py index a2f5977..c8660a2 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -13,37 +13,53 @@ def test_method(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.post("http://" + host + "/post") - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + s = adafruit_requests.Session(pool) + r = s.post("http://" + host + "/post") + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_has_calls( - [mock.call(b"POST /post HTTP/1.0\r\n"), mock.call(b"Host: httpbin.org\r\n")] + [mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")] ) def test_string(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) data = "31F" - r = adafruit_requests.post("http://" + host + "/post", data=data) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + r = s.post("http://" + host + "/post", data=data) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b"31F") +def test_form(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + data = {"Date": "July 25, 2019"} + r = s.post("http://" + host + "/post", data=data) + sock.connect.assert_called_once_with((ip, 80)) + sock.send.assert_called_with(b"Date=July 25, 2019") + + def test_json(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(headers + encoded) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) + s = adafruit_requests.Session(pool) json_data = {"Date": "July 25, 2019"} - r = adafruit_requests.post("http://" + host + "/post", json=json_data) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + r = s.post("http://" + host + "/post", json=json_data) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b'{"Date": "July 25, 2019"}') diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 0eec002..c7ad9da 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -1,5 +1,6 @@ from unittest import mock import mocket +import pytest import adafruit_requests ip = "1.2.3.4" @@ -9,40 +10,54 @@ response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text +def test_get_https_no_ssl(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + pool.socket.return_value = sock + + s = adafruit_requests.Session(pool) + with pytest.raises(RuntimeError): + r = s.get("https://" + host + path) + + def test_get_https_text(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response) - mocket.socket.return_value = sock + pool.socket.return_value = sock + ssl = mocket.SSLContext() - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("https://" + host + path) + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) - sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + sock.connect.assert_called_once_with((host, 443)) sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.0\r\n"), + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), mock.call(b"Host: wifitest.adafruit.com\r\n"), ] ) assert r.text == str(text, "utf-8") + # Close isn't needed but can be called to release the socket early. + r.close() + def test_get_http_text(): - mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) sock = mocket.Mocket(response) - mocket.socket.return_value = sock + pool.socket.return_value = sock - adafruit_requests.set_socket(mocket, mocket.interface) - r = adafruit_requests.get("http://" + host + path) + s = adafruit_requests.Session(pool) + r = s.get("http://" + host + path) - sock.connect.assert_called_once_with((ip, 80), mocket.interface.TCP_MODE) + sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.0\r\n"), + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), mock.call(b"Host: wifitest.adafruit.com\r\n"), ] ) assert r.text == str(text, "utf-8") - - -# Add a chunked response test when we support HTTP 1.1 diff --git a/tests/reuse_test.py b/tests/reuse_test.py new file mode 100644 index 0000000..2e06931 --- /dev/null +++ b/tests/reuse_test.py @@ -0,0 +1,107 @@ +from unittest import mock +import mocket +import pytest +import adafruit_requests + +ip = "1.2.3.4" +host = "wifitest.adafruit.com" +host2 = "wifitest2.adafruit.com" +path = "/testwifi/index.html" +text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" +response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text + +# def test_get_twice(): +# pool = mocket.MocketPool() +# pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) +# sock = mocket.Mocket(response + response) +# pool.socket.return_value = sock +# ssl = mocket.SSLContext() + +# s = adafruit_requests.Session(pool, ssl) +# r = s.get("https://" + host + path) + +# sock.send.assert_has_calls( +# [ +# mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), +# mock.call(b"Host: wifitest.adafruit.com\r\n"), +# ] +# ) +# assert r.text == str(text, "utf-8") + +# r = s.get("https://" + host + path + "2") +# sock.send.assert_has_calls( +# [ +# mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), +# mock.call(b"Host: wifitest.adafruit.com\r\n"), +# ] +# ) + +# assert r.text == str(text, "utf-8") +# sock.connect.assert_called_once_with((host, 443)) +# pool.socket.assert_called_once() + + +def test_get_twice_after_second(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response + response) + pool.socket.return_value = sock + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + + r2 = s.get("https://" + host + path + "2") + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + sock.connect.assert_called_once_with((host, 443)) + pool.socket.assert_called_once() + + with pytest.raises(RuntimeError): + r.text == str(text, "utf-8") + + +def test_connect_out_of_memory(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + sock2 = mocket.Mocket(response) + sock3 = mocket.Mocket(response) + pool.socket.side_effect = [sock, sock2, sock3] + sock2.connect.side_effect = MemoryError() + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest.adafruit.com\r\n"), + ] + ) + assert r.text == str(text, "utf-8") + + r = s.get("https://" + host2 + path) + sock3.send.assert_has_calls( + [ + mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), + mock.call(b"Host: wifitest2.adafruit.com\r\n"), + ] + ) + + assert r.text == str(text, "utf-8") + sock.close.assert_called_once() + sock.connect.assert_called_once_with((host, 443)) + sock3.connect.assert_called_once_with((host2, 443))