From 0b26d064dbeb3bc644a26b0e37335666ccc34214 Mon Sep 17 00:00:00 2001 From: xs5871 <60395129+xs5871@users.noreply.github.com> Date: Wed, 8 Jan 2025 22:40:45 +0000 Subject: [PATCH] Refactor hid.py (#1068) * Small memory and performance improvements, mostly const() and bytearray ops. * Separate out report specific code into report classes. Serves as a base for easier addition of custom endpoints in "the future". * Lots of small integration changes, for example a simpler mechanism to record reports for unit tests (now with potential to verify other reports than 6KRO). --- kmk/hid.py | 487 ++++++++++++++++----------------------- kmk/keys.py | 4 +- kmk/kmk_keyboard.py | 19 +- kmk/modules/autoshift.py | 2 +- tests/keyboard_test.py | 26 ++- tests/mocks.py | 20 ++ tests/test_autoshift.py | 2 +- 7 files changed, 249 insertions(+), 311 deletions(-) diff --git a/kmk/hid.py b/kmk/hid.py index a87c3d234..f201c14d3 100644 --- a/kmk/hid.py +++ b/kmk/hid.py @@ -2,9 +2,9 @@ import usb_hid from micropython import const -from storage import getmount +from struct import pack, pack_into -from kmk.keys import ConsumerKey, KeyboardKey, ModifierKey, MouseKey +from kmk.keys import Axis, ConsumerKey, KeyboardKey, ModifierKey, MouseKey from kmk.scheduler import cancel_task, create_task from kmk.utils import Debug, clamp @@ -12,6 +12,9 @@ from adafruit_ble import BLERadio from adafruit_ble.advertising.standard import ProvideServicesAdvertisement from adafruit_ble.services.standard.hid import HIDService + from storage import getmount + + _BLE_APPEARANCE_HID_KEYBOARD = const(961) except ImportError: # BLE not supported on this platform pass @@ -25,342 +28,264 @@ class HIDModes: USB = 1 BLE = 2 - ALL_MODES = (NOOP, USB, BLE) +_USAGE_PAGE_CONSUMER = const(0x0C) +_USAGE_PAGE_KEYBOARD = const(0x01) +_USAGE_PAGE_MOUSE = const(0x01) +_USAGE_PAGE_SYSCONTROL = const(0x01) -class HIDReportTypes: - KEYBOARD = 1 - MOUSE = 2 - CONSUMER = 3 - SYSCONTROL = 4 +_USAGE_CONSUMER = const(0x01) +_USAGE_KEYBOARD = const(0x06) +_USAGE_MOUSE = const(0x02) +_USAGE_SYSCONTROL = const(0x80) +_REPORT_SIZE_CONSUMER = const(2) +_REPORT_SIZE_KEYBOARD = const(8) +_REPORT_SIZE_KEYBOARD_NKRO = const(16) +_REPORT_SIZE_MOUSE = const(4) +_REPORT_SIZE_MOUSE_HSCROLL = const(5) +_REPORT_SIZE_SYSCONTROL = const(8) -class HIDUsage: - KEYBOARD = 0x06 - MOUSE = 0x02 - CONSUMER = 0x01 - SYSCONTROL = 0x80 +def find_device(devices, usage_page, usage): + for device in devices: + if ( + device.usage_page == usage_page + and device.usage == usage + and hasattr(device, 'send_report') + ): + return device -class HIDUsagePage: - CONSUMER = 0x0C - KEYBOARD = MOUSE = SYSCONTROL = 0x01 +class Report: + def __init__(self, size): + self.buffer = bytearray(size) + self.pending = False -HID_REPORT_SIZES = { - HIDReportTypes.KEYBOARD: 8, - HIDReportTypes.MOUSE: 4, - HIDReportTypes.CONSUMER: 2, - HIDReportTypes.SYSCONTROL: 8, # TODO find the correct value for this -} + def clear(self): + for k, v in enumerate(self.buffer): + if v: + self.buffer[k] = 0x00 + self.pending = True + def get_action_map(self): + return {} -class AbstractHID: - report_bytes_default = 8 - report_bytes_nkro = 17 - REPORT_BYTES = report_bytes_default - hid_devices = {} - hid_ready = False - - def __init__(self, **kwargs): - self._nkro = False - self._mouse = True - self._pan = False - self.find_devices() - self.setup_keyboard_hid() - self.setup_consumer_control() - self.setup_mouse_hid() - def show_debug(self): - if self._nkro: - debug('use NKRO') - else: - debug('use 6KRO') - if self._mouse and self._pan: - debug('enable horizontal scrolling mouse') - elif self._mouse: - debug('enable mouse') - else: - debug('disable mouse') - - def find_devices(self): - self.devices = {} - - for device in self.hid_devices: - if not hasattr(device, 'send_report'): - continue - us = device.usage - up = device.usage_page - - if up == HIDUsagePage.CONSUMER and us == HIDUsage.CONSUMER: - self.devices[HIDReportTypes.CONSUMER] = device - elif up == HIDUsagePage.KEYBOARD and us == HIDUsage.KEYBOARD: - self.devices[HIDReportTypes.KEYBOARD] = device - elif up == HIDUsagePage.MOUSE and us == HIDUsage.MOUSE: - self.devices[HIDReportTypes.MOUSE] = device - elif up == HIDUsagePage.SYSCONTROL and us == HIDUsage.SYSCONTROL: - self.devices[HIDReportTypes.SYSCONTROL] = device +class KeyboardReport(Report): + def __init__(self, size=_REPORT_SIZE_KEYBOARD): + self.buffer = bytearray(size) + self.prev_buffer = bytearray(size) - def setup_keyboard_hid(self): - self.REPORT_BYTES = self.report_bytes_default - self._evt = bytearray(self.REPORT_BYTES) - self._evt[0] = HIDReportTypes.KEYBOARD + @property + def pending(self): + return self.buffer != self.prev_buffer - # bodgy NKRO autodetect - try: - self.hid_send(self._evt) - except ValueError: - self.REPORT_BYTES = self.report_bytes_nkro - self._evt = bytearray(self.REPORT_BYTES) - self._evt[0] = HIDReportTypes.KEYBOARD - self._nkro = True + @pending.setter + def pending(self, v): + if v is False: + self.prev_buffer[:] = self.buffer[:] - self._prev_evt = bytearray(self.REPORT_BYTES) + def clear(self): + for idx in range(len(self.buffer)): + self.buffer[idx] = 0x00 - # Landmine alert for HIDReportTypes.KEYBOARD: byte index 1 of this view - # is "reserved" and evidently (mostly?) unused. However, other modes (or - # at least consumer, so far) will use this byte, which is the main reason - # this view exists. For KEYBOARD, use report_mods and report_non_mods - self.report_keys = memoryview(self._evt)[1:] + def add_key(self, key): + # Find the first empty slot in the key report, and fill it; drop key if + # report is full. + idx = self.buffer.find(b'\x00', 2) - self.report_mods = memoryview(self._evt)[1:2] - self.report_non_mods = memoryview(self._evt)[3:] + if 0 < idx < _REPORT_SIZE_KEYBOARD: + self.buffer[idx] = key.code - def setup_consumer_control(self): - self._cc_report = bytearray(HID_REPORT_SIZES[HIDReportTypes.CONSUMER] + 1) - self._cc_report[0] = HIDReportTypes.CONSUMER - self._cc_pending = False + def remove_key(self, key): + idx = self.buffer.find(pack('B', key.code), 2) + if 0 < idx: + self.buffer[idx] = 0x00 - def setup_mouse_hid(self): - self._pd_report = bytearray(HID_REPORT_SIZES[HIDReportTypes.MOUSE] + 1) - self._pd_report[0] = HIDReportTypes.MOUSE - self._pd_pending = False + def add_modifier(self, modifier): + self.buffer[0] |= modifier.code - # bodgy pointing device panning autodetect - try: - self.hid_send(self._pd_report) - except ValueError: - self._pd_report = bytearray(6) - self._pd_report[0] = HIDReportTypes.MOUSE - self._pan = True - except KeyError: - self._mouse = False + def remove_modifier(self, modifier): + self.buffer[0] &= ~modifier.code - def __repr__(self): - return f'{self.__class__.__name__}(REPORT_BYTES={self.REPORT_BYTES})' - - def create_report(self, keys_pressed, axes): - self.clear_all() - - for key in keys_pressed: - if isinstance(key, KeyboardKey): - self.add_key(key) - elif isinstance(key, ModifierKey): - self.add_modifier(key) - elif isinstance(key, ConsumerKey): - self.add_cc(key) - elif isinstance(key, MouseKey): - self.add_pd(key) - - for axis in axes: - self.move_axis(axis) - - def hid_send(self, evt): - # Don't raise a NotImplementedError so this can serve as our "dummy" HID - # when MCU/board doesn't define one to use (which should almost always be - # the CircuitPython-targeting one, except when unit testing or doing - # something truly bizarre. This will likely change eventually when Bluetooth - # is added) - pass + def get_action_map(self): + return {KeyboardKey: self.add_key, ModifierKey: self.add_modifier} - def send(self): - if self._evt != self._prev_evt: - self._prev_evt[:] = self._evt - self.hid_send(self._evt) - if self._cc_pending: - self.hid_send(self._cc_report) - self._cc_pending = False +class NKROKeyboardReport(KeyboardReport): + def __init__(self): + super().__init__(_REPORT_SIZE_KEYBOARD_NKRO) - if self._pd_pending: - self.hid_send(self._pd_report) - self._pd_pending = False + def add_key(self, key): + self.buffer[(key.code >> 3) + 1] |= 1 << (key.code & 0x07) - return self + def remove_key(self, key): + self.buffer[(key.code >> 3) + 1] &= ~(1 << (key.code & 0x07)) - def clear_all(self): - for idx, _ in enumerate(self.report_keys): - self.report_keys[idx] = 0x00 - self.remove_cc() - self.remove_pd() - self.clear_axis() +class ConsumerControlReport(Report): + def __init__(self): + super().__init__(_REPORT_SIZE_CONSUMER) - return self + def add_cc(self, cc): + pack_into('> 3) + 1] |= 1 << (key.code & 0x07) + def get_action_map(self): + return {Axis: self.move_axis, MouseKey: self.add_button} - def remove_key(self, key): - if not self._nkro: - code = key.code.to_bytes(1, 'little') - idx = self._evt.find(code, 3) - self._evt[idx] = 0x00 - else: - self.report_keys[(key.code >> 3) + 1] &= ~(1 << (key.code & 0x07)) - def add_cc(self, cc): - # Add (or write over) consumer control report. There can only be one CC - # active at any time. - memoryview(self._cc_report)[1:3] = cc.code.to_bytes(2, 'little') - self._cc_pending = True +class HSPointingDeviceReport(PointingDeviceReport): + def __init__(self): + super().__init__(_REPORT_SIZE_MOUSE_HSCROLL) - def remove_cc(self): - # Remove consumer control report. - report = memoryview(self._cc_report)[1:3] - if report != b'\x00\x00': - report[:] = b'\x00\x00' - self._cc_pending = True - def add_pd(self, key): - self._pd_report[1] |= key.code - self._pd_pending = True +class AbstractHID: + def __init__(self): + self.report_map = {} + self.device_map = {} + self._setup_task = create_task(self.setup, period_ms=100) - def remove_pd(self): - if self._pd_report[1]: - self._pd_pending = True - self._pd_report[1] = 0x00 + def __repr__(self): + return self.__class__.__name__ + + def create_report(self, keys): + for report in self.device_map.keys(): + report.clear() + + for key in keys: + if action := self.report_map.get(type(key)): + action(key) + + def send(self): + for report in self.device_map.keys(): + if report.pending: + self.device_map[report].send_report(report.buffer) + report.pending = False + + def setup(self): + if not self.connected: + return - def move_axis(self, axis): - delta = clamp(axis.delta, -127, 127) - axis.delta -= delta try: - self._pd_report[axis.code + 2] = 0xFF & delta - self._pd_pending = True - except IndexError: + self.setup_keyboard_hid() + self.setup_consumer_control() + self.setup_mouse_hid() + + cancel_task(self._setup_task) + self._setup_task = None if debug.enabled: - debug('Axis(', axis.code, ') not supported') - - def clear_axis(self): - for idx in range(2, len(self._pd_report)): - self._pd_report[idx] = 0x00 - - def has_key(self, key): - if isinstance(key, ModifierKey): - return bool(self.report_mods[0] & key.code) - else: - if not self._nkro: - code = key.code.to_bytes(1, 'little') - return self.report_non_mods.find(code) > 0 - else: - part = self.report_keys[(key.code >> 3) + 1] - return bool(part & (1 << (key.code & 0x07))) - return False + self.show_debug() + except OSError as e: + if debug.enabled: + debug(type(e), ':', e) -class USBHID(AbstractHID): - report_bytes_default = 9 - REPORT_BYTES = report_bytes_default + def setup_keyboard_hid(self): + if device := find_device(self.devices, _USAGE_PAGE_KEYBOARD, _USAGE_KEYBOARD): + # bodgy NKRO autodetect + try: + report = KeyboardReport() + device.send_report(report.buffer) + except ValueError: + report = NKROKeyboardReport() + + self.report_map.update(report.get_action_map()) + self.device_map[report] = device - def __init__(self, **kwargs): - self.hid = usb_hid - self.hid_devices = self.hid.devices - super().__init__(**kwargs) - self._setup_task = self.wait_until_connected() + def setup_consumer_control(self): + if device := find_device(self.devices, _USAGE_PAGE_CONSUMER, _USAGE_CONSUMER): + report = ConsumerControlReport() + self.report_map.update(report.get_action_map()) + self.device_map[report] = device - def test_reports(self): - if self._connected(): + def setup_mouse_hid(self): + if device := find_device(self.devices, _USAGE_PAGE_MOUSE, _USAGE_MOUSE): + # bodgy pointing device panning autodetect try: - self.hid_ready = True - self.setup_keyboard_hid() - self.setup_consumer_control() - self.setup_mouse_hid() - cancel_task(self._setup_task) - self._setup_task = None - if debug.enabled: - self.show_debug() - self.hid_ready = True - except OSError as e: - if debug.enabled: - debug(type(e), ':', e) + report = PointingDeviceReport() + device.send_report(report.buffer) + except ValueError: + report = HSPointingDeviceReport() - def wait_until_connected(self, period_ms=200): - return create_task(self.test_reports, period_ms=period_ms) + self.report_map.update(report.get_action_map()) + self.device_map[report] = device - def _connected(self): - return supervisor.runtime.usb_connected + def show_debug(self): + for report in self.device_map.keys(): + debug('use ', report.__class__.__name__) - def hid_send(self, evt): - if not self.hid_ready or not self._connected(): - return - # int, can be looked up in HIDReportTypes - reporting_device_const = evt[0] +class USBHID(AbstractHID): + @property + def connected(self): + return supervisor.runtime.usb_connected - return self.devices[reporting_device_const].send_report(evt[1:]) + @property + def devices(self): + return usb_hid.devices class BLEHID(AbstractHID): - BLE_APPEARANCE_HID_KEYBOARD = const(961) - # Hardcoded in CPy - MAX_CONNECTIONS = const(2) - ble_connected = False + def __init__(self, ble_name=None): + super().__init__() - def __init__(self, ble_name=str(getmount('/').label), **kwargs): - self.ble_name = ble_name self.ble = BLERadio() - self.ble.name = self.ble_name + self.ble.name = ble_name if ble_name else getmount('/').label + self.ble_connected = False + self.hid = HIDService() - self.hid_devices = self.hid.devices self.hid.protocol_mode = 0 # Boot protocol - super().__init__(**kwargs) - self.start_ble_monitor() - def _connected(self): + create_task(self.ble_monitor, period_ms=1000) + + @property + def connected(self): return self.ble.connected + @property + def devices(self): + return self.hid.devices + def ble_monitor(self): - if self.ble_connected != self._connected(): - self.ble_connected = self._connected() - if self._connected(): - self.find_devices() - self.hid_ready = True + if self.ble_connected != self.connected: + self.ble_connected = self.connected + if self._connected: if debug.enabled: debug('BLE connected') else: - self.hid_ready = False # Security-wise this is not right. While you're away someone turns # on your keyboard and they can pair with it nice and clean and then # listen to keystrokes. @@ -370,24 +295,6 @@ def ble_monitor(self): if debug.enabled: debug('BLE disconnected') - def start_ble_monitor(self, period_ms=200): - return create_task(self.setup, period_ms=period_ms) - - def hid_send(self, evt): - if not self.hid_ready or not self._connected(): - return - - # int, can be looked up in HIDReportTypes - reporting_device_const = evt[0] - - device = self.devices[reporting_device_const] - - report_size = len(device._characteristic.value) - while len(evt) < report_size + 1: - evt.append(0) - - return device.send_report(evt[1 : report_size + 1]) # noqa: E203 - def clear_bonds(self): import _bleio @@ -396,7 +303,7 @@ def clear_bonds(self): def start_advertising(self): if not self.ble.advertising: advertisement = ProvideServicesAdvertisement(self.hid) - advertisement.appearance = self.BLE_APPEARANCE_HID_KEYBOARD + advertisement.appearance = _BLE_APPEARANCE_HID_KEYBOARD self.ble.start_advertising(advertisement) diff --git a/kmk/keys.py b/kmk/keys.py index e935988a8..27e9127fc 100644 --- a/kmk/keys.py +++ b/kmk/keys.py @@ -30,10 +30,10 @@ def __repr__(self) -> str: def move(self, keyboard: Keyboard, delta: int): self.delta += delta if self.delta: - keyboard.axes.add(self) + keyboard.keys_pressed.add(self) keyboard.hid_pending = True else: - keyboard.axes.discard(self) + keyboard.keys_pressed.discard(self) class AX: diff --git a/kmk/kmk_keyboard.py b/kmk/kmk_keyboard.py index 6a0d19aee..395ca0a80 100644 --- a/kmk/kmk_keyboard.py +++ b/kmk/kmk_keyboard.py @@ -7,7 +7,7 @@ from keypad import Event as KeyEvent from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes -from kmk.keys import KC, Key +from kmk.keys import KC, Axis, Key from kmk.modules import Module from kmk.scanners.keypad import MatrixScanner from kmk.scheduler import Task, cancel_task, create_task, get_due_task @@ -51,7 +51,6 @@ class KMKKeyboard: ##### # Internal State keys_pressed = set() - axes = set() _coordkeys_pressed = {} implicit_modifier = None hid_type = HIDModes.USB @@ -83,10 +82,8 @@ def _send_hid(self) -> None: if debug.enabled: if self.keys_pressed: debug('keys_pressed=', self.keys_pressed) - if self.axes: - debug('axes=', self.axes) - self._hid_helper.create_report(self.keys_pressed, self.axes) + self._hid_helper.create_report(self.keys_pressed) try: self._hid_helper.send() except Exception as err: @@ -94,8 +91,9 @@ def _send_hid(self) -> None: self.hid_pending = False - for axis in self.axes: - axis.move(self, 0) + for key in self.keys_pressed: + if isinstance(key, Axis): + key.move(self, 0) def _handle_matrix_report(self, kevent: KeyEvent) -> None: if kevent is not None: @@ -293,8 +291,11 @@ def _init_hid(self) -> None: debug('hid=', self._hid_helper) def _deinit_hid(self) -> None: - self._hid_helper.clear_all() - self._hid_helper.send() + try: + self._hid_helper.create_report({}) + self._hid_helper.send() + except Exception as e: + debug_error(self, '_deinit_hid', e) def _init_matrix(self) -> None: if self.matrix is None: diff --git a/kmk/modules/autoshift.py b/kmk/modules/autoshift.py index fc48f0f83..44c0695b5 100644 --- a/kmk/modules/autoshift.py +++ b/kmk/modules/autoshift.py @@ -30,7 +30,7 @@ def process_key(self, keyboard, key, is_pressed, int_coord): return key # Only shift from an unshifted state - if keyboard._hid_helper.has_key(KC.LSHIFT): + if KC.LSFT in keyboard.keys_pressed: return key # Ignore rolls from tapped to hold diff --git a/tests/keyboard_test.py b/tests/keyboard_test.py index 2d44bdca8..20d489316 100644 --- a/tests/keyboard_test.py +++ b/tests/keyboard_test.py @@ -1,7 +1,8 @@ import digitalio +import mock_hid import time -from unittest.mock import Mock, patch +from unittest.mock import Mock from kmk import scheduler from kmk.hid import HIDModes @@ -64,15 +65,16 @@ def __init__( scheduler._task_queue = scheduler.TaskQueue() self.keyboard._init(hid_type=HIDModes.NOOP) + self.keyboard._hid_helper.connected = True + self.keyboard._hid_helper.devices = mock_hid.devices + self.keyboard._hid_helper.setup() + for hid in mock_hid.devices: + hid.reports.clear() - @patch('kmk.hid.AbstractHID.hid_send') - def test(self, testname, key_events, assert_reports, hid_send): - if self.debug_enabled: - print(testname) - + def get_keyboard_report(self, key_events): # setup report recording - hid_reports = [] - hid_send.side_effect = lambda report: hid_reports.append(report[1:]) + keyboard_hid = self.keyboard._hid_helper.devices[0] + keyboard_hid.reports.clear() # inject key switch events self.keyboard._main_loop() @@ -95,6 +97,14 @@ def test(self, testname, key_events, assert_reports, hid_send): break assert timeout > time.time_ns(), 'infinite loop detected' + return keyboard_hid.reports + + def test(self, testname, key_events, assert_reports): + if self.debug_enabled: + print(testname) + + hid_reports = self.get_keyboard_report(key_events) + matching = True for i in range(max(len(hid_reports), len(assert_reports))): # prepare the generated report codes diff --git a/tests/mocks.py b/tests/mocks.py index d64bdb5b3..c32a008c4 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -13,8 +13,28 @@ def ticks_ms(): return (time.time_ns() // 1_000_000) % (1 << 29) +class Device: + def __init__(self, usage_page, usage): + self.usage_page = usage_page + self.usage = usage + self.reports = [] + + def send_report(self, report): + self.reports.append(report[:]) + + def init_circuit_python_modules_mocks(): sys.modules['usb_hid'] = Mock() + sys.modules['mock_hid'] = Mock() + sys.modules['mock_hid'].devices = [ + Device(p, u) + for p, u in [ + (0x01, 0x06), # keyboard + (0x01, 0x02), # mouse + (0x0C, 0x01), # consumer control + ] + ] + sys.modules['digitalio'] = Mock() sys.modules['neopixel'] = Mock() sys.modules['pulseio'] = Mock() diff --git a/tests/test_autoshift.py b/tests/test_autoshift.py index e38fd3195..5f1aa6073 100644 --- a/tests/test_autoshift.py +++ b/tests/test_autoshift.py @@ -76,7 +76,7 @@ def test_hold_shifted_hold_alpha(self): self.kb.test( '', [(2, True), (0, True), t_after, (2, False), (0, False)], - [{KC.LSHIFT, KC.N3}, {KC.N3, KC.A}, {KC.A}, {}], + [{KC.LSHIFT, KC.N3}, {KC.LSHIFT, KC.N3, KC.A}, {KC.A}, {}], ) def test_hold_internal(self):