Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CPython/Legacy Socket SSL Problems #65

Merged
merged 6 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
run: |
pylint $( find . -path './adafruit*.py' )
([[ ! -d "examples" ]] || pylint --disable=missing-docstring,invalid-name,bad-whitespace $( find . -path "./examples/*.py" ))
([[ ! -d "examples" ]] || pylint --disable=missing-docstring,invalid-name,bad-whitespace $( find . -path "./examples/*/*.py" ))
- name: Build assets
run: circuitpython-build-bundles --filename_prefix ${{ steps.repo-name.outputs.repo-name }} --library_location .
- name: Archive bundles
Expand Down
103 changes: 32 additions & 71 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
const(0x05): "Connection Refused - Unauthorized",
}

_the_interface = None # pylint: disable=invalid-name
_the_sock = None # pylint: disable=invalid-name
_default_sock = None # pylint: disable=invalid-name
_fake_context = None # pylint: disable=invalid-name


class MMQTTException(Exception):
Expand All @@ -74,17 +74,17 @@ class MMQTTException(Exception):

# Legacy ESP32SPI Socket API
def set_socket(sock, iface=None):
"""Legacy API for setting the socket and network interface, use a Session instead.

"""Legacy API for setting the socket and network interface.
:param sock: socket object.
:param iface: internet interface object

"""
global _the_sock # pylint: disable=invalid-name, global-statement
_the_sock = sock
global _default_sock # pylint: disable=invalid-name, global-statement
global _fake_context # pylint: disable=invalid-name, global-statement
_default_sock = sock
if iface:
global _the_interface # pylint: disable=invalid-name, global-statement
_the_interface = iface
_the_sock.set_interface(iface)
_default_sock.set_interface(iface)
_fake_context = _FakeSSLContext(iface)


class _FakeSSLSocket:
Expand Down Expand Up @@ -144,18 +144,7 @@ def __init__(
):

self._socket_pool = socket_pool
# Legacy API - if we do not have a socket pool, use default socket
if self._socket_pool is None:
self._socket_pool = _the_sock

self._ssl_context = ssl_context
# Legacy API - if we do not have SSL context, fake it
if self._ssl_context is None:
self._ssl_context = _FakeSSLContext(_the_interface)

# Hang onto open sockets so that we can reuse them
self._socket_free = {}
self._open_sockets = {}
self._sock = None
self._backwards_compatible_sock = False

Expand Down Expand Up @@ -214,62 +203,37 @@ def __init__(
self.on_subscribe = None
self.on_unsubscribe = None

# Socket helpers
def _free_socket(self, socket):
"""Frees a socket for re-use."""
if socket not in self._open_sockets.values():
raise RuntimeError("Socket not from MQTT client.")
self._socket_free[socket] = True

def _close_socket(self, socket):
"""Closes a slocket."""
socket.close()
del self._socket_free[socket]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == socket:
key = k
break
if key:
del self._open_sockets[key]

def _free_sockets(self):
"""Closes all free sockets."""
free_sockets = []
for sock in self._socket_free:
if self._socket_free[sock]:
free_sockets.append(sock)
for sock in free_sockets:
self._close_socket(sock)

# pylint: disable=too-many-branches
def _get_socket(self, host, port, *, timeout=1):
key = (host, port)
if key in self._open_sockets:
sock = self._open_sockets[key]
if self._socket_free[sock]:
self._socket_free[sock] = False
return sock
def _get_connect_socket(self, host, port, *, timeout=1):
"""Obtains a new socket and connects to a broker.
:param str host: Desired broker hostname
:param int port: Desired broker port
:param int timeout: Desired socket timeout
"""
# For reconnections - check if we're using a socket already and close it
if self._sock:
self._sock.close()
self._sock = None

# Legacy API - use the interface's socket instead of a passed socket pool
if self._socket_pool is None:
self._socket_pool = _default_sock

# Legacy API - fake the ssl context
if self._ssl_context is None:
self._ssl_context = _fake_context

if port == 8883 and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
)

# Legacy API - use a default socket instead of socket pool
if self._socket_pool is None:
self._socket_pool = _the_sock

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]
retry_count = 0

sock = None
retry_count = 0
while retry_count < 5 and sock is None:
if retry_count > 0:
if any(self._socket_free.items()):
self._free_sockets()
else:
raise RuntimeError("Sending request failed")
retry_count += 1

try:
Expand Down Expand Up @@ -298,9 +262,6 @@ def _get_socket(self, host, port, *, timeout=1):
raise RuntimeError("Repeated socket failures")

self._backwards_compatible_sock = not hasattr(sock, "recv_into")

self._open_sockets[key] = sock
self._socket_free[sock] = False
return sock

def __enter__(self):
Expand Down Expand Up @@ -463,8 +424,8 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
if self.logger:
self.logger.debug("Attempting to establish MQTT connection...")

# Attempt to get a new socket
self._sock = self._get_socket(self.broker, self.port)
# Get a new socket
self._sock = self._get_connect_socket(self.broker, self.port)

# Fixed Header
fixed_header = bytearray([0x10])
Expand Down
19 changes: 9 additions & 10 deletions examples/cpython/minimqtt_simpletest_cpython.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries
# SPDX-License-Identifier: MIT

import ssl
import socket
import adafruit_minimqtt.adafruit_minimqtt as MQTT

### Secrets File Setup ###

# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and
# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other
# source control.
# pylint: disable=no-name-in-module,wrong-import-order
try:
from secrets import secrets
except ImportError:
print("Connection secrets are kept in secrets.py, please add them there!")
print("WiFi secrets are kept in secrets.py, please add them there!")
raise

### Topic Setup ###

# MQTT Topic
# Use this topic if you'd like to connect to a standard MQTT broker
# mqtt_topic = "test/topic"
mqtt_topic = "test/topic"

# Adafruit IO-style Topic
# Use this topic if you'd like to connect to io.adafruit.com
mqtt_topic = secrets["aio_username"] + "/feeds/temperature"
# mqtt_topic = secrets["aio_username"] + "/feeds/temperature"

### Code ###

# Keep track of client connection state
disconnect_client = False

# Define callback methods which are called when events occur
# pylint: disable=unused-argument, redefined-outer-name
def connect(mqtt_client, userdata, flags, rc):
Expand Down Expand Up @@ -65,10 +64,10 @@ def message(client, topic, message):
# Set up a MiniMQTT Client
mqtt_client = MQTT.MQTT(
broker=secrets["broker"],
port=1883,
username=secrets["aio_username"],
password=secrets["aio_key"],
socket_pool=socket,
ssl_context=ssl.create_default_context(),
)

# Connect callback handlers to mqtt_client
Expand Down