From 52f5382dba2977edc2af0f098e02b065154c683c Mon Sep 17 00:00:00 2001 From: tkulin Date: Wed, 26 Feb 2025 17:00:39 -0500 Subject: [PATCH 1/5] add basic fragmentation support --- bellows/ezsp/fragmentation.py | 98 +++++++++++++++++++++++++++++++++++ bellows/ezsp/protocol.py | 53 +++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 bellows/ezsp/fragmentation.py diff --git a/bellows/ezsp/fragmentation.py b/bellows/ezsp/fragmentation.py new file mode 100644 index 00000000..3a43d9da --- /dev/null +++ b/bellows/ezsp/fragmentation.py @@ -0,0 +1,98 @@ +""" +Implements APS fragmentation reassembly on the EZSP Host side, +mirroring the logic from fragmentation.c in the EmberZNet stack. +""" + +import asyncio +import logging +from collections import defaultdict +from typing import Optional, Dict, Tuple + +LOGGER = logging.getLogger(__name__) + +# The maximum time (in seconds) we wait for all fragments of a given message. +# If not all fragments arrive within this time, we discard the partial data. +FRAGMENT_TIMEOUT = 10 + +# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id) +FragmentKey = Tuple[int, int, int, int] + +class _FragmentEntry: + def __init__(self, fragment_count: int): + self.fragment_count = fragment_count + self.fragments_received = 0 + self.fragment_data = {} + self.start_time = asyncio.get_event_loop().time() + + def add_fragment(self, index: int, data: bytes) -> None: + if index not in self.fragment_data: + self.fragment_data[index] = data + self.fragments_received += 1 + + def is_complete(self) -> bool: + return self.fragments_received == self.fragment_count + + def assemble(self) -> bytes: + return b''.join(self.fragment_data[i] for i in sorted(self.fragment_data.keys())) + +class FragmentManager: + def __init__(self): + self._partial: Dict[FragmentKey, _FragmentEntry] = {} + + def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_id: int, cluster_id: int, + group_id: int, payload: bytes) -> Tuple[bool, Optional[bytes], int, int]: + """ + Handle a newly received fragment. The group_id field + encodes high byte = total fragment count, low byte = current fragment index. + + :param sender_nwk: NWK address or the short ID of the sender. + :param aps_sequence: The APS sequence from the incoming APS frame. + :param profile_id: The APS frame's profileId. + :param cluster_id: The APS frame's clusterId. + :param group_id: The APS frame's groupId (used to store fragment # / total). + :param payload: The fragment of data for this message. + :return: (complete, reassembled_data, fragment_count, fragment_index) + complete = True if we have all fragments now, else False + reassembled_data = the final complete payload (bytes) if complete is True + fragment_coutn = the total number of fragments holding the complete packet + fragment_index = the index of the current received fragment + """ + fragment_count = (group_id >> 8) & 0xFF + fragment_index = group_id & 0xFF + + key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id) + + # If we have never seen this message, create a reassembly entry. + if key not in self._partial: + entry = _FragmentEntry(fragment_count) + self._partial[key] = entry + else: + entry = self._partial[key] + + LOGGER.debug("Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)", + fragment_index, fragment_count, sender_nwk, aps_sequence, cluster_id) + + entry.add_fragment(fragment_index, payload) + + if entry.is_complete(): + reassembled = entry.assemble() + del self._partial[key] + LOGGER.debug("Message reassembly complete. Total length=%d", len(reassembled)) + return (True, reassembled, fragment_count, fragment_index) + else: + return (False, None, fragment_count, fragment_index) + + def cleanup_expired(self) -> None: + + now = asyncio.get_event_loop().time() + to_remove = [] + for k, entry in self._partial.items(): + if now - entry.start_time > FRAGMENT_TIMEOUT: + to_remove.append(k) + for k in to_remove: + del self._partial[k] + LOGGER.debug("Removed stale fragment reassembly for key=%s", k) + +# Create a single global manager instance +fragment_manager = FragmentManager() + diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index b03df772..fb36ab8f 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -181,6 +181,38 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) + if frame_name == "incomingMessageHandler" and result[1].options & 0x8000: # incoming message with APS_OPTION_FRAGMENT raised + from bellows.ezsp.fragmentation import fragment_manager + + # Extract received APS frame and sender + aps_frame = result[1] + sender = result[4] + + group_id = aps_frame.groupId + profile_id = aps_frame.profileId + cluster_id = aps_frame.clusterId + aps_seq = aps_frame.sequence + + complete, reassembled, frag_count, frag_index = fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_seq, + profile_id=profile_id, + cluster_id=cluster_id, + group_id=group_id, + payload=result[7] + ) + asyncio.create_task(self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)) # APS Ack + + if not complete: + # Do not pass partial data up the stack + LOGGER.debug("Fragment reassembly not complete. waiting for more data.") + return + else: + # Replace partial data with fully reassembled data + result[7] = reassembled + + LOGGER.debug("Reassembled fragmented message. Proceeding with normal handling.") + if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) try: @@ -205,6 +237,27 @@ def __call__(self, data: bytes) -> None: else: self._handle_callback(frame_name, result) + async def _send_fragment_ack(self, sender: int, incoming_aps: t.EmberApsFrame, fragment_count: int, fragment_index: int): + + ackFrame = t.EmberApsFrame( + profileId=incoming_aps.profileId, + clusterId=incoming_aps.clusterId, + sourceEndpoint=incoming_aps.destinationEndpoint, + destinationEndpoint=incoming_aps.sourceEndpoint, + options=incoming_aps.options, + groupId=((0xFF00) | (fragment_index & 0xFF)), + sequence=incoming_aps.sequence + ) + + LOGGER.debug("Sending fragment ack to 0x%04X for fragment index=%d/%d", sender, fragment_index, fragment_count) + await self.sendReply(sender, ackFrame, b'') + + async def _cleanup_fragments_periodically(self): + from bellows.ezsp.fragmentation import fragment_manager + while True: + await asyncio.sleep(5) + fragment_manager.cleanup_expired() + def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") From ea381158979c8af4c5491d7d29992f8ccf27a0a0 Mon Sep 17 00:00:00 2001 From: tkulin Date: Wed, 26 Feb 2025 17:08:37 -0500 Subject: [PATCH 2/5] add basic fragmentation support (remembering to cleanup leftover fragments) --- bellows/ezsp/protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index fb36ab8f..d00e8699 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -53,6 +53,7 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: # Cached by `set_extended_timeout` so subsequent calls are a little faster self._address_table_size: int | None = None + self._cleanup_fragments_periodically() def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes: """Serialize the named frame and data.""" From 04ae89baf761ce5f3277e54169bbbbb21dce4130 Mon Sep 17 00:00:00 2001 From: tkulin Date: Mon, 3 Mar 2025 11:13:55 -0500 Subject: [PATCH 3/5] code cleanup and improved partial fragment removal --- bellows/ezsp/fragmentation.py | 78 ++++++++++++++++++++++------------- bellows/ezsp/protocol.py | 64 ++++++++++++++++++---------- 2 files changed, 92 insertions(+), 50 deletions(-) diff --git a/bellows/ezsp/fragmentation.py b/bellows/ezsp/fragmentation.py index 3a43d9da..0e84b2e2 100644 --- a/bellows/ezsp/fragmentation.py +++ b/bellows/ezsp/fragmentation.py @@ -1,12 +1,10 @@ -""" -Implements APS fragmentation reassembly on the EZSP Host side, +"""Implements APS fragmentation reassembly on the EZSP Host side, mirroring the logic from fragmentation.c in the EmberZNet stack. """ import asyncio import logging -from collections import defaultdict -from typing import Optional, Dict, Tuple +from typing import Dict, Optional, Tuple LOGGER = logging.getLogger(__name__) @@ -17,13 +15,14 @@ # store partial data keyed by (sender, aps_sequence, profile_id, cluster_id) FragmentKey = Tuple[int, int, int, int] + class _FragmentEntry: def __init__(self, fragment_count: int): self.fragment_count = fragment_count self.fragments_received = 0 self.fragment_data = {} self.start_time = asyncio.get_event_loop().time() - + def add_fragment(self, index: int, data: bytes) -> None: if index not in self.fragment_data: self.fragment_data[index] = data @@ -33,23 +32,34 @@ def is_complete(self) -> bool: return self.fragments_received == self.fragment_count def assemble(self) -> bytes: - return b''.join(self.fragment_data[i] for i in sorted(self.fragment_data.keys())) + return b"".join( + self.fragment_data[i] for i in sorted(self.fragment_data.keys()) + ) + class FragmentManager: def __init__(self): self._partial: Dict[FragmentKey, _FragmentEntry] = {} - - def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_id: int, cluster_id: int, - group_id: int, payload: bytes) -> Tuple[bool, Optional[bytes], int, int]: - """ - Handle a newly received fragment. The group_id field - encodes high byte = total fragment count, low byte = current fragment index. + self._cleanup_timers: Dict[FragmentKey, asyncio.TimerHandle] = {} + + def handle_incoming_fragment( + self, + sender_nwk: int, + aps_sequence: int, + profile_id: int, + cluster_id: int, + fragment_count: int, + fragment_index: int, + payload: bytes, + ) -> Tuple[bool, Optional[bytes], int, int]: + """Handle a newly received fragment. :param sender_nwk: NWK address or the short ID of the sender. :param aps_sequence: The APS sequence from the incoming APS frame. :param profile_id: The APS frame's profileId. :param cluster_id: The APS frame's clusterId. - :param group_id: The APS frame's groupId (used to store fragment # / total). + :param fragment_count: The total number of expected message fragments. + :param fragment_index: The index of the current fragment being processed. :param payload: The fragment of data for this message. :return: (complete, reassembled_data, fragment_count, fragment_index) complete = True if we have all fragments now, else False @@ -57,8 +67,6 @@ def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_i fragment_coutn = the total number of fragments holding the complete packet fragment_index = the index of the current received fragment """ - fragment_count = (group_id >> 8) & 0xFF - fragment_index = group_id & 0xFF key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id) @@ -69,30 +77,44 @@ def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_i else: entry = self._partial[key] - LOGGER.debug("Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)", - fragment_index, fragment_count, sender_nwk, aps_sequence, cluster_id) + LOGGER.debug( + "Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)", + fragment_index + 1, + fragment_count, + sender_nwk, + aps_sequence, + cluster_id, + ) entry.add_fragment(fragment_index, payload) + loop = asyncio.get_running_loop() + self._cleanup_timers[key] = loop.call_later( + FRAGMENT_TIMEOUT, self.cleanup_partial, key + ) + if entry.is_complete(): reassembled = entry.assemble() del self._partial[key] - LOGGER.debug("Message reassembly complete. Total length=%d", len(reassembled)) + timer = self._cleanup_timers.pop(key, None) + if timer: + timer.cancel() + LOGGER.debug( + "Message reassembly complete. Total length=%d", len(reassembled) + ) return (True, reassembled, fragment_count, fragment_index) else: return (False, None, fragment_count, fragment_index) - def cleanup_expired(self) -> None: + def cleanup_partial(self, key: FragmentKey): + # Called when FRAGMENT_TIMEOUT passes with no new fragments for that key. + LOGGER.debug( + "Timeout for partial reassembly of fragmented message, discarding key=%s", + key, + ) + self._partial.pop(key, None) + self._cleanup_timers.pop(key, None) - now = asyncio.get_event_loop().time() - to_remove = [] - for k, entry in self._partial.items(): - if now - entry.start_time > FRAGMENT_TIMEOUT: - to_remove.append(k) - for k in to_remove: - del self._partial[k] - LOGGER.debug("Removed stale fragment reassembly for key=%s", k) # Create a single global manager instance fragment_manager = FragmentManager() - diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index d00e8699..a787bf54 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -53,7 +53,6 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: # Cached by `set_extended_timeout` so subsequent calls are a little faster self._address_table_size: int | None = None - self._cleanup_fragments_periodically() def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes: """Serialize the named frame and data.""" @@ -182,27 +181,41 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) - if frame_name == "incomingMessageHandler" and result[1].options & 0x8000: # incoming message with APS_OPTION_FRAGMENT raised + if ( + frame_name == "incomingMessageHandler" and result[1].options & 0x8000 + ): # incoming message with APS_OPTION_FRAGMENT raised from bellows.ezsp.fragmentation import fragment_manager # Extract received APS frame and sender aps_frame = result[1] - sender = result[4] + sender = result[4] group_id = aps_frame.groupId profile_id = aps_frame.profileId cluster_id = aps_frame.clusterId aps_seq = aps_frame.sequence - - complete, reassembled, frag_count, frag_index = fragment_manager.handle_incoming_fragment( + + fragment_count = (group_id >> 8) & 0xFF + fragment_index = group_id & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = fragment_manager.handle_incoming_fragment( sender_nwk=sender, aps_sequence=aps_seq, profile_id=profile_id, cluster_id=cluster_id, - group_id=group_id, - payload=result[7] + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=result[7], ) - asyncio.create_task(self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)) # APS Ack + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) # APS Ack + ack_task.add_done_callback(self._ack_tasks.remove) if not complete: # Do not pass partial data up the stack @@ -211,8 +224,10 @@ def __call__(self, data: bytes) -> None: else: # Replace partial data with fully reassembled data result[7] = reassembled - - LOGGER.debug("Reassembled fragmented message. Proceeding with normal handling.") + + LOGGER.debug( + "Reassembled fragmented message. Proceeding with normal handling." + ) if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) @@ -238,8 +253,13 @@ def __call__(self, data: bytes) -> None: else: self._handle_callback(frame_name, result) - async def _send_fragment_ack(self, sender: int, incoming_aps: t.EmberApsFrame, fragment_count: int, fragment_index: int): - + async def _send_fragment_ack( + self, + sender: int, + incoming_aps: t.EmberApsFrame, + fragment_count: int, + fragment_index: int, + ) -> t.EmberStatus: ackFrame = t.EmberApsFrame( profileId=incoming_aps.profileId, clusterId=incoming_aps.clusterId, @@ -247,18 +267,18 @@ async def _send_fragment_ack(self, sender: int, incoming_aps: t.EmberApsFrame, f destinationEndpoint=incoming_aps.sourceEndpoint, options=incoming_aps.options, groupId=((0xFF00) | (fragment_index & 0xFF)), - sequence=incoming_aps.sequence + sequence=incoming_aps.sequence, + ) + + LOGGER.debug( + "Sending fragment ack to 0x%04X for fragment index=%d/%d", + sender, + fragment_index + 1, + fragment_count, ) + status = await self.sendReply(sender, ackFrame, b"") + return status - LOGGER.debug("Sending fragment ack to 0x%04X for fragment index=%d/%d", sender, fragment_index, fragment_count) - await self.sendReply(sender, ackFrame, b'') - - async def _cleanup_fragments_periodically(self): - from bellows.ezsp.fragmentation import fragment_manager - while True: - await asyncio.sleep(5) - fragment_manager.cleanup_expired() - def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") From c91fd306426b1a9731539937b1b755b10ec6588a Mon Sep 17 00:00:00 2001 From: tkulin Date: Mon, 3 Mar 2025 14:28:37 -0500 Subject: [PATCH 4/5] minor change to frag ack callback handling, added test case for fragmentation --- bellows/ezsp/protocol.py | 5 +- tests/test_fragmentation.py | 192 ++++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 tests/test_fragmentation.py diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index a787bf54..931265a2 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -212,10 +212,13 @@ def __call__(self, data: bytes) -> None: fragment_index=fragment_index, payload=result[7], ) + if not hasattr(self, "_ack_tasks"): + self._ack_tasks = set() ack_task = asyncio.create_task( self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) ) # APS Ack - ack_task.add_done_callback(self._ack_tasks.remove) + self._ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._ack_tasks.discard(t)) if not complete: # Do not pass partial data up the stack diff --git a/tests/test_fragmentation.py b/tests/test_fragmentation.py new file mode 100644 index 00000000..a0e60d1e --- /dev/null +++ b/tests/test_fragmentation.py @@ -0,0 +1,192 @@ +import pytest + +from bellows.ezsp.fragmentation import FragmentManager + + +@pytest.fixture +def frag_manager(): + return FragmentManager() + + +@pytest.mark.asyncio +async def test_single_fragment_complete(frag_manager): + # If we receive a single-fragment message, the fragemnt manager should immediately report completion. + + key = (0x1234, 0xAB, 0x1234, 0x5678) + fragment_count = 1 + fragment_index = 0 + payload = b"Single fragment" + + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=payload, + ) + + assert complete is True + assert reassembled == payload + assert returned_frag_count == fragment_count + assert returned_frag_index == fragment_index + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers + + +@pytest.mark.asyncio +async def test_two_fragments_in_order(frag_manager): + # A two-fragment message should remain partial until we've received both pieces. + key = (0x1111, 0x01, 0x9999, 0x2222) + fragment_count = 2 + + # First fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"Frag0-", + ) + assert complete is False + assert reassembled is None + assert key in frag_manager._partial + assert frag_manager._partial[key].fragments_received == 1 + + # Second fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"Frag1", + ) + assert complete is True + assert reassembled == b"Frag0-Frag1" + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers + + +@pytest.mark.asyncio +async def test_out_of_order_fragments(frag_manager): + # Receiving fragments in reverse order + key = (0x9999, 0xCD, 0x1234, 0xABCD) + fragment_count = 2 + + # Second fragment arrives first + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"World", + ) + assert not complete + assert reassembled is None + + # Then the first fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"Hello ", + ) + assert complete + assert reassembled == b"Hello World" + + +@pytest.mark.asyncio +async def test_repeated_fragments_ignored(frag_manager): + # Ensure repeated arrivals of the same fragment index do not double-count or break the logic. + + key = (0xAAA, 0xBB, 0xCCC, 0xDDD) + fragment_count = 2 + + # First fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"first", + ) + assert not complete + assert frag_manager._partial[key].fragments_received == 1 + + # Repeat the same fragment index + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"first", + ) + assert not complete + assert frag_manager._partial[key].fragments_received == 1, "Should not increment" + + # Second fragment completes + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"second", + ) + assert complete + assert reassembled == b"firstsecond" From 13c376a2e627b5e3f11582691ffe32206e9edf06 Mon Sep 17 00:00:00 2001 From: tkulin Date: Tue, 4 Mar 2025 15:06:37 -0500 Subject: [PATCH 5/5] added more tests --- bellows/ezsp/protocol.py | 2 +- tests/test_ezsp_protocol.py | 157 ++++++++++++++++++++++++++++++++++++ tests/test_fragmentation.py | 42 +++++++--- 3 files changed, 191 insertions(+), 10 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 931265a2..d7d968dd 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -280,7 +280,7 @@ async def _send_fragment_ack( fragment_count, ) status = await self.sendReply(sender, ackFrame, b"") - return status + return status[0] def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 98f5678e..a6c2e3c1 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -133,3 +133,160 @@ async def test_parsing_schema_response(prot_hndl_v9): rsp = await coro assert rsp == GetTokenDataRsp(status=t.EmberStatus.LIBRARY_NOT_PRESENT) + + +@pytest.mark.asyncio +async def test_send_fragment_ack(prot_hndl, caplog): + """Test the _send_fragment_ack method.""" + sender = 0x1D6F + incoming_aps = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=512, + sequence=238, + ) + fragment_count = 2 + fragment_index = 0 + + expected_ack_frame = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=((0xFF00) | (fragment_index & 0xFF)), + sequence=238, + ) + + with patch.object(prot_hndl, "sendReply", new=AsyncMock()) as mock_send_reply: + mock_send_reply.return_value = (t.EmberStatus.SUCCESS,) + + caplog.set_level(logging.DEBUG) + status = await prot_hndl._send_fragment_ack( + sender, incoming_aps, fragment_count, fragment_index + ) + + # Assertions + assert status == t.EmberStatus.SUCCESS + assert ( + "Sending fragment ack to 0x1d6f for fragment index=1/2".lower() + in caplog.text.lower() + ) + mock_send_reply.assert_called_once_with(sender, expected_ack_frame, b"") + + +@pytest.mark.asyncio +async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): + """Test handling of an incomplete fragmented message.""" + packet = b"\x90\x01\x45\x00\x05\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x01\xdd" + + # Parse packet manually to extract parameters for assertions + sender = 0x1D6F + aps_frame = t.EmberApsFrame( + profileId=261, # 0x0105 + clusterId=65281, # 0xFF01 + sourceEndpoint=2, # 0x02 + destinationEndpoint=2, # 0x02 + options=33088, # 0x8140 (APS_OPTION_FRAGMENT + others) + groupId=512, # 0x0002 (fragment_count=2, fragment_index=0) + sequence=238, # 0xEE + ) + + with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: + mock_ack.return_value = None + + caplog.set_level(logging.DEBUG) + prot_hndl(packet) + + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_not_called() + assert "Fragment reassembly not complete. waiting for more data." in caplog.text + mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) + + +@pytest.mark.asyncio +async def test_incoming_fragmented_message_complete(prot_hndl, caplog): + """Test handling of a complete fragmented message.""" + packet1 = ( + b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x09" + + b"complete " + ) # fragment index 0 + packet2 = ( + b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07" + + b"message" + ) # fragment index 1 + sender = 0x1D6F + + aps_frame_1 = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, # Includes APS_OPTION_FRAGMENT + groupId=512, # fragment_count=2, fragment_index=0 + sequence=238, + ) + aps_frame_2 = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=513, # fragment_count=2, fragment_index=1 + sequence=238, + ) + reassembled = b"complete message" + + with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: + mock_ack.return_value = None + caplog.set_level(logging.DEBUG) + + # Packet 1 + prot_hndl(packet1) + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_not_called() + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + not in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) + + # Packet 2 + prot_hndl(packet2) + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_called_once_with( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 + aps_frame_2, # Parsed APS frame + 255, # lastHopLqi: 0xFF + -8, # lastHopRssi: 0xF8 + sender, # 0x1D6F + 255, # bindingIndex: 0xFF + 255, # addressIndex: 0xFF + reassembled, # Reassembled payload + ], + ) + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) diff --git a/tests/test_fragmentation.py b/tests/test_fragmentation.py index a0e60d1e..35a816b7 100644 --- a/tests/test_fragmentation.py +++ b/tests/test_fragmentation.py @@ -1,21 +1,26 @@ +from unittest.mock import MagicMock + import pytest -from bellows.ezsp.fragmentation import FragmentManager +from bellows.ezsp.fragmentation import fragment_manager @pytest.fixture def frag_manager(): - return FragmentManager() + """Return a new FragmentManager instance for each test.""" + return fragment_manager @pytest.mark.asyncio async def test_single_fragment_complete(frag_manager): - # If we receive a single-fragment message, the fragemnt manager should immediately report completion. - + """ + If we receive a single-fragment message (fragment_count=1, fragment_index=0), + the manager should immediately report completion. + """ key = (0x1234, 0xAB, 0x1234, 0x5678) fragment_count = 1 fragment_index = 0 - payload = b"Single fragment" + payload = b"Hello single fragment" ( complete, @@ -36,13 +41,16 @@ async def test_single_fragment_complete(frag_manager): assert reassembled == payload assert returned_frag_count == fragment_count assert returned_frag_index == fragment_index + # Make sure it's no longer tracked as partial assert key not in frag_manager._partial assert key not in frag_manager._cleanup_timers @pytest.mark.asyncio async def test_two_fragments_in_order(frag_manager): - # A two-fragment message should remain partial until we've received both pieces. + """ + A two-fragment message should remain partial until we've received both pieces. + """ key = (0x1111, 0x01, 0x9999, 0x2222) fragment_count = 2 @@ -83,13 +91,16 @@ async def test_two_fragments_in_order(frag_manager): ) assert complete is True assert reassembled == b"Frag0-Frag1" + # It's removed from partials after completion assert key not in frag_manager._partial assert key not in frag_manager._cleanup_timers @pytest.mark.asyncio async def test_out_of_order_fragments(frag_manager): - # Receiving fragments in reverse order + """ + Receiving fragments in reverse order should still produce the correct reassembly once all arrive. + """ key = (0x9999, 0xCD, 0x1234, 0xABCD) fragment_count = 2 @@ -132,8 +143,9 @@ async def test_out_of_order_fragments(frag_manager): @pytest.mark.asyncio async def test_repeated_fragments_ignored(frag_manager): - # Ensure repeated arrivals of the same fragment index do not double-count or break the logic. - + """ + Ensure repeated arrivals of the same fragment index do not double-count or break the logic. + """ key = (0xAAA, 0xBB, 0xCCC, 0xDDD) fragment_count = 2 @@ -190,3 +202,15 @@ async def test_repeated_fragments_ignored(frag_manager): ) assert complete assert reassembled == b"firstsecond" + + +@pytest.mark.asyncio +async def test_cleanup_partial(frag_manager, caplog): + key = (0x1234, 0xAB, 0x1234, 0x5678) + + frag_manager._partial[key] = MagicMock() + frag_manager._cleanup_timers[key] = MagicMock() + frag_manager.cleanup_partial(key) + + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers