Skip to content

Commit

Permalink
Add type hinting to all user facing interfaces
Browse files Browse the repository at this point in the history
Problem:
 Auto-complete / intellisense does not work correctly on IDE's when using this SDK

Solution:
Add type hints to all user facing classes 

Signed-off-by: Serhiy Pikho <Serhiy1@live.co.uk>
  • Loading branch information
serhiy authored and eccles committed Jun 30, 2021
1 parent f42c98f commit 5472d3d
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 233 deletions.
58 changes: 33 additions & 25 deletions archivist/access_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
"""

from typing import List, Optional
import logging

from copy import deepcopy

# pylint:disable=unused-import # To prevent cyclical import errors forward referencing is used
# pylint:disable=cyclic-import # but pylint doesn't understand this feature
from archivist import archivist as type_helper

from .constants import (
SEP,
ACCESS_POLICIES_SUBPATH,
Expand All @@ -42,21 +46,27 @@
LOGGER = logging.getLogger(__name__)


class AccessPolicy(dict):
"""AccessPolicy object"""


class _AccessPoliciesClient:
"""AccessPoliciesClient
Access to access_policies entitiies using CRUD interface. This class is usually
Access to access_policies entities using CRUD interface. This class is usually
accessed as an attribute of the Archivist class.
Args:
archivist (Archivist): :class:`Archivist` instance
"""

def __init__(self, archivist):
def __init__(self, archivist: "type_helper.Archivist"):
self._archivist = archivist

def create(self, props, filters, access_permissions):
def create(
self, props: dict, filters: List, access_permissions: List
) -> AccessPolicy:
"""Create access policy
Creates access policy with defined attributes.
Expand All @@ -75,7 +85,7 @@ def create(self, props, filters, access_permissions):
self.__query(props, filters=filters, access_permissions=access_permissions),
)

def create_from_data(self, data):
def create_from_data(self, data: dict) -> AccessPolicy:
"""Create access policy
Creates access policy with request body from data stream.
Expand All @@ -95,7 +105,7 @@ def create_from_data(self, data):
)
)

def read(self, identity):
def read(self, identity: str) -> AccessPolicy:
"""Read Access Policy
Reads access policy.
Expand All @@ -114,7 +124,13 @@ def read(self, identity):
)
)

def update(self, identity, props=None, filters=None, access_permissions=None):
def update(
self,
identity,
props: Optional[dict] = None,
filters: Optional[list] = None,
access_permissions: Optional[list] = None,
) -> AccessPolicy:
"""Update Access Policy
Update access policy.
Expand All @@ -139,7 +155,7 @@ def update(self, identity, props=None, filters=None, access_permissions=None):
)
)

def delete(self, identity):
def delete(self, identity: str) -> dict:
"""Delete Access Policy
Deletes access policy.
Expand All @@ -164,7 +180,7 @@ def __query(props, *, filters=None, access_permissions=None):

return query

def count(self, *, display_name=None):
def count(self, *, display_name: Optional[str] = None) -> int:
"""Count access policies.
Counts number of access policies that match criteria.
Expand All @@ -183,10 +199,7 @@ def count(self, *, display_name=None):
)

def list(
self,
*,
page_size=DEFAULT_PAGE_SIZE,
display_name=None,
self, *, page_size: int = DEFAULT_PAGE_SIZE, display_name: Optional[str] = None
):
"""List access policies.
Expand All @@ -212,10 +225,10 @@ def list(
)

# additional queries on different endpoints
def count_matching_assets(self, access_policy_id):
def count_matching_assets(self, access_policy_id: str) -> int:
"""Count assets that match access_policy.
Counts number of assets that match an access_polocy.
Counts number of assets that match an access_policy.
Args:
access_policy_id (str): e.g. access_policies/xxxxxxxxxxxxxxx
Expand All @@ -229,10 +242,7 @@ def count_matching_assets(self, access_policy_id):
)

def list_matching_assets(
self,
access_policy_id,
*,
page_size=DEFAULT_PAGE_SIZE,
self, access_policy_id: str, *, page_size: int = DEFAULT_PAGE_SIZE
):
"""List matching assets.
Expand All @@ -255,7 +265,7 @@ def list_matching_assets(
)
)

def count_matching_access_policies(self, asset_id):
def count_matching_access_policies(self, asset_id: str) -> int:
"""Count access policies that match asset.
Counts number of access policies that match asset.
Expand All @@ -271,7 +281,9 @@ def count_matching_access_policies(self, asset_id):
SEP.join((ACCESS_POLICIES_SUBPATH, asset_id, ACCESS_POLICIES_LABEL)),
)

def list_matching_access_policies(self, asset_id, *, page_size=DEFAULT_PAGE_SIZE):
def list_matching_access_policies(
self, asset_id: str, *, page_size: int = DEFAULT_PAGE_SIZE
):
"""List matching access policies.
List access policies that match asset.
Expand All @@ -292,7 +304,3 @@ def list_matching_access_policies(self, asset_id, *, page_size=DEFAULT_PAGE_SIZE
page_size=page_size,
)
)


class AccessPolicy(dict):
"""AccessPolicy object"""
83 changes: 57 additions & 26 deletions archivist/archivist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
the basic REST verbs to GET, POST, PATCH and DELETE entities..
The REST methods in this class should only be used directly when
a CRUD endpoint for the specific type of entity is unavaliable.
a CRUD endpoint for the specific type of entity is unavailable.
Current CRUD endpoints are assets, events, locations, attachments.
IAM subjects and IAM access policies.
Expand All @@ -24,7 +24,7 @@
auth=authtoken,
)
The arch variable now has additonal endpoints assets,events,locations,
The arch variable now has additional endpoints assets,events,locations,
attachments, IAM subjects and IAM access policies documented elsewhere.
"""
Expand All @@ -33,7 +33,8 @@

import json
from os.path import isfile as os_path_isfile
from typing import Optional
from typing import IO, Optional
from requests.models import Response

from flatten_dict import flatten
import requests
Expand Down Expand Up @@ -92,7 +93,15 @@ class Archivist: # pylint: disable=too-many-instance-attributes
"""

def __init__(self, url, *, auth=None, cert=None, verify=True):
def __init__(
self,
url: str,
*,
auth: Optional[str] = None,
cert: Optional[str] = None,
verify: bool = True,
):

self._headers = {"content-type": "application/json"}
if auth is not None:
self._headers["authorization"] = "Bearer " + auth.strip()
Expand All @@ -114,14 +123,14 @@ def __init__(self, url, *, auth=None, cert=None, verify=True):
self._cert = cert

# keep these in sync with CLIENTS map above
self.assets: Optional[_AssetsClient]
self.events: Optional[_EventsClient]
self.locations: Optional[_LocationsClient]
self.attachments: Optional[_AttachmentsClient]
self.access_policies: Optional[_AccessPoliciesClient]
self.subjects: Optional[_SubjectsClient]

def __getattr__(self, value):
self.assets: _AssetsClient
self.events: _EventsClient
self.locations: _LocationsClient
self.attachments: _AttachmentsClient
self.access_policies: _AccessPoliciesClient
self.subjects: _SubjectsClient

def __getattr__(self, value: str):
"""Create endpoints on demand"""
client = CLIENTS.get(value)

Expand All @@ -133,22 +142,22 @@ def __getattr__(self, value):
return c

@property
def headers(self):
def headers(self) -> dict:
"""dict: Headers REST headers from response"""
return self._headers

@property
def url(self):
def url(self) -> str:
"""str: URL of Archivist endpoint"""
return self._url

@property
def verify(self):
def verify(self) -> bool:
"""bool: Returns True if https connections are to be verified"""
return self._verify

@property
def cert(self):
def cert(self) -> str:
"""str: filepath containing authorisation certificate."""
return self._cert

Expand All @@ -161,7 +170,9 @@ def __add_headers(self, headers):

return newheaders

def get(self, subpath, identity, *, headers=None):
def get(
self, subpath: str, identity: str, *, headers: Optional[dict] = None
) -> dict:
"""GET method (REST)
Args:
Expand All @@ -173,7 +184,6 @@ def get(self, subpath, identity, *, headers=None):
dict representing the response body (entity).
"""
LOGGER.debug("get %s/%s", subpath, identity)
response = requests.get(
SEP.join((self.url, ROOT, subpath, identity)),
headers=self.__add_headers(headers),
Expand All @@ -187,7 +197,9 @@ def get(self, subpath, identity, *, headers=None):

return response.json()

def get_file(self, subpath, identity, fd, *, headers=None):
def get_file(
self, subpath: str, identity: str, fd: IO, *, headers: Optional[dict] = None
) -> Response:
"""GET method (REST) - chunked
Downloads a binary object from upstream storage.
Expand Down Expand Up @@ -220,7 +232,7 @@ def get_file(self, subpath, identity, fd, *, headers=None):

return response

def post(self, path, request, *, headers=None):
def post(self, path: str, request: dict, *, headers: Optional[dict] = None) -> dict:
"""POST method (REST)
Creates an entity
Expand Down Expand Up @@ -249,7 +261,7 @@ def post(self, path, request, *, headers=None):

return response.json()

def post_file(self, path, fd, mtype):
def post_file(self, path: str, fd: IO, mtype: str) -> dict:
"""POST method (REST) - upload binary
Uploads a file to an endpoint
Expand Down Expand Up @@ -286,7 +298,9 @@ def post_file(self, path, fd, mtype):

return response.json()

def delete(self, subpath, identity, *, headers=None):
def delete(
self, subpath: str, identity: str, *, headers: Optional[dict] = None
) -> dict:
"""DELETE method (REST)
Deletes an entity
Expand All @@ -313,7 +327,14 @@ def delete(self, subpath, identity, *, headers=None):

return response.json()

def patch(self, subpath, identity, request, *, headers=None):
def patch(
self,
subpath: str,
identity: str,
request: dict,
*,
headers: Optional[dict] = None,
) -> dict:
"""PATCH method (REST)
Updates the specified entity.
Expand Down Expand Up @@ -365,7 +386,9 @@ def __query(query):
sorted(f"{k}={v}" for k, v in flatten(query, reducer="dot").items())
)

def get_by_signature(self, path, field, query, *, headers=None):
def get_by_signature(
self, path: str, field: str, query: dict, *, headers: Optional[dict] = None
) -> dict:
"""GET method (REST) with query string
Reads an entity indirectly by searching for its signature
Expand Down Expand Up @@ -413,7 +436,7 @@ def get_by_signature(self, path, field, query, *, headers=None):

return records[0]

def count(self, path, *, query=None):
def count(self, path: str, *, query: Optional[dict] = None) -> int:
"""GET method (REST) with query string
Returns the count of objects that match query
Expand All @@ -439,7 +462,15 @@ def count(self, path, *, query=None):

return int(response.headers[HEADERS_TOTAL_COUNT])

def list(self, path, field, *, page_size=None, query=None, headers=None):
def list(
self,
path: str,
field: str,
*,
page_size: Optional[int] = None,
query: Optional[dict] = None,
headers: Optional[dict] = None,
):
"""GET method (REST) with query string
Lists entities that match the query dictionary.
Expand Down
Loading

0 comments on commit 5472d3d

Please sign in to comment.