Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
Fix strings and add tests (#128)
Browse files Browse the repository at this point in the history
* Fix python3-ism with byte v. str (copied from @rustyhowell)

* Changed binary strings to regular strings

* Version bump to 1.3.0.dev

* If 'serial' is bytes, convert it to unicode

* Version bump to 1.3.0.1

* Add changes from @tuxuser
  • Loading branch information
JeffLIrion authored and fahhem committed Oct 4, 2018
1 parent 82bfd52 commit 40ffe13
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 64 deletions.
15 changes: 10 additions & 5 deletions adb/adb_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,17 @@ def ConnectDevice(self, port_path=None, serial=None, default_timeout_ms=None, **
# If there isnt a handle override (used by tests), build one here
if 'handle' in kwargs:
self._handle = kwargs.pop('handle')
elif serial and b':' in serial:
self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms)
else:
self._handle = common.UsbHandle.FindAndOpen(
DeviceIsAvailable, port_path=port_path, serial=serial,
timeout_ms=default_timeout_ms)
# if necessary, convert serial to a unicode string
if isinstance(serial, (bytes, bytearray)):
serial = serial.decode('utf-8')

if serial and ':' in serial:
self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms)
else:
self._handle = common.UsbHandle.FindAndOpen(
DeviceIsAvailable, port_path=port_path, serial=serial,
timeout_ms=default_timeout_ms)

self._Connect(**kwargs)

Expand Down
5 changes: 5 additions & 0 deletions adb/adb_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def Connect(cls, usb, banner=b'notadb', rsa_keys=None, auth_timeout_ms=100):
InvalidResponseError: When the device does authentication in an
unexpected way.
"""
# In py3, convert unicode to bytes. In py2, convert str to bytes.
# It's later joined into a byte string, so in py2, this ends up kind of being a no-op.
if isinstance(banner, str):
banner = bytearray(banner, 'utf-8')

msg = cls(
command=b'CNXN', arg0=VERSION, arg1=MAX_ADB_DATA,
data=b'host::%s\0' % banner)
Expand Down
23 changes: 17 additions & 6 deletions adb/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,26 @@ def __init__(self, serial, timeout_ms=None):
Host may be an IP address or a host name.
"""
if b':' in serial:
(host, port) = serial.split(b':')
# if necessary, convert serial to a unicode string
if isinstance(serial, (bytes, bytearray)):
serial = serial.decode('utf-8')

if ':' in serial:
self.host, self.port = serial.split(':')
else:
host = serial
port = 5555
self._serial_number = '%s:%s' % (host, port)
self.host = serial
self.port = 5555

self._connection = None
self._serial_number = '%s:%s' % (self.host, self.port)
self._timeout_ms = float(timeout_ms) if timeout_ms else None

self._connect()

def _connect(self):
timeout = self.TimeoutSeconds(self._timeout_ms)
self._connection = socket.create_connection((host, port), timeout=timeout)
self._connection = socket.create_connection((self.host, self.port),
timeout=timeout)
if timeout:
self._connection.setblocking(0)

Expand Down
49 changes: 42 additions & 7 deletions test/adb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
from io import BytesIO
import struct
import unittest
from mock import mock


from adb import common
from adb import adb_commands
from adb import adb_protocol
from adb.usb_exceptions import TcpTimeoutException
from adb.usb_exceptions import TcpTimeoutException, DeviceNotFoundError
import common_stub


Expand Down Expand Up @@ -78,10 +81,9 @@ def _Connect(cls, usb):


class AdbTest(BaseAdbTest):

@classmethod
def _ExpectCommand(cls, service, command, *responses):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
cls._ExpectConnection(usb)
cls._ExpectOpen(usb, b'%s:%s\0' % (service, command))

Expand All @@ -91,12 +93,19 @@ def _ExpectCommand(cls, service, command, *responses):
return usb

def testConnect(self):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
self._ExpectConnection(usb)

dev = adb_commands.AdbCommands()
dev.ConnectDevice(handle=usb, banner=BANNER)

def testConnectSerialString(self):
dev = adb_commands.AdbCommands()

with mock.patch.object(common.UsbHandle, 'FindAndOpen', return_value=None):
with mock.patch.object(adb_commands.AdbCommands, '_Connect', return_value=None):
dev.ConnectDevice(serial='/dev/invalidHandle')

def testSmallResponseShell(self):
command = b'keepin it real'
response = 'word.'
Expand Down Expand Up @@ -196,7 +205,7 @@ def _MakeWriteSyncPacket(cls, command, data=b'', size=None):

@classmethod
def _ExpectSyncCommand(cls, write_commands, read_commands):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
cls._ExpectConnection(usb)
cls._ExpectOpen(usb, b'sync:\0')

Expand Down Expand Up @@ -246,7 +255,7 @@ class TcpTimeoutAdbTest(BaseAdbTest):

@classmethod
def _ExpectCommand(cls, service, command, *responses):
tcp = common_stub.StubTcp()
tcp = common_stub.StubTcp('10.0.0.123')
cls._ExpectConnection(tcp)
cls._ExpectOpen(tcp, b'%s:%s\0' % (service, command))

Expand All @@ -262,7 +271,7 @@ def _run_shell(self, cmd, timeout_ms=None):
dev.Shell(cmd, timeout_ms=timeout_ms)

def testConnect(self):
tcp = common_stub.StubTcp()
tcp = common_stub.StubTcp('10.0.0.123')
self._ExpectConnection(tcp)
dev = adb_commands.AdbCommands()
dev.ConnectDevice(handle=tcp, banner=BANNER)
Expand All @@ -276,5 +285,31 @@ def testTcpTimeout(self):
command,
timeout_ms=timeout_ms)


class TcpHandleTest(unittest.TestCase):
def testInitWithHost(self):
tcp = common_stub.StubTcp('10.11.12.13')

self.assertEqual('10.11.12.13:5555', tcp._serial_number)
self.assertEqual(None, tcp._timeout_ms)

def testInitWithHostAndPort(self):
tcp = common_stub.StubTcp('10.11.12.13:5678')

self.assertEqual('10.11.12.13:5678', tcp._serial_number)
self.assertEqual(None, tcp._timeout_ms)

def testInitWithTimeout(self):
tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234.5)

self.assertEqual('10.0.0.2:5555', tcp._serial_number)
self.assertEqual(234.5, tcp._timeout_ms)

def testInitWithTimeoutInt(self):
tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234)

self.assertEqual('10.0.0.2:5555', tcp._serial_number)
self.assertEqual(234.0, tcp._timeout_ms)

if __name__ == '__main__':
unittest.main()
115 changes: 70 additions & 45 deletions test/common_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import string
import sys
import time
from adb.usb_exceptions import TcpTimeoutException
from mock import mock

from adb.common import TcpHandle, UsbHandle
from adb.usb_exceptions import TcpTimeoutException

PRINTABLE_DATA = set(string.printable) - set(string.whitespace)

Expand All @@ -16,33 +19,23 @@ def _Dotify(data):
return ''.join(char if char in PRINTABLE_DATA else '.' for char in data)


class StubUsb(object):
"""UsbHandle stub."""

def __init__(self):
class StubHandleBase(object):
def __init__(self, timeout_ms, is_tcp=False):
self.written_data = []
self.read_data = []
self.timeout_ms = 0
self.is_tcp = is_tcp
self.timeout_ms = timeout_ms

def BulkWrite(self, data, unused_timeout_ms=None):
expected_data = self.written_data.pop(0)
if isinstance(data, bytearray):
data = bytes(data)
if not isinstance(data, bytes):
data = data.encode('utf8')
if expected_data != data:
raise ValueError('Expected %s (%s) got %s (%s)' % (
binascii.hexlify(expected_data), _Dotify(expected_data),
binascii.hexlify(data), _Dotify(data)))
def _signal_handler(self, signum, frame):
raise TcpTimeoutException('End of time')

def BulkRead(self, length,
timeout_ms=None): # pylint: disable=unused-argument
data = self.read_data.pop(0)
if length < len(data):
raise ValueError(
'Overflow packet length. Read %d bytes, got %d bytes: %s',
length, len(data))
return bytearray(data)
def _return_seconds(self, time_ms):
return (float(time_ms)/1000) if time_ms else 0

def _alarm_sounder(self, timeout_ms):
signal.signal(signal.SIGALRM, self._signal_handler)
signal.setitimer(signal.ITIMER_REAL,
self._return_seconds(timeout_ms))

def ExpectWrite(self, data):
if not isinstance(data, bytes):
Expand All @@ -54,22 +47,6 @@ def ExpectRead(self, data):
data = data.encode('utf8')
self.read_data.append(data)

def Timeout(self, timeout_ms):
return timeout_ms if timeout_ms is not None else self.timeout_ms

class StubTcp(StubUsb):

def _signal_handler(self, signum, frame):
raise TcpTimeoutException('End of time')

def _return_seconds(self, time_ms):
return (float(time_ms)/1000) if time_ms else 0

def _alarm_sounder(self, timeout_ms):
signal.signal(signal.SIGALRM, self._signal_handler)
signal.setitimer(signal.ITIMER_REAL,
self._return_seconds(timeout_ms))

def BulkWrite(self, data, timeout_ms=None):
expected_data = self.written_data.pop(0)
if isinstance(data, bytearray):
Expand All @@ -80,8 +57,8 @@ def BulkWrite(self, data, timeout_ms=None):
raise ValueError('Expected %s (%s) got %s (%s)' % (
binascii.hexlify(expected_data), _Dotify(expected_data),
binascii.hexlify(data), _Dotify(data)))
if b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
if self.is_tcp and b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
time.sleep(2*self._return_seconds(timeout_ms))

def BulkRead(self, length,
Expand All @@ -91,8 +68,56 @@ def BulkRead(self, length,
raise ValueError(
'Overflow packet length. Read %d bytes, got %d bytes: %s',
length, len(data))
if b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
if self.is_tcp and b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
time.sleep(2*self._return_seconds(timeout_ms))
return bytearray(data)
return bytearray(data)

def Timeout(self, timeout_ms):
return timeout_ms if timeout_ms is not None else self.timeout_ms


class StubUsb(UsbHandle):
"""UsbHandle stub."""
def __init__(self, device, setting, usb_info=None, timeout_ms=None):
super(StubUsb, self).__init__(device, setting, usb_info, timeout_ms)
self.stub_base = StubHandleBase(0)

def ExpectWrite(self, data):
return self.stub_base.ExpectWrite(data)

def ExpectRead(self, data):
return self.stub_base.ExpectRead(data)

def BulkWrite(self, data, unused_timeout_ms=None):
return self.stub_base.BulkWrite(data, unused_timeout_ms)

def BulkRead(self, length, timeout_ms=None):
return self.stub_base.BulkRead(length, timeout_ms)

def Timeout(self, timeout_ms):
return self.stub_base.Timeout(timeout_ms)


class StubTcp(TcpHandle):
def __init__(self, serial, timeout_ms=None):
"""TcpHandle stub."""
self._connect = mock.MagicMock(return_value=None)

super(StubTcp, self).__init__(serial, timeout_ms)
self.stub_base = StubHandleBase(0, is_tcp=True)

def ExpectWrite(self, data):
return self.stub_base.ExpectWrite(data)

def ExpectRead(self, data):
return self.stub_base.ExpectRead(data)

def BulkWrite(self, data, unused_timeout_ms=None):
return self.stub_base.BulkWrite(data, unused_timeout_ms)

def BulkRead(self, length, timeout_ms=None):
return self.stub_base.BulkRead(length, timeout_ms)

def Timeout(self, timeout_ms):
return self.stub_base.Timeout(timeout_ms)
2 changes: 1 addition & 1 deletion test/fastboot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class FastbootTest(unittest.TestCase):

def setUp(self):
self.usb = common_stub.StubUsb()
self.usb = common_stub.StubUsb(device=None, setting=None)

@staticmethod
def _SumLengths(items):
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ envlist =
deps =
pytest
pytest-cov
mock
usedevelop = True
commands = py.test --cov adb test

0 comments on commit 40ffe13

Please sign in to comment.