Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for receiving basic fragmented messages #669

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions bellows/ezsp/fragmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Implements APS fragmentation reassembly on the EZSP Host side,
mirroring the logic from fragmentation.c in the EmberZNet stack.
"""

import asyncio
import logging
from typing import Dict, Optional, Tuple

LOGGER = logging.getLogger(__name__)

# The maximum time (in seconds) we wait for all fragments of a given message.
# If not all fragments arrive within this time, we discard the partial data.
FRAGMENT_TIMEOUT = 10

# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id)
FragmentKey = Tuple[int, int, int, int]


class _FragmentEntry:
def __init__(self, fragment_count: int):
self.fragment_count = fragment_count
self.fragments_received = 0
self.fragment_data = {}
self.start_time = asyncio.get_event_loop().time()

def add_fragment(self, index: int, data: bytes) -> None:
if index not in self.fragment_data:
self.fragment_data[index] = data
self.fragments_received += 1

def is_complete(self) -> bool:
return self.fragments_received == self.fragment_count

def assemble(self) -> bytes:
return b"".join(
self.fragment_data[i] for i in sorted(self.fragment_data.keys())
)


class FragmentManager:
def __init__(self):
self._partial: Dict[FragmentKey, _FragmentEntry] = {}
self._cleanup_timers: Dict[FragmentKey, asyncio.TimerHandle] = {}

def handle_incoming_fragment(
self,
sender_nwk: int,
aps_sequence: int,
profile_id: int,
cluster_id: int,
fragment_count: int,
fragment_index: int,
payload: bytes,
) -> Tuple[bool, Optional[bytes], int, int]:
"""Handle a newly received fragment.

:param sender_nwk: NWK address or the short ID of the sender.
:param aps_sequence: The APS sequence from the incoming APS frame.
:param profile_id: The APS frame's profileId.
:param cluster_id: The APS frame's clusterId.
:param fragment_count: The total number of expected message fragments.
:param fragment_index: The index of the current fragment being processed.
:param payload: The fragment of data for this message.
:return: (complete, reassembled_data, fragment_count, fragment_index)
complete = True if we have all fragments now, else False
reassembled_data = the final complete payload (bytes) if complete is True
fragment_coutn = the total number of fragments holding the complete packet
fragment_index = the index of the current received fragment
"""

key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id)

# If we have never seen this message, create a reassembly entry.
if key not in self._partial:
entry = _FragmentEntry(fragment_count)
self._partial[key] = entry
else:
entry = self._partial[key]

LOGGER.debug(
"Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
fragment_index + 1,
fragment_count,
sender_nwk,
aps_sequence,
cluster_id,
)

entry.add_fragment(fragment_index, payload)

loop = asyncio.get_running_loop()
self._cleanup_timers[key] = loop.call_later(
FRAGMENT_TIMEOUT, self.cleanup_partial, key
)

if entry.is_complete():
reassembled = entry.assemble()
del self._partial[key]
timer = self._cleanup_timers.pop(key, None)
if timer:
timer.cancel()
LOGGER.debug(
"Message reassembly complete. Total length=%d", len(reassembled)
)
return (True, reassembled, fragment_count, fragment_index)
else:
return (False, None, fragment_count, fragment_index)

def cleanup_partial(self, key: FragmentKey):
# Called when FRAGMENT_TIMEOUT passes with no new fragments for that key.
LOGGER.debug(
"Timeout for partial reassembly of fragmented message, discarding key=%s",
key,
)
self._partial.pop(key, None)
self._cleanup_timers.pop(key, None)


# Create a single global manager instance
fragment_manager = FragmentManager()
77 changes: 77 additions & 0 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,57 @@ def __call__(self, data: bytes) -> None:
if data:
LOGGER.debug("Frame contains trailing data: %s", data)

if (
frame_name == "incomingMessageHandler" and result[1].options & 0x8000
): # incoming message with APS_OPTION_FRAGMENT raised
from bellows.ezsp.fragmentation import fragment_manager

# Extract received APS frame and sender
aps_frame = result[1]
sender = result[4]

group_id = aps_frame.groupId
profile_id = aps_frame.profileId
cluster_id = aps_frame.clusterId
aps_seq = aps_frame.sequence

fragment_count = (group_id >> 8) & 0xFF
fragment_index = group_id & 0xFF

(
complete,
reassembled,
frag_count,
frag_index,
) = fragment_manager.handle_incoming_fragment(
sender_nwk=sender,
aps_sequence=aps_seq,
profile_id=profile_id,
cluster_id=cluster_id,
fragment_count=fragment_count,
fragment_index=fragment_index,
payload=result[7],
)
if not hasattr(self, "_ack_tasks"):
self._ack_tasks = set()
ack_task = asyncio.create_task(
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
) # APS Ack
self._ack_tasks.add(ack_task)
ack_task.add_done_callback(lambda t: self._ack_tasks.discard(t))

if not complete:
# Do not pass partial data up the stack
LOGGER.debug("Fragment reassembly not complete. waiting for more data.")
return
else:
# Replace partial data with fully reassembled data
result[7] = reassembled

LOGGER.debug(
"Reassembled fragmented message. Proceeding with normal handling."
)

if sequence in self._awaiting:
expected_id, schema, future = self._awaiting.pop(sequence)
try:
Expand All @@ -205,6 +256,32 @@ def __call__(self, data: bytes) -> None:
else:
self._handle_callback(frame_name, result)

async def _send_fragment_ack(
self,
sender: int,
incoming_aps: t.EmberApsFrame,
fragment_count: int,
fragment_index: int,
) -> t.EmberStatus:
ackFrame = t.EmberApsFrame(
profileId=incoming_aps.profileId,
clusterId=incoming_aps.clusterId,
sourceEndpoint=incoming_aps.destinationEndpoint,
destinationEndpoint=incoming_aps.sourceEndpoint,
options=incoming_aps.options,
groupId=((0xFF00) | (fragment_index & 0xFF)),
sequence=incoming_aps.sequence,
)

LOGGER.debug(
"Sending fragment ack to 0x%04X for fragment index=%d/%d",
sender,
fragment_index + 1,
fragment_count,
)
status = await self.sendReply(sender, ackFrame, b"")
return status[0]

def __getattr__(self, name: str) -> Callable:
if name not in self.COMMANDS:
raise AttributeError(f"{name} not found in COMMANDS")
Expand Down
157 changes: 157 additions & 0 deletions tests/test_ezsp_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,160 @@ async def test_parsing_schema_response(prot_hndl_v9):

rsp = await coro
assert rsp == GetTokenDataRsp(status=t.EmberStatus.LIBRARY_NOT_PRESENT)


@pytest.mark.asyncio
async def test_send_fragment_ack(prot_hndl, caplog):
"""Test the _send_fragment_ack method."""
sender = 0x1D6F
incoming_aps = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=512,
sequence=238,
)
fragment_count = 2
fragment_index = 0

expected_ack_frame = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=((0xFF00) | (fragment_index & 0xFF)),
sequence=238,
)

with patch.object(prot_hndl, "sendReply", new=AsyncMock()) as mock_send_reply:
mock_send_reply.return_value = (t.EmberStatus.SUCCESS,)

caplog.set_level(logging.DEBUG)
status = await prot_hndl._send_fragment_ack(
sender, incoming_aps, fragment_count, fragment_index
)

# Assertions
assert status == t.EmberStatus.SUCCESS
assert (
"Sending fragment ack to 0x1d6f for fragment index=1/2".lower()
in caplog.text.lower()
)
mock_send_reply.assert_called_once_with(sender, expected_ack_frame, b"")


@pytest.mark.asyncio
async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog):
"""Test handling of an incomplete fragmented message."""
packet = b"\x90\x01\x45\x00\x05\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x01\xdd"

# Parse packet manually to extract parameters for assertions
sender = 0x1D6F
aps_frame = t.EmberApsFrame(
profileId=261, # 0x0105
clusterId=65281, # 0xFF01
sourceEndpoint=2, # 0x02
destinationEndpoint=2, # 0x02
options=33088, # 0x8140 (APS_OPTION_FRAGMENT + others)
groupId=512, # 0x0002 (fragment_count=2, fragment_index=0)
sequence=238, # 0xEE
)

with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
mock_ack.return_value = None

caplog.set_level(logging.DEBUG)
prot_hndl(packet)

assert hasattr(prot_hndl, "_ack_tasks")
assert len(prot_hndl._ack_tasks) == 1
ack_task = next(iter(prot_hndl._ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task"

prot_hndl._handle_callback.assert_not_called()
assert "Fragment reassembly not complete. waiting for more data." in caplog.text
mock_ack.assert_called_once_with(sender, aps_frame, 2, 0)


@pytest.mark.asyncio
async def test_incoming_fragmented_message_complete(prot_hndl, caplog):
"""Test handling of a complete fragmented message."""
packet1 = (
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x09"
+ b"complete "
) # fragment index 0
packet2 = (
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07"
+ b"message"
) # fragment index 1
sender = 0x1D6F

aps_frame_1 = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088, # Includes APS_OPTION_FRAGMENT
groupId=512, # fragment_count=2, fragment_index=0
sequence=238,
)
aps_frame_2 = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=513, # fragment_count=2, fragment_index=1
sequence=238,
)
reassembled = b"complete message"

with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
mock_ack.return_value = None
caplog.set_level(logging.DEBUG)

# Packet 1
prot_hndl(packet1)
assert hasattr(prot_hndl, "_ack_tasks")
assert len(prot_hndl._ack_tasks) == 1
ack_task = next(iter(prot_hndl._ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task"

prot_hndl._handle_callback.assert_not_called()
assert (
"Reassembled fragmented message. Proceeding with normal handling."
not in caplog.text
)
mock_ack.assert_called_with(sender, aps_frame_1, 2, 0)

# Packet 2
prot_hndl(packet2)
assert hasattr(prot_hndl, "_ack_tasks")
assert len(prot_hndl._ack_tasks) == 1
ack_task = next(iter(prot_hndl._ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task"

prot_hndl._handle_callback.assert_called_once_with(
"incomingMessageHandler",
[
t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00
aps_frame_2, # Parsed APS frame
255, # lastHopLqi: 0xFF
-8, # lastHopRssi: 0xF8
sender, # 0x1D6F
255, # bindingIndex: 0xFF
255, # addressIndex: 0xFF
reassembled, # Reassembled payload
],
)
assert (
"Reassembled fragmented message. Proceeding with normal handling."
in caplog.text
)
mock_ack.assert_called_with(sender, aps_frame_2, 2, 1)
Loading
Loading