Skip to content

Commit

Permalink
Merge pull request #65 from brentru/update-cpython-example
Browse files Browse the repository at this point in the history
Fix CPython/Legacy Socket SSL Problems
  • Loading branch information
tannewt authored Feb 24, 2021
2 parents b0b4418 + c9ab6ab commit fabe279
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 81 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
run: |
pylint $( find . -path './adafruit*.py' )
([[ ! -d "examples" ]] || pylint --disable=missing-docstring,invalid-name,bad-whitespace $( find . -path "./examples/*.py" ))
([[ ! -d "examples" ]] || pylint --disable=missing-docstring,invalid-name,bad-whitespace $( find . -path "./examples/*/*.py" ))
- name: Build assets
run: circuitpython-build-bundles --filename_prefix ${{ steps.repo-name.outputs.repo-name }} --library_location .
- name: Archive bundles
Expand Down
103 changes: 32 additions & 71 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
const(0x05): "Connection Refused - Unauthorized",
}

_the_interface = None # pylint: disable=invalid-name
_the_sock = None # pylint: disable=invalid-name
_default_sock = None # pylint: disable=invalid-name
_fake_context = None # pylint: disable=invalid-name


class MMQTTException(Exception):
Expand All @@ -74,17 +74,17 @@ class MMQTTException(Exception):

# Legacy ESP32SPI Socket API
def set_socket(sock, iface=None):
"""Legacy API for setting the socket and network interface, use a Session instead.
"""Legacy API for setting the socket and network interface.
:param sock: socket object.
:param iface: internet interface object
"""
global _the_sock # pylint: disable=invalid-name, global-statement
_the_sock = sock
global _default_sock # pylint: disable=invalid-name, global-statement
global _fake_context # pylint: disable=invalid-name, global-statement
_default_sock = sock
if iface:
global _the_interface # pylint: disable=invalid-name, global-statement
_the_interface = iface
_the_sock.set_interface(iface)
_default_sock.set_interface(iface)
_fake_context = _FakeSSLContext(iface)


class _FakeSSLSocket:
Expand Down Expand Up @@ -144,18 +144,7 @@ def __init__(
):

self._socket_pool = socket_pool
# Legacy API - if we do not have a socket pool, use default socket
if self._socket_pool is None:
self._socket_pool = _the_sock

self._ssl_context = ssl_context
# Legacy API - if we do not have SSL context, fake it
if self._ssl_context is None:
self._ssl_context = _FakeSSLContext(_the_interface)

# Hang onto open sockets so that we can reuse them
self._socket_free = {}
self._open_sockets = {}
self._sock = None
self._backwards_compatible_sock = False

Expand Down Expand Up @@ -214,62 +203,37 @@ def __init__(
self.on_subscribe = None
self.on_unsubscribe = None

# Socket helpers
def _free_socket(self, socket):
"""Frees a socket for re-use."""
if socket not in self._open_sockets.values():
raise RuntimeError("Socket not from MQTT client.")
self._socket_free[socket] = True

def _close_socket(self, socket):
"""Closes a slocket."""
socket.close()
del self._socket_free[socket]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == socket:
key = k
break
if key:
del self._open_sockets[key]

def _free_sockets(self):
"""Closes all free sockets."""
free_sockets = []
for sock in self._socket_free:
if self._socket_free[sock]:
free_sockets.append(sock)
for sock in free_sockets:
self._close_socket(sock)

# pylint: disable=too-many-branches
def _get_socket(self, host, port, *, timeout=1):
key = (host, port)
if key in self._open_sockets:
sock = self._open_sockets[key]
if self._socket_free[sock]:
self._socket_free[sock] = False
return sock
def _get_connect_socket(self, host, port, *, timeout=1):
"""Obtains a new socket and connects to a broker.
:param str host: Desired broker hostname
:param int port: Desired broker port
:param int timeout: Desired socket timeout
"""
# For reconnections - check if we're using a socket already and close it
if self._sock:
self._sock.close()
self._sock = None

# Legacy API - use the interface's socket instead of a passed socket pool
if self._socket_pool is None:
self._socket_pool = _default_sock

# Legacy API - fake the ssl context
if self._ssl_context is None:
self._ssl_context = _fake_context

if port == 8883 and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
)

# Legacy API - use a default socket instead of socket pool
if self._socket_pool is None:
self._socket_pool = _the_sock

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]
retry_count = 0

sock = None
retry_count = 0
while retry_count < 5 and sock is None:
if retry_count > 0:
if any(self._socket_free.items()):
self._free_sockets()
else:
raise RuntimeError("Sending request failed")
retry_count += 1

try:
Expand Down Expand Up @@ -298,9 +262,6 @@ def _get_socket(self, host, port, *, timeout=1):
raise RuntimeError("Repeated socket failures")

self._backwards_compatible_sock = not hasattr(sock, "recv_into")

self._open_sockets[key] = sock
self._socket_free[sock] = False
return sock

def __enter__(self):
Expand Down Expand Up @@ -463,8 +424,8 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
if self.logger:
self.logger.debug("Attempting to establish MQTT connection...")

# Attempt to get a new socket
self._sock = self._get_socket(self.broker, self.port)
# Get a new socket
self._sock = self._get_connect_socket(self.broker, self.port)

# Fixed Header
fixed_header = bytearray([0x10])
Expand Down
19 changes: 9 additions & 10 deletions examples/cpython/minimqtt_simpletest_cpython.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries
# SPDX-License-Identifier: MIT

import ssl
import socket
import adafruit_minimqtt.adafruit_minimqtt as MQTT

### Secrets File Setup ###

# Add a secrets.py to your filesystem that has a dictionary called secrets with "ssid" and
# "password" keys with your WiFi credentials. DO NOT share that file or commit it into Git or other
# source control.
# pylint: disable=no-name-in-module,wrong-import-order
try:
from secrets import secrets
except ImportError:
print("Connection secrets are kept in secrets.py, please add them there!")
print("WiFi secrets are kept in secrets.py, please add them there!")
raise

### Topic Setup ###

# MQTT Topic
# Use this topic if you'd like to connect to a standard MQTT broker
# mqtt_topic = "test/topic"
mqtt_topic = "test/topic"

# Adafruit IO-style Topic
# Use this topic if you'd like to connect to io.adafruit.com
mqtt_topic = secrets["aio_username"] + "/feeds/temperature"
# mqtt_topic = secrets["aio_username"] + "/feeds/temperature"

### Code ###

# Keep track of client connection state
disconnect_client = False

# Define callback methods which are called when events occur
# pylint: disable=unused-argument, redefined-outer-name
def connect(mqtt_client, userdata, flags, rc):
Expand Down Expand Up @@ -65,10 +64,10 @@ def message(client, topic, message):
# Set up a MiniMQTT Client
mqtt_client = MQTT.MQTT(
broker=secrets["broker"],
port=1883,
username=secrets["aio_username"],
password=secrets["aio_key"],
socket_pool=socket,
ssl_context=ssl.create_default_context(),
)

# Connect callback handlers to mqtt_client
Expand Down

0 comments on commit fabe279

Please sign in to comment.