Skip to content

Commit

Permalink
Merge pull request #34 from ronnie-llamado/feature/add-type-annotations
Browse files Browse the repository at this point in the history
Add type annotations
  • Loading branch information
kattni authored Oct 27, 2021
2 parents e3188ad + 331dfa0 commit 23f7ca7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
2 changes: 1 addition & 1 deletion adafruit_azureiot/device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DeviceRegistrationError(Exception):
An error from the device registration
"""

def __init__(self, message):
def __init__(self, message: str):
super().__init__(message)
self.message = message

Expand Down
51 changes: 30 additions & 21 deletions adafruit_azureiot/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

# pylint: disable=C0103, W0108, R0915, C0116, C0115

try:
from typing import Union
except ImportError:
pass

def __translate(key, translation):

def __translate(key: Union[bytes, bytearray], translation: bytes) -> bytes:
return bytes(translation[x] for x in key)


Expand All @@ -28,7 +33,7 @@ def __translate(key, translation):
SHA_DIGESTSIZE = 32


def new_shaobject():
def new_shaobject() -> dict:
"""Struct. for storing SHA information."""
return {
"digest": [0] * 8,
Expand All @@ -40,7 +45,7 @@ def new_shaobject():
}


def sha_init():
def sha_init() -> dict:
"""Initialize the SHA digest."""
sha_info = new_shaobject()
sha_info["digest"] = [
Expand Down Expand Up @@ -73,7 +78,7 @@ def sha_init():
Gamma1 = lambda x: (S(x, 17) ^ S(x, 19) ^ R(x, 10))


def sha_transform(sha_info):
def sha_transform(sha_info: dict) -> None:
W = []

d = sha_info["data"]
Expand All @@ -90,7 +95,7 @@ def sha_transform(sha_info):
ss = sha_info["digest"][:]

# pylint: disable=too-many-arguments, line-too-long
def RND(a, b, c, d, e, f, g, h, i, ki):
def RND(a, b, c, d, e, f, g, h, i, ki): # type: ignore[no-untyped-def]
"""Compress"""
t0 = h + Sigma1(e) + Ch(e, f, g) + ki + W[i]
t1 = Sigma0(a) + Maj(a, b, c)
Expand Down Expand Up @@ -298,7 +303,7 @@ def RND(a, b, c, d, e, f, g, h, i, ki):
sha_info["digest"] = dig


def sha_update(sha_info, buffer):
def sha_update(sha_info: dict, buffer: Union[bytes, bytearray]) -> None:
"""Update the SHA digest.
:param dict sha_info: SHA Digest.
:param str buffer: SHA buffer size.
Expand Down Expand Up @@ -346,13 +351,13 @@ def sha_update(sha_info, buffer):
sha_info["local"] = count


def getbuf(s):
def getbuf(s: Union[str, bytes, bytearray]) -> Union[bytes, bytearray]:
if isinstance(s, str):
return s.encode("ascii")
return bytes(s)


def sha_final(sha_info):
def sha_final(sha_info: dict) -> bytes:
"""Finish computing the SHA Digest."""
lo_bit_count = sha_info["count_lo"]
hi_bit_count = sha_info["count_hi"]
Expand Down Expand Up @@ -393,28 +398,28 @@ class sha256:
block_size = SHA_BLOCKSIZE
name = "sha256"

def __init__(self, s=None):
def __init__(self, s: Union[str, bytes, bytearray] = None):
"""Constructs a SHA256 hash object."""
self._sha = sha_init()
if s:
sha_update(self._sha, getbuf(s))

def update(self, s):
def update(self, s: Union[str, bytes, bytearray]) -> None:
"""Updates the hash object with a bytes-like object, s."""
sha_update(self._sha, getbuf(s))

def digest(self):
def digest(self) -> bytes:
"""Returns the digest of the data passed to the update()
method so far."""
return sha_final(self._sha.copy())[: self._sha["digestsize"]]

def hexdigest(self):
def hexdigest(self) -> str:
"""Like digest() except the digest is returned as a string object of
double length, containing only hexadecimal digits.
"""
return "".join(["%.2x" % i for i in self.digest()])

def copy(self):
def copy(self) -> "sha256":
"""Return a copy (“clone”) of the hash object."""
new = sha256()
new._sha = self._sha.copy()
Expand All @@ -429,7 +434,9 @@ class HMAC:

blocksize = 64 # 512-bit HMAC; can be changed in subclasses.

def __init__(self, key, msg=None):
def __init__(
self, key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None
):
"""Create a new HMAC object.
key: key for the keyed hash object.
Expand Down Expand Up @@ -478,15 +485,15 @@ def __init__(self, key, msg=None):
self.update(msg)

@property
def name(self):
def name(self) -> str:
"""Return the name of this object"""
return "hmac-" + self.inner.name

def update(self, msg):
def update(self, msg: Union[bytes, bytearray]) -> None:
"""Update this hashing object with the string msg."""
self.inner.update(msg)

def copy(self):
def copy(self) -> "HMAC":
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
Expand All @@ -499,7 +506,7 @@ def copy(self):
other.outer = self.outer.copy()
return other

def _current(self):
def _current(self) -> "sha256":
"""Return a hash object for the current state.
To be used only internally with digest() and hexdigest().
Expand All @@ -508,7 +515,7 @@ def _current(self):
hmac.update(self.inner.digest())
return hmac

def digest(self):
def digest(self) -> bytes:
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
Expand All @@ -518,13 +525,15 @@ def digest(self):
hmac = self._current()
return hmac.digest()

def hexdigest(self):
def hexdigest(self) -> str:
"""Like digest(), but returns a string of hexadecimal digits instead."""
hmac = self._current()
return hmac.hexdigest()


def new_hmac(key, msg=None):
def new_hmac(
key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None
) -> "HMAC":
"""Create a new hashing object and return it.
key: The starting key for the hash.
Expand Down
50 changes: 33 additions & 17 deletions adafruit_azureiot/iothub_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
* Author(s): Jim Bennett, Elena Horton
"""

try:
from typing import Any, Callable, Mapping, Union
except ImportError:
pass

import json
import adafruit_logging as logging
from .iot_error import IoTError
from .iot_mqtt import IoTMQTT, IoTMQTTCallback, IoTResponse


def _validate_keys(connection_string_parts):
def _validate_keys(connection_string_parts: Mapping) -> None:
"""Raise ValueError if incorrect combination of keys"""
host_name = connection_string_parts.get(HOST_NAME)
shared_access_key_name = connection_string_parts.get(SHARED_ACCESS_KEY_NAME)
Expand Down Expand Up @@ -67,7 +72,7 @@ def connection_status_change(self, connected: bool) -> None:
self._on_connection_status_changed(connected)

# pylint: disable=W0613, R0201
def direct_method_invoked(self, method_name: str, payload) -> IoTResponse:
def direct_method_invoked(self, method_name: str, payload: str) -> IoTResponse:
"""Called when a direct method is invoked
:param str method_name: The name of the method that was invoked
:param str payload: The payload with the message
Expand All @@ -91,7 +96,10 @@ def cloud_to_device_message_received(self, body: str, properties: dict) -> None:
self._on_cloud_to_device_message_received(body, properties)

def device_twin_desired_updated(
self, desired_property_name: str, desired_property_value, desired_version: int
self,
desired_property_name: str,
desired_property_value: Any,
desired_version: int,
) -> None:
"""Called when the device twin desired properties are updated
:param str desired_property_name: The name of the desired property that was updated
Expand All @@ -107,7 +115,7 @@ def device_twin_desired_updated(
def device_twin_reported_updated(
self,
reported_property_name: str,
reported_property_value,
reported_property_value: Any,
reported_version: int,
) -> None:
"""Called when the device twin reported values are updated
Expand Down Expand Up @@ -175,21 +183,23 @@ def __init__(
self._mqtt = None

@property
def on_connection_status_changed(self):
def on_connection_status_changed(self) -> Callable:
"""A callback method that is called when the connection status is changed. This method should have the following signature:
def connection_status_changed(connected: bool) -> None
"""
return self._on_connection_status_changed

@on_connection_status_changed.setter
def on_connection_status_changed(self, new_on_connection_status_changed):
def on_connection_status_changed(
self, new_on_connection_status_changed: Callable
) -> None:
"""A callback method that is called when the connection status is changed. This method should have the following signature:
def connection_status_changed(connected: bool) -> None
"""
self._on_connection_status_changed = new_on_connection_status_changed

@property
def on_direct_method_invoked(self):
def on_direct_method_invoked(self) -> Callable:
"""A callback method that is called when a direct method is invoked. This method should have the following signature:
def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
Expand All @@ -202,7 +212,7 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
return self._on_direct_method_invoked

@on_direct_method_invoked.setter
def on_direct_method_invoked(self, new_on_direct_method_invoked):
def on_direct_method_invoked(self, new_on_direct_method_invoked: Callable) -> None:
"""A callback method that is called when a direct method is invoked. This method should have the following signature:
def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
Expand All @@ -215,16 +225,16 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
self._on_direct_method_invoked = new_on_direct_method_invoked

@property
def on_cloud_to_device_message_received(self):
def on_cloud_to_device_message_received(self) -> Callable:
"""A callback method that is called when a cloud to device message is received. This method should have the following signature:
def cloud_to_device_message_received(body: str, properties: dict) -> None:
"""
return self._on_cloud_to_device_message_received

@on_cloud_to_device_message_received.setter
def on_cloud_to_device_message_received(
self, new_on_cloud_to_device_message_received
):
self, new_on_cloud_to_device_message_received: Callable
) -> None:
"""A callback method that is called when a cloud to device message is received. This method should have the following signature:
def cloud_to_device_message_received(body: str, properties: dict) -> None:
"""
Expand All @@ -233,15 +243,17 @@ def cloud_to_device_message_received(body: str, properties: dict) -> None:
)

@property
def on_device_twin_desired_updated(self):
def on_device_twin_desired_updated(self) -> Callable:
"""A callback method that is called when the desired properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None:
"""
return self._on_device_twin_desired_updated

@on_device_twin_desired_updated.setter
def on_device_twin_desired_updated(self, new_on_device_twin_desired_updated):
def on_device_twin_desired_updated(
self, new_on_device_twin_desired_updated: Callable
) -> None:
"""A callback method that is called when the desired properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None:
Expand All @@ -252,15 +264,17 @@ def device_twin_desired_updated(desired_property_name: str, desired_property_val
self._mqtt.subscribe_to_twins()

@property
def on_device_twin_reported_updated(self):
def on_device_twin_reported_updated(self) -> Callable:
"""A callback method that is called when the reported properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None:
"""
return self._on_device_twin_reported_updated

@on_device_twin_reported_updated.setter
def on_device_twin_reported_updated(self, new_on_device_twin_reported_updated):
def on_device_twin_reported_updated(
self, new_on_device_twin_reported_updated: Callable
) -> None:
"""A callback method that is called when the reported properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None:
Expand Down Expand Up @@ -327,7 +341,9 @@ def is_connected(self) -> bool:

return False

def send_device_to_cloud_message(self, message, system_properties=None) -> None:
def send_device_to_cloud_message(
self, message: Union[str, dict], system_properties: dict = None
) -> None:
"""Send a device to cloud message from this device to Azure IoT Hub
:param message: The message data as a JSON string or a dictionary
:param system_properties: System properties to send with the message
Expand All @@ -339,7 +355,7 @@ def send_device_to_cloud_message(self, message, system_properties=None) -> None:

self._mqtt.send_device_to_cloud_message(message, system_properties)

def update_twin(self, patch) -> None:
def update_twin(self, patch: Union[str, dict]) -> None:
"""Updates the reported properties in the devices device twin
:param patch: The JSON patch to apply to the device twin reported properties
"""
Expand Down
Loading

0 comments on commit 23f7ca7

Please sign in to comment.