From 40ffe13448857df6d2ce34450f4a68234ade5b87 Mon Sep 17 00:00:00 2001 From: Jeff Irion Date: Thu, 4 Oct 2018 13:31:33 -0700 Subject: [PATCH] Fix strings and add tests (#128) * Fix python3-ism with byte v. str (copied from @rustyhowell) * Changed binary strings to regular strings * Version bump to 1.3.0.dev * If 'serial' is bytes, convert it to unicode * Version bump to 1.3.0.1 * Add changes from @tuxuser --- adb/adb_commands.py | 15 ++++-- adb/adb_protocol.py | 5 ++ adb/common.py | 23 ++++++--- test/adb_test.py | 49 +++++++++++++++--- test/common_stub.py | 115 +++++++++++++++++++++++++----------------- test/fastboot_test.py | 2 +- tox.ini | 1 + 7 files changed, 146 insertions(+), 64 deletions(-) diff --git a/adb/adb_commands.py b/adb/adb_commands.py index 667b713..f3667c8 100644 --- a/adb/adb_commands.py +++ b/adb/adb_commands.py @@ -127,12 +127,17 @@ def ConnectDevice(self, port_path=None, serial=None, default_timeout_ms=None, ** # If there isnt a handle override (used by tests), build one here if 'handle' in kwargs: self._handle = kwargs.pop('handle') - elif serial and b':' in serial: - self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms) else: - self._handle = common.UsbHandle.FindAndOpen( - DeviceIsAvailable, port_path=port_path, serial=serial, - timeout_ms=default_timeout_ms) + # if necessary, convert serial to a unicode string + if isinstance(serial, (bytes, bytearray)): + serial = serial.decode('utf-8') + + if serial and ':' in serial: + self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms) + else: + self._handle = common.UsbHandle.FindAndOpen( + DeviceIsAvailable, port_path=port_path, serial=serial, + timeout_ms=default_timeout_ms) self._Connect(**kwargs) diff --git a/adb/adb_protocol.py b/adb/adb_protocol.py index 9654e51..4ff28c7 100644 --- a/adb/adb_protocol.py +++ b/adb/adb_protocol.py @@ -302,6 +302,11 @@ def Connect(cls, usb, banner=b'notadb', rsa_keys=None, auth_timeout_ms=100): InvalidResponseError: When the device does authentication in an unexpected way. """ + # In py3, convert unicode to bytes. In py2, convert str to bytes. + # It's later joined into a byte string, so in py2, this ends up kind of being a no-op. + if isinstance(banner, str): + banner = bytearray(banner, 'utf-8') + msg = cls( command=b'CNXN', arg0=VERSION, arg1=MAX_ADB_DATA, data=b'host::%s\0' % banner) diff --git a/adb/common.py b/adb/common.py index a68414f..0c78728 100644 --- a/adb/common.py +++ b/adb/common.py @@ -298,15 +298,26 @@ def __init__(self, serial, timeout_ms=None): Host may be an IP address or a host name. """ - if b':' in serial: - (host, port) = serial.split(b':') + # if necessary, convert serial to a unicode string + if isinstance(serial, (bytes, bytearray)): + serial = serial.decode('utf-8') + + if ':' in serial: + self.host, self.port = serial.split(':') else: - host = serial - port = 5555 - self._serial_number = '%s:%s' % (host, port) + self.host = serial + self.port = 5555 + + self._connection = None + self._serial_number = '%s:%s' % (self.host, self.port) self._timeout_ms = float(timeout_ms) if timeout_ms else None + + self._connect() + + def _connect(self): timeout = self.TimeoutSeconds(self._timeout_ms) - self._connection = socket.create_connection((host, port), timeout=timeout) + self._connection = socket.create_connection((self.host, self.port), + timeout=timeout) if timeout: self._connection.setblocking(0) diff --git a/test/adb_test.py b/test/adb_test.py index bdbfce5..0ce1ead 100755 --- a/test/adb_test.py +++ b/test/adb_test.py @@ -17,10 +17,13 @@ from io import BytesIO import struct import unittest +from mock import mock + +from adb import common from adb import adb_commands from adb import adb_protocol -from adb.usb_exceptions import TcpTimeoutException +from adb.usb_exceptions import TcpTimeoutException, DeviceNotFoundError import common_stub @@ -78,10 +81,9 @@ def _Connect(cls, usb): class AdbTest(BaseAdbTest): - @classmethod def _ExpectCommand(cls, service, command, *responses): - usb = common_stub.StubUsb() + usb = common_stub.StubUsb(device=None, setting=None) cls._ExpectConnection(usb) cls._ExpectOpen(usb, b'%s:%s\0' % (service, command)) @@ -91,12 +93,19 @@ def _ExpectCommand(cls, service, command, *responses): return usb def testConnect(self): - usb = common_stub.StubUsb() + usb = common_stub.StubUsb(device=None, setting=None) self._ExpectConnection(usb) dev = adb_commands.AdbCommands() dev.ConnectDevice(handle=usb, banner=BANNER) + def testConnectSerialString(self): + dev = adb_commands.AdbCommands() + + with mock.patch.object(common.UsbHandle, 'FindAndOpen', return_value=None): + with mock.patch.object(adb_commands.AdbCommands, '_Connect', return_value=None): + dev.ConnectDevice(serial='/dev/invalidHandle') + def testSmallResponseShell(self): command = b'keepin it real' response = 'word.' @@ -196,7 +205,7 @@ def _MakeWriteSyncPacket(cls, command, data=b'', size=None): @classmethod def _ExpectSyncCommand(cls, write_commands, read_commands): - usb = common_stub.StubUsb() + usb = common_stub.StubUsb(device=None, setting=None) cls._ExpectConnection(usb) cls._ExpectOpen(usb, b'sync:\0') @@ -246,7 +255,7 @@ class TcpTimeoutAdbTest(BaseAdbTest): @classmethod def _ExpectCommand(cls, service, command, *responses): - tcp = common_stub.StubTcp() + tcp = common_stub.StubTcp('10.0.0.123') cls._ExpectConnection(tcp) cls._ExpectOpen(tcp, b'%s:%s\0' % (service, command)) @@ -262,7 +271,7 @@ def _run_shell(self, cmd, timeout_ms=None): dev.Shell(cmd, timeout_ms=timeout_ms) def testConnect(self): - tcp = common_stub.StubTcp() + tcp = common_stub.StubTcp('10.0.0.123') self._ExpectConnection(tcp) dev = adb_commands.AdbCommands() dev.ConnectDevice(handle=tcp, banner=BANNER) @@ -276,5 +285,31 @@ def testTcpTimeout(self): command, timeout_ms=timeout_ms) + +class TcpHandleTest(unittest.TestCase): + def testInitWithHost(self): + tcp = common_stub.StubTcp('10.11.12.13') + + self.assertEqual('10.11.12.13:5555', tcp._serial_number) + self.assertEqual(None, tcp._timeout_ms) + + def testInitWithHostAndPort(self): + tcp = common_stub.StubTcp('10.11.12.13:5678') + + self.assertEqual('10.11.12.13:5678', tcp._serial_number) + self.assertEqual(None, tcp._timeout_ms) + + def testInitWithTimeout(self): + tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234.5) + + self.assertEqual('10.0.0.2:5555', tcp._serial_number) + self.assertEqual(234.5, tcp._timeout_ms) + + def testInitWithTimeoutInt(self): + tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234) + + self.assertEqual('10.0.0.2:5555', tcp._serial_number) + self.assertEqual(234.0, tcp._timeout_ms) + if __name__ == '__main__': unittest.main() diff --git a/test/common_stub.py b/test/common_stub.py index e2a2e3e..f993ef1 100644 --- a/test/common_stub.py +++ b/test/common_stub.py @@ -5,7 +5,10 @@ import string import sys import time -from adb.usb_exceptions import TcpTimeoutException +from mock import mock + +from adb.common import TcpHandle, UsbHandle +from adb.usb_exceptions import TcpTimeoutException PRINTABLE_DATA = set(string.printable) - set(string.whitespace) @@ -16,33 +19,23 @@ def _Dotify(data): return ''.join(char if char in PRINTABLE_DATA else '.' for char in data) -class StubUsb(object): - """UsbHandle stub.""" - - def __init__(self): +class StubHandleBase(object): + def __init__(self, timeout_ms, is_tcp=False): self.written_data = [] self.read_data = [] - self.timeout_ms = 0 + self.is_tcp = is_tcp + self.timeout_ms = timeout_ms - def BulkWrite(self, data, unused_timeout_ms=None): - expected_data = self.written_data.pop(0) - if isinstance(data, bytearray): - data = bytes(data) - if not isinstance(data, bytes): - data = data.encode('utf8') - if expected_data != data: - raise ValueError('Expected %s (%s) got %s (%s)' % ( - binascii.hexlify(expected_data), _Dotify(expected_data), - binascii.hexlify(data), _Dotify(data))) + def _signal_handler(self, signum, frame): + raise TcpTimeoutException('End of time') - def BulkRead(self, length, - timeout_ms=None): # pylint: disable=unused-argument - data = self.read_data.pop(0) - if length < len(data): - raise ValueError( - 'Overflow packet length. Read %d bytes, got %d bytes: %s', - length, len(data)) - return bytearray(data) + def _return_seconds(self, time_ms): + return (float(time_ms)/1000) if time_ms else 0 + + def _alarm_sounder(self, timeout_ms): + signal.signal(signal.SIGALRM, self._signal_handler) + signal.setitimer(signal.ITIMER_REAL, + self._return_seconds(timeout_ms)) def ExpectWrite(self, data): if not isinstance(data, bytes): @@ -54,22 +47,6 @@ def ExpectRead(self, data): data = data.encode('utf8') self.read_data.append(data) - def Timeout(self, timeout_ms): - return timeout_ms if timeout_ms is not None else self.timeout_ms - -class StubTcp(StubUsb): - - def _signal_handler(self, signum, frame): - raise TcpTimeoutException('End of time') - - def _return_seconds(self, time_ms): - return (float(time_ms)/1000) if time_ms else 0 - - def _alarm_sounder(self, timeout_ms): - signal.signal(signal.SIGALRM, self._signal_handler) - signal.setitimer(signal.ITIMER_REAL, - self._return_seconds(timeout_ms)) - def BulkWrite(self, data, timeout_ms=None): expected_data = self.written_data.pop(0) if isinstance(data, bytearray): @@ -80,8 +57,8 @@ def BulkWrite(self, data, timeout_ms=None): raise ValueError('Expected %s (%s) got %s (%s)' % ( binascii.hexlify(expected_data), _Dotify(expected_data), binascii.hexlify(data), _Dotify(data))) - if b'i_need_a_timeout' in data: - self._alarm_sounder(timeout_ms) + if self.is_tcp and b'i_need_a_timeout' in data: + self._alarm_sounder(timeout_ms) time.sleep(2*self._return_seconds(timeout_ms)) def BulkRead(self, length, @@ -91,8 +68,56 @@ def BulkRead(self, length, raise ValueError( 'Overflow packet length. Read %d bytes, got %d bytes: %s', length, len(data)) - if b'i_need_a_timeout' in data: - self._alarm_sounder(timeout_ms) + if self.is_tcp and b'i_need_a_timeout' in data: + self._alarm_sounder(timeout_ms) time.sleep(2*self._return_seconds(timeout_ms)) - return bytearray(data) + return bytearray(data) + + def Timeout(self, timeout_ms): + return timeout_ms if timeout_ms is not None else self.timeout_ms + + +class StubUsb(UsbHandle): + """UsbHandle stub.""" + def __init__(self, device, setting, usb_info=None, timeout_ms=None): + super(StubUsb, self).__init__(device, setting, usb_info, timeout_ms) + self.stub_base = StubHandleBase(0) + + def ExpectWrite(self, data): + return self.stub_base.ExpectWrite(data) + + def ExpectRead(self, data): + return self.stub_base.ExpectRead(data) + + def BulkWrite(self, data, unused_timeout_ms=None): + return self.stub_base.BulkWrite(data, unused_timeout_ms) + + def BulkRead(self, length, timeout_ms=None): + return self.stub_base.BulkRead(length, timeout_ms) + + def Timeout(self, timeout_ms): + return self.stub_base.Timeout(timeout_ms) + + +class StubTcp(TcpHandle): + def __init__(self, serial, timeout_ms=None): + """TcpHandle stub.""" + self._connect = mock.MagicMock(return_value=None) + + super(StubTcp, self).__init__(serial, timeout_ms) + self.stub_base = StubHandleBase(0, is_tcp=True) + def ExpectWrite(self, data): + return self.stub_base.ExpectWrite(data) + + def ExpectRead(self, data): + return self.stub_base.ExpectRead(data) + + def BulkWrite(self, data, unused_timeout_ms=None): + return self.stub_base.BulkWrite(data, unused_timeout_ms) + + def BulkRead(self, length, timeout_ms=None): + return self.stub_base.BulkRead(length, timeout_ms) + + def Timeout(self, timeout_ms): + return self.stub_base.Timeout(timeout_ms) diff --git a/test/fastboot_test.py b/test/fastboot_test.py index 32c96fa..58ccced 100755 --- a/test/fastboot_test.py +++ b/test/fastboot_test.py @@ -26,7 +26,7 @@ class FastbootTest(unittest.TestCase): def setUp(self): - self.usb = common_stub.StubUsb() + self.usb = common_stub.StubUsb(device=None, setting=None) @staticmethod def _SumLengths(items): diff --git a/tox.ini b/tox.ini index 4880a55..0f9881a 100644 --- a/tox.ini +++ b/tox.ini @@ -12,5 +12,6 @@ envlist = deps = pytest pytest-cov + mock usedevelop = True commands = py.test --cov adb test