Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Issue #90: Remaining length issues #91

Merged
merged 4 commits into from
Dec 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
randint(0, int(time.monotonic() * 100) % 1000), randint(0, 99)
)
# generated client_id's enforce spec.'s length rules
if len(self.client_id) > 23 or not self.client_id:
if len(self.client_id.encode("utf-8")) > 23 or not self.client_id:
raise ValueError("MQTT Client ID must be between 1 and 23 bytes")

# LWT
Expand Down Expand Up @@ -450,16 +450,16 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
var_header[6] = clean_session << 1

# Set up variable header and remaining_length
remaining_length = 12 + len(self.client_id)
remaining_length = 12 + len(self.client_id.encode("utf-8"))
if self._username:
remaining_length += 2 + len(self._username) + 2 + len(self._password)
remaining_length += 2 + len(self._username.encode("utf-8")) + 2 + len(self._password.encode("utf-8"))
var_header[6] |= 0xC0
if self.keep_alive:
assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT
var_header[7] |= self.keep_alive >> 8
var_header[8] |= self.keep_alive & 0x00FF
if self._lw_topic:
remaining_length += 2 + len(self._lw_topic) + 2 + len(self._lw_msg)
remaining_length += 2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
var_header[6] |= self._lw_retain << 5

Expand Down Expand Up @@ -586,10 +586,10 @@ def publish(self, topic, msg, retain=False, qos=0):
pub_hdr_fixed = bytearray([0x30 | retain | qos << 1])

# variable header = 2-byte Topic length (big endian)
pub_hdr_var = bytearray(struct.pack(">H", len(topic)))
pub_hdr_var = bytearray(struct.pack(">H", len(topic.encode("utf-8"))))
pub_hdr_var.extend(topic.encode("utf-8")) # Topic name

remaining_length = 2 + len(msg) + len(topic)
remaining_length = 2 + len(msg) + len(topic.encode("utf-8"))
if qos > 0:
# packet identifier where QoS level is 1 or 2. [3.3.2.2]
remaining_length += 2
Expand Down Expand Up @@ -668,15 +668,15 @@ def subscribe(self, topic, qos=0):
topics.append((t, q))
# Assemble packet
packet_length = 2 + (2 * len(topics)) + (1 * len(topics))
packet_length += sum(len(topic) for topic, qos in topics)
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
packet_length_byte = packet_length.to_bytes(1, "big")
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
# Packet with variable and fixed headers
packet = MQTT_SUB + packet_length_byte + packet_id_bytes
# attaching topic and QOS level to the packet
for t, q in topics:
topic_size = len(t).to_bytes(2, "big")
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
qos_byte = q.to_bytes(1, "big")
packet += topic_size + t.encode() + qos_byte
if self.logger:
Expand Down Expand Up @@ -717,13 +717,13 @@ def unsubscribe(self, topic):
)
# Assemble packet
packet_length = 2 + (2 * len(topics))
packet_length += sum(len(topic) for topic in topics)
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
packet_length_byte = packet_length.to_bytes(1, "big")
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes
for t in topics:
topic_size = len(t).to_bytes(2, "big")
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
packet += topic_size + t.encode()
if self.logger:
for t in topics:
Expand Down Expand Up @@ -914,10 +914,11 @@ def _send_str(self, string):
:param str string: String to write to the socket.

"""
self._sock.send(struct.pack("!H", len(string)))
if isinstance(string, str):
self._sock.send(struct.pack("!H", len(string.encode("utf-8"))))
self._sock.send(str.encode(string, "utf-8"))
else:
self._sock.send(struct.pack("!H", len(string)))
self._sock.send(string)

@staticmethod
Expand Down