Skip to content

Commit

Permalink
Move NCP ASH implementation into tests
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly committed Apr 15, 2024
1 parent ef3c3c1 commit 7f604e0
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 106 deletions.
167 changes: 63 additions & 104 deletions bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ class NcpState(enum.Enum):
FAILED = "failed"


class AshRole(enum.Enum):
HOST = "host"
NCP = "ncp"


class ParsingError(Exception):
pass

Expand Down Expand Up @@ -327,7 +322,7 @@ class ErrorFrame(AshFrame):


class AshProtocol(asyncio.Protocol):
def __init__(self, ezsp_protocol, *, role: AshRole = AshRole.HOST) -> None:
def __init__(self, ezsp_protocol) -> None:
self._ezsp_protocol = ezsp_protocol
self._transport = None
self._buffer = bytearray()
Expand All @@ -338,7 +333,6 @@ def __init__(self, ezsp_protocol, *, role: AshRole = AshRole.HOST) -> None:
self._rx_seq: int = 0
self._t_rx_ack = T_RX_ACK_INIT

self._role: AshRole = role
self._ncp_reset_code: t.NcpResetCode | None = None
self._ncp_state: NcpState = NcpState.CONNECTED

Expand All @@ -352,12 +346,6 @@ def connection_lost(self, exc):
def eof_received(self):
self._ezsp_protocol.eof_received()

def _get_tx_seq(self) -> int:
result = self._tx_seq
self._tx_seq = (self._tx_seq + 1) % 8

return result

def close(self):
if self._transport is not None:
self._transport.close()
Expand Down Expand Up @@ -471,78 +459,72 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
def frame_received(self, frame: AshFrame) -> None:
_LOGGER.debug("Received frame %r", frame)

if (
self._ncp_reset_code is not None
and self._role == AshRole.NCP
and not isinstance(frame, RstFrame)
):
_LOGGER.debug(
"NCP in failure state %r, ignoring frame: %r",
self._ncp_reset_code,
frame,
)
self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code))
return

if isinstance(frame, DataFrame):
# The Host may not piggyback acknowledgments and should promptly send an ACK
# frame when it receives a DATA frame.
if frame.frm_num == self._rx_seq:
self._handle_ack(frame)
self._rx_seq = (frame.frm_num + 1) % 8
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))

self._ezsp_protocol.data_received(frame.ezsp_frame)
elif frame.re_tx:
# Retransmitted frames must be immediately ACKed even if they are out of
# sequence
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))
else:
_LOGGER.debug("Received an out of sequence frame: %r", frame)
self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))
self.data_frame_received(frame)
elif isinstance(frame, RStackFrame):
self._ncp_reset_code = None
self._ncp_state = NcpState.CONNECTED

self._tx_seq = 0
self._rx_seq = 0
self._change_ack_timeout(T_RX_ACK_INIT)
self._ezsp_protocol.reset_received(frame.reset_code)
self.rstack_frame_received(frame)
elif isinstance(frame, AckFrame):
self._handle_ack(frame)
self.ack_frame_received(frame)
elif isinstance(frame, NakFrame):
error = NotAcked(frame=frame)

for frm_num, fut in self._pending_data_frames.items():
if (
not frame.ack_num - TX_K <= frm_num <= frame.ack_num
and not fut.done()
):
fut.set_exception(error)
self.nak_frame_received(frame)
elif isinstance(frame, RstFrame):
self._ncp_reset_code = None
self._ncp_state = NcpState.CONNECTED

if self._role == AshRole.NCP:
self._tx_seq = 0
self._rx_seq = 0
self._change_ack_timeout(T_RX_ACK_INIT)

self._enter_ncp_error_state(None)
self._write_frame(
RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)
)
elif isinstance(frame, ErrorFrame) and self._role == AshRole.HOST:
_LOGGER.debug("NCP has entered failed state: %s", frame.reset_code)
self._ncp_reset_code = frame.reset_code
self._ncp_state = NcpState.FAILED

# Cancel all pending requests
exc = NcpFailure(code=self._ncp_reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)
self.rst_frame_received(frame)
elif isinstance(frame, ErrorFrame):
self.error_frame_received(frame)
else:
raise TypeError(f"Unknown frame received: {frame}") # pragma: no cover

def data_frame_received(self, frame: DataFrame) -> None:
# The Host may not piggyback acknowledgments and should promptly send an ACK
# frame when it receives a DATA frame.
if frame.frm_num == self._rx_seq:
self._handle_ack(frame)
self._rx_seq = (frame.frm_num + 1) % 8
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))

self._ezsp_protocol.data_received(frame.ezsp_frame)
elif frame.re_tx:
# Retransmitted frames must be immediately ACKed even if they are out of
# sequence
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))
else:
_LOGGER.debug("Received an out of sequence frame: %r", frame)
self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))

def rstack_frame_received(self, frame: RStackFrame) -> None:
self._ncp_reset_code = None
self._ncp_state = NcpState.CONNECTED

self._tx_seq = 0
self._rx_seq = 0
self._change_ack_timeout(T_RX_ACK_INIT)
self._ezsp_protocol.reset_received(frame.reset_code)

def ack_frame_received(self, frame: AckFrame) -> None:
self._handle_ack(frame)

def nak_frame_received(self, frame: NakFrame) -> None:
err = NotAcked(frame=frame)

for frm_num, fut in self._pending_data_frames.items():
if not frame.ack_num - TX_K <= frm_num <= frame.ack_num and not fut.done():
fut.set_exception(err)

def rst_frame_received(self, frame: RstFrame) -> None:
self._ncp_reset_code = None
self._ncp_state = NcpState.CONNECTED

def error_frame_received(self, frame: ErrorFrame) -> None:
_LOGGER.debug("NCP has entered failed state: %s", frame.reset_code)
self._ncp_reset_code = frame.reset_code
self._ncp_state = NcpState.FAILED

# Cancel all pending requests
exc = NcpFailure(code=self._ncp_reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

def _write_frame(self, frame: AshFrame) -> None:
_LOGGER.debug("Sending frame %r", frame)
Expand All @@ -561,20 +543,6 @@ def _change_ack_timeout(self, new_value: float) -> None:

self._t_rx_ack = new_value

def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None:
self._ncp_reset_code = code

if code is None:
self._ncp_state = NcpState.CONNECTED
else:
self._ncp_state = NcpState.FAILED

_LOGGER.debug("Changing connectivity state: %r", self._ncp_state)
_LOGGER.debug("Changing reset code: %r", self._ncp_reset_code)

if self._ncp_state == NcpState.FAILED:
self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code))

async def _send_frame(self, frame: AshFrame) -> None:
if not isinstance(frame, DataFrame):
# Non-DATA frames can be sent immediately and do not require an ACK
Expand All @@ -589,10 +557,7 @@ async def _send_frame(self, frame: AshFrame) -> None:

try:
for attempt in range(ACK_TIMEOUTS):
if (
self._role == AshRole.HOST
and self._ncp_state == NcpState.FAILED
):
if self._ncp_state == NcpState.FAILED:
_LOGGER.debug(
"NCP is in a failed state, not re-sending: %r", frame
)
Expand Down Expand Up @@ -647,12 +612,6 @@ async def _send_frame(self, frame: AshFrame) -> None:
self._change_ack_timeout(2 * self._t_rx_ack)

if attempt >= ACK_TIMEOUTS - 1:
# Only a timeout is enough to enter an error state
if self._role == AshRole.NCP:
self._enter_ncp_error_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)

raise
else:
# Whenever an acknowledgement is received, t_rx_ack is set to
Expand Down
64 changes: 62 additions & 2 deletions tests/test_ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,66 @@
import bellows.types as t


class AshNcpProtocol(ash.AshProtocol):
def frame_received(self, frame: ash.AshFrame) -> None:
if self._ncp_reset_code is not None and not isinstance(frame, ash.RstFrame):
ash._LOGGER.debug(
"NCP in failure state %r, ignoring frame: %r",
self._ncp_reset_code,
frame,
)
self._write_frame(
ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code)
)
return

super().frame_received(frame)

def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None:
self._ncp_reset_code = code

if code is None:
self._ncp_state = ash.NcpState.CONNECTED
else:
self._ncp_state = ash.NcpState.FAILED

ash._LOGGER.debug("Changing connectivity state: %r", self._ncp_state)
ash._LOGGER.debug("Changing reset code: %r", self._ncp_reset_code)

if self._ncp_state == ash.NcpState.FAILED:
self._write_frame(
ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code)
)

def rst_frame_received(self, frame: ash.RstFrame) -> None:
super().rst_frame_received(frame)

self._tx_seq = 0
self._rx_seq = 0
self._change_ack_timeout(ash.T_RX_ACK_INIT)

self._enter_ncp_error_state(None)
self._write_frame(
ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)
)

async def _send_frame(self, frame: ash.AshFrame) -> None:
try:
return await super()._send_frame(frame)
except asyncio.TimeoutError:
self._enter_ncp_error_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)
raise
if not isinstance(frame, ash.DataFrame):
# Non-DATA frames can be sent immediately and do not require an ACK
self._write_frame(frame)
return

def send_reset(self) -> None:
raise NotImplementedError()


def test_stuffing():
assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E"
assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31"
Expand Down Expand Up @@ -129,8 +189,8 @@ def write(self, data):
if not self.paused:
self.receiver.data_received(data)

host = ash.AshProtocol(host_ezsp, role=ash.AshRole.HOST)
ncp = ash.AshProtocol(ncp_ezsp, role=ash.AshRole.NCP)
host = ash.AshProtocol(host_ezsp)
ncp = AshNcpProtocol(ncp_ezsp)

host_transport = FakeTransport(ncp)
ncp_transport = FakeTransport(host)
Expand Down

0 comments on commit 7f604e0

Please sign in to comment.