From f0811bcd0f917c6395842f6c5f20b0723186dbf6 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 12 Aug 2024 09:35:21 -0400 Subject: [PATCH 1/4] Implement `set_extended_timeout` in the protocol handler --- bellows/ezsp/protocol.py | 6 ++++++ bellows/ezsp/v4/__init__.py | 35 +++++++++++++++++++++++++++++++++++ bellows/zigbee/application.py | 6 ++++-- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 1fe87ad1..0e6ad367 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -252,3 +252,9 @@ async def read_counters(self) -> dict[t.EmberCounterType, int]: @abc.abstractmethod async def read_and_clear_counters(self) -> dict[t.EmberCounterType, int]: raise NotImplementedError + + @abc.abstractmethod + async def set_extended_timeout( + self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True + ) -> None: + raise NotImplementedError() diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index b1e9fb6d..b70b2664 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import random from typing import AsyncGenerator, Iterable import voluptuous as vol @@ -193,3 +194,37 @@ async def read_counters(self) -> dict[t.EmberCounterType, t.uint16_t]: async def read_and_clear_counters(self) -> dict[t.EmberCounterType, t.uint16_t]: (res,) = await self.readAndClearCounters() return dict(zip(t.EmberCounterType, res)) + + async def set_extended_timeout( + self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True + ) -> None: + (curr_extended_timeout,) = await self.getExtendedTimeout(remoteEui64=ieee) + + if curr_extended_timeout == extended_timeout: + return + + (node_id,) = await self.lookupNodeIdByEui64(eui64=ieee) + + # Check to see if we have an address table entry + if node_id != 0xFFFF: + await self.setExtendedTimeout( + remoteEui64=ieee, extendedTimeout=extended_timeout + ) + return + + (status, addr_table_size) = await self.getConfigurationValue( + t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE + ) + + if t.sl_Status.from_ember_status(status) != t.sl_Status.OK: + return + + # Replace a random entry in the address table + index = random.randint(0, addr_table_size - 1) + + await self.replaceAddressTableEntry( + addressTableIndex=index, + newEui64=ieee, + newId=nwk, + newExtendedTimeout=extended_timeout, + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 22247178..08a1169e 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -749,8 +749,10 @@ async def send_packet(self, packet: zigpy.types.ZigbeePacket) -> None: async with self._req_lock: if packet.dst.addr_mode == zigpy.types.AddrMode.NWK: if packet.extended_timeout and device is not None: - await self._ezsp.setExtendedTimeout( - remoteEui64=device.ieee, extendedTimeout=True + await self._ezsp.set_extended_timeout( + nwk=device.nwk, + ieee=device.ieee, + extended_timeout=True, ) if packet.source_route is not None: From 7792587ac1a9856064226d141d9d8a14f955cfdc Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 13 Aug 2024 23:41:07 -0400 Subject: [PATCH 2/4] Add tests --- bellows/ezsp/v4/__init__.py | 4 ++ tests/test_application.py | 1 + tests/test_ezsp_v4.py | 109 +++++++++++++++++++++++++++++++++++- 3 files changed, 113 insertions(+), 1 deletion(-) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index b70b2664..ebf18001 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -217,6 +217,10 @@ async def set_extended_timeout( ) if t.sl_Status.from_ember_status(status) != t.sl_Status.OK: + # Last-ditch effort + await self.setExtendedTimeout( + remoteEui64=ieee, extendedTimeout=extended_timeout + ) return # Replace a random entry in the address table diff --git a/tests/test_application.py b/tests/test_application.py index c846f6bf..fce1a8b0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -178,6 +178,7 @@ def form_network(): ) proto.factory_reset = AsyncMock(proto=proto.factory_reset) + proto.set_extended_timeout = AsyncMock(proto=proto.set_extended_timeout) proto.read_link_keys = MagicMock() proto.read_link_keys.return_value.__aiter__.return_value = [ diff --git a/tests/test_ezsp_v4.py b/tests/test_ezsp_v4.py index 69e00243..3f035366 100644 --- a/tests/test_ezsp_v4.py +++ b/tests/test_ezsp_v4.py @@ -1,5 +1,5 @@ import logging -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest import zigpy.state @@ -379,3 +379,110 @@ async def test_read_counters(ezsp_f, length: int) -> None: ) assert counters1 == counters2 == {t.EmberCounterType(i): i for i in range(length)} + + +async def test_set_extended_timeout_no_entry(ezsp_f) -> None: + # Typical invocation + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) # No address table entry + ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.SUCCESS, 8) + ezsp_f.replaceAddressTableEntry.return_value = ( + t.EmberStatus.SUCCESS, + t.EUI64.convert("ff:ff:ff:ff:ff:ff:ff:ff"), + 0xFFFF, + t.Bool.false, + ) + + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 0 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ] + assert mock_random.mock_calls == [call(0, 8 - 1)] + assert ezsp_f.replaceAddressTableEntry.mock_calls == [ + call( + addressTableIndex=0, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ) + ] + + +async def test_set_extended_timeout_already_set(ezsp_f) -> None: + # No-op, it's already set + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.true,) + + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.setExtendedTimeout.mock_calls == [] + + +async def test_set_extended_timeout_already_have_entry(ezsp_f) -> None: + # An address table entry is present + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0x1234,) + + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.setExtendedTimeout.mock_calls == [ + call( + remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), extendedTimeout=True + ) + ] + + +async def test_set_extended_timeout_bad_table_size(ezsp_f) -> None: + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) + ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.ERR_FATAL, 0xFF) + + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 0 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ] From 2d4605aed40dc92cd7e9fde7f327653d89a78309 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 13 Aug 2024 23:51:31 -0400 Subject: [PATCH 3/4] Add test for `send_packet` as well --- tests/test_application.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_application.py b/tests/test_application.py index fce1a8b0..e35adfc5 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -843,6 +843,19 @@ async def test_send_packet_unicast_source_route(make_app, packet): ) +async def test_send_packet_unicast_extended_timeout(app, ieee, packet): + app.add_device(nwk=packet.dst.address, ieee=ieee) + + await _test_send_packet_unicast( + app, + packet.replace(extended_timeout=True), + ) + + assert app._ezsp._protocol.set_extended_timeout.mock_calls == [ + call(nwk=packet.dst.address, ieee=ieee, extended_timeout=True) + ] + + @patch("bellows.zigbee.application.RETRY_DELAYS", [0.01, 0.01, 0.01]) async def test_send_packet_unicast_retries_success(app, packet): await _test_send_packet_unicast( From ec398b9958938dba29fc57943f2cb30ebdc6c9d2 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:55:09 -0400 Subject: [PATCH 4/4] Cache the address table size once it's read --- bellows/ezsp/protocol.py | 3 +++ bellows/ezsp/v4/__init__.py | 23 +++++++++++++---------- tests/test_ezsp_v4.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 0e6ad367..f9eca74e 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -43,6 +43,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: } self.tc_policy = 0 + # Cached by `set_extended_timeout` so subsequent calls are a little faster + self._address_table_size: int | None = None + def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes: """Serialize the named frame and data.""" c, tx_schema, rx_schema = self.COMMANDS[name] diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index ebf18001..534c842a 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -212,19 +212,22 @@ async def set_extended_timeout( ) return - (status, addr_table_size) = await self.getConfigurationValue( - t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE - ) - - if t.sl_Status.from_ember_status(status) != t.sl_Status.OK: - # Last-ditch effort - await self.setExtendedTimeout( - remoteEui64=ieee, extendedTimeout=extended_timeout + if self._address_table_size is None: + (status, addr_table_size) = await self.getConfigurationValue( + t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE ) - return + + if t.sl_Status.from_ember_status(status) != t.sl_Status.OK: + # Last-ditch effort + await self.setExtendedTimeout( + remoteEui64=ieee, extendedTimeout=extended_timeout + ) + return + + self._address_table_size = addr_table_size # Replace a random entry in the address table - index = random.randint(0, addr_table_size - 1) + index = random.randint(0, self._address_table_size - 1) await self.replaceAddressTableEntry( addressTableIndex=index, diff --git a/tests/test_ezsp_v4.py b/tests/test_ezsp_v4.py index 3f035366..1fec0187 100644 --- a/tests/test_ezsp_v4.py +++ b/tests/test_ezsp_v4.py @@ -420,6 +420,35 @@ async def test_set_extended_timeout_no_entry(ezsp_f) -> None: ) ] + # The address table size is cached + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 1 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + # Still called only once + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ] + + assert ezsp_f.replaceAddressTableEntry.mock_calls == [ + call( + addressTableIndex=0, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ), + call( + addressTableIndex=1, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ), + ] + async def test_set_extended_timeout_already_set(ezsp_f) -> None: # No-op, it's already set