diff --git a/pyproject.toml b/pyproject.toml index 7c7515c..974c0b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,14 @@ license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ "click>=8.0.0", - "zigpy", + "zigpy>=0.70.0", "crc", - "bellows~=0.41.0", + "bellows>=0.42.0", 'gpiod; platform_system=="Linux"', "coloredlogs", "async_timeout", "typing_extensions", + "pyserial-asyncio-fast", ] [tool.setuptools.packages.find] diff --git a/universal_silabs_flasher/common.py b/universal_silabs_flasher/common.py index 7d62318..938fcde 100644 --- a/universal_silabs_flasher/common.py +++ b/universal_silabs_flasher/common.py @@ -12,7 +12,6 @@ import async_timeout import click import crc -import serial_asyncio import zigpy.serial if typing.TYPE_CHECKING: @@ -118,61 +117,6 @@ async def wait_for_state(self, state: str) -> None: self._futures_for_state[state].remove(future) -class SerialProtocol(asyncio.Protocol): - """Base class for packet-parsing serial protocol implementations.""" - - def __init__(self) -> None: - self._buffer = bytearray() - self._transport: serial_asyncio.SerialTransport | None = None - self._connected_event = asyncio.Event() - - async def wait_until_connected(self) -> None: - """Wait for the protocol's transport to be connected.""" - await self._connected_event.wait() - - def connection_made(self, transport: serial_asyncio.SerialTransport) -> None: - _LOGGER.debug("Connection made: %s", transport) - - self._transport = transport - self._connected_event.set() - - def send_data(self, data: bytes) -> None: - """Sends data over the connected transport.""" - assert self._transport is not None - data = bytes(data) - _LOGGER.debug("Sending data %s", data) - self._transport.write(data) - - def data_received(self, data: bytes) -> None: - _LOGGER.debug("Received data %s", data) - self._buffer += data - - def disconnect(self) -> None: - if self._transport is not None: - self._transport.close() - self._buffer.clear() - self._connected_event.clear() - - -def patch_pyserial_asyncio() -> None: - """Patches pyserial-asyncio's `SerialTransport` to support swapping protocols.""" - - if ( - serial_asyncio.SerialTransport.get_protocol - is not asyncio.BaseTransport.get_protocol - ): - return - - def get_protocol(self) -> asyncio.Protocol: - return self._protocol - - def set_protocol(self, protocol: asyncio.Protocol) -> None: - self._protocol = protocol - - serial_asyncio.SerialTransport.get_protocol = get_protocol - serial_asyncio.SerialTransport.set_protocol = set_protocol - - @contextlib.asynccontextmanager async def connect_protocol(port, baudrate, factory): loop = asyncio.get_running_loop() @@ -189,10 +133,7 @@ async def connect_protocol(port, baudrate, factory): try: yield protocol finally: - protocol.disconnect() - - # Required for Windows to be able to re-connect to the same serial port - await asyncio.sleep(0) + await protocol.disconnect() class CommaSeparatedNumbers(click.ParamType): diff --git a/universal_silabs_flasher/cpc.py b/universal_silabs_flasher/cpc.py index 4e82478..6e84b5b 100644 --- a/universal_silabs_flasher/cpc.py +++ b/universal_silabs_flasher/cpc.py @@ -6,10 +6,11 @@ import typing import async_timeout +from zigpy.serial import SerialProtocol import zigpy.types from . import cpc_types -from .common import BufferTooShort, SerialProtocol, Version, crc16_ccitt +from .common import BufferTooShort, Version, crc16_ccitt _LOGGER = logging.getLogger(__name__) @@ -209,6 +210,8 @@ def poll_final(self) -> bool: class CPCProtocol(SerialProtocol): """Partial implementation of the CPC protocol.""" + _buffer: bytearray + def __init__(self) -> None: super().__init__() self._command_seq: int = 0 @@ -279,6 +282,11 @@ async def get_secondary_version(self) -> Version | None: return Version(version_bytes.split(b"\x00", 1)[0].decode("ascii")) + def send_data(self, data: bytes) -> None: + assert self._transport is not None + _LOGGER.debug("Sending data %s", data) + self._transport.write(data) + def data_received(self, data: bytes) -> None: super().data_received(data) diff --git a/universal_silabs_flasher/emberznet.py b/universal_silabs_flasher/emberznet.py index dd2040f..ca97929 100644 --- a/universal_silabs_flasher/emberznet.py +++ b/universal_silabs_flasher/emberznet.py @@ -1,47 +1,32 @@ -import asyncio import contextlib import bellows.config import bellows.ezsp import bellows.types +from bellows.zigbee.application import ControllerApplication import zigpy.config -AFTER_DISCONNECT_DELAY = 0.1 - @contextlib.asynccontextmanager async def connect_ezsp(port: str, baudrate: int = 115200) -> bellows.ezsp.EZSP: """Context manager to return a connected EZSP instance for a serial port.""" - app_config = zigpy.config.CONFIG_SCHEMA( - { - zigpy.config.CONF_DEVICE: { - zigpy.config.CONF_DEVICE_PATH: port, - zigpy.config.CONF_DEVICE_BAUDRATE: baudrate, - }, - bellows.config.CONF_EZSP_CONFIG: { - # Do not set any configuration on startup - "CONFIG_END_DEVICE_POLL_TIMEOUT": None, - "CONFIG_INDIRECT_TRANSMISSION_TIMEOUT": None, - "CONFIG_TC_REJOINS_USING_WELL_KNOWN_KEY_TIMEOUT_S": None, - "CONFIG_SECURITY_LEVEL": None, - "CONFIG_APPLICATION_ZDO_FLAGS": None, - "CONFIG_SUPPORTED_NETWORKS": None, - "CONFIG_PAN_ID_CONFLICT_REPORT_THRESHOLD": None, - "CONFIG_TRUST_CENTER_ADDRESS_CACHE_SIZE": None, - "CONFIG_SOURCE_ROUTE_TABLE_SIZE": None, - "CONFIG_MULTICAST_TABLE_SIZE": None, - "CONFIG_ADDRESS_TABLE_SIZE": None, - "CONFIG_PACKET_BUFFER_COUNT": None, - "CONFIG_STACK_PROFILE": None, - }, - bellows.config.CONF_USE_THREAD: False, - } + + ezsp = bellows.ezsp.EZSP( + # We use this roundabout way to construct the device schema to make sure that + # we are compatible with future changes to the zigpy device config schema. + ControllerApplication.SCHEMA( + { + zigpy.config.CONF_DEVICE: { + zigpy.config.CONF_DEVICE_PATH: port, + zigpy.config.CONF_DEVICE_BAUDRATE: baudrate, + } + } + )[zigpy.config.CONF_DEVICE] ) - ezsp = await bellows.ezsp.EZSP.initialize(app_config) + await ezsp.connect(use_thread=False) try: yield ezsp finally: - ezsp.close() - await asyncio.sleep(AFTER_DISCONNECT_DELAY) + await ezsp.disconnect() diff --git a/universal_silabs_flasher/flash.py b/universal_silabs_flasher/flash.py index 7576386..733cd42 100644 --- a/universal_silabs_flasher/flash.py +++ b/universal_silabs_flasher/flash.py @@ -17,7 +17,7 @@ import zigpy.ota.validators import zigpy.types -from .common import CommaSeparatedNumbers, patch_pyserial_asyncio, put_first +from .common import CommaSeparatedNumbers, put_first from .const import ( DEFAULT_BAUDRATES, FW_IMAGE_TYPE_TO_APPLICATION_TYPE, @@ -28,8 +28,6 @@ from .flasher import Flasher from .xmodemcrc import BLOCK_SIZE as XMODEM_BLOCK_SIZE, ReceiverCancelled -patch_pyserial_asyncio() - _LOGGER = logging.getLogger(__name__) LOG_LEVELS = ["INFO", "DEBUG"] diff --git a/universal_silabs_flasher/flasher.py b/universal_silabs_flasher/flasher.py index 756d3a5..7dd4e84 100644 --- a/universal_silabs_flasher/flasher.py +++ b/universal_silabs_flasher/flasher.py @@ -9,14 +9,9 @@ import bellows.config import bellows.ezsp import bellows.types +from zigpy.serial import SerialProtocol -from .common import ( - PROBE_TIMEOUT, - SerialProtocol, - Version, - connect_protocol, - pad_to_multiple, -) +from .common import PROBE_TIMEOUT, Version, connect_protocol, pad_to_multiple from .const import DEFAULT_BAUDRATES, GPIO_CONFIGS, ApplicationType, ResetTarget from .cpc import CPCProtocol from .emberznet import connect_ezsp @@ -115,8 +110,6 @@ async def probe_gecko_bootloader( if run_firmware: await gecko.run_firmware() _LOGGER.info("Launched application from bootloader") - - await asyncio.sleep(1) except NoFirmwareError: _LOGGER.warning("No application can be launched") return ProbeResult( diff --git a/universal_silabs_flasher/gecko_bootloader.py b/universal_silabs_flasher/gecko_bootloader.py index c3b922c..d423933 100644 --- a/universal_silabs_flasher/gecko_bootloader.py +++ b/universal_silabs_flasher/gecko_bootloader.py @@ -7,8 +7,9 @@ import typing import async_timeout +from zigpy.serial import SerialProtocol -from .common import PROBE_TIMEOUT, SerialProtocol, StateMachine, Version +from .common import PROBE_TIMEOUT, StateMachine, Version from .xmodemcrc import send_xmodem128_crc _LOGGER = logging.getLogger(__name__) @@ -142,6 +143,11 @@ async def upload_firmware( if self._upload_status != "complete": raise UploadError(self._upload_status) + def send_data(self, data: bytes) -> None: + assert self._transport is not None + _LOGGER.debug("Sending data %s", data) + self._transport.write(data) + def data_received(self, data: bytes) -> None: super().data_received(data) diff --git a/universal_silabs_flasher/spinel.py b/universal_silabs_flasher/spinel.py index bae1f5e..c5ce53c 100644 --- a/universal_silabs_flasher/spinel.py +++ b/universal_silabs_flasher/spinel.py @@ -6,9 +6,10 @@ import typing import async_timeout +from zigpy.serial import SerialProtocol import zigpy.types -from .common import SerialProtocol, Version, crc16_kermit +from .common import Version, crc16_kermit from .spinel_types import CommandID, HDLCSpecial, PropertyID, ResetReason _LOGGER = logging.getLogger(__name__) @@ -104,11 +105,18 @@ def serialize(self) -> bytes: class SpinelProtocol(SerialProtocol): + _buffer: bytearray + def __init__(self) -> None: super().__init__() self._transaction_id: int = 1 self._pending_frames: dict[int, asyncio.Future] = {} + def send_data(self, data: bytes) -> None: + assert self._transport is not None + _LOGGER.debug("Sending data %s", data) + self._transport.write(data) + def data_received(self, data: bytes) -> None: super().data_received(data) @@ -260,6 +268,3 @@ async def enter_bootloader(self) -> None: ResetReason.BOOTLOADER.serialize(), wait_response=False, ) - - # A small delay is necessary when switching baudrates - await asyncio.sleep(0.5)