Skip to content

Commit

Permalink
Type hinting overhaul.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kircheneer committed Mar 17, 2023
1 parent 95fb491 commit b0dcde9
Show file tree
Hide file tree
Showing 19 changed files with 210 additions and 137 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ B.load()

# Show the difference between both systems, that is, what would change if we applied changes from System B to System A
diff_a_b = A.diff_from(B)
print(diff_a_b.str())
print(diff_a_b.to_detailed_string())

# Update System A to align with the current status of system B
A.sync_from(B)
Expand Down
72 changes: 37 additions & 35 deletions diffsync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
limitations under the License.
"""
from inspect import isclass
from typing import Callable, ClassVar, Dict, List, Mapping, Optional, Text, Tuple, Type, Union
from typing import Callable, ClassVar, Dict, List, Mapping, Optional, Text, Tuple, Type, Union, Any, Set

from pydantic import BaseModel, PrivateAttr
import structlog # type: ignore
Expand Down Expand Up @@ -72,7 +72,7 @@ class DiffSyncModel(BaseModel):
Note: inclusion in `_attributes` is mutually exclusive from inclusion in `_identifiers`; a field cannot be in both!
"""

_children: ClassVar[Mapping[str, str]] = {}
_children: ClassVar[Dict[str, str]] = {}
"""Optional: dict of `{_modelname: field_name}` entries describing how to store "child" models in this model.
When calculating a Diff or performing a sync, DiffSync will automatically recurse into these child models.
Expand Down Expand Up @@ -101,7 +101,7 @@ class Config: # pylint: disable=too-few-public-methods
# Let us have a DiffSync as an instance variable even though DiffSync is not a Pydantic model itself.
arbitrary_types_allowed = True

def __init_subclass__(cls):
def __init_subclass__(cls) -> None:
"""Validate that the various class attribute declarations correspond to actual instance fields.
Called automatically on subclass declaration.
Expand Down Expand Up @@ -132,27 +132,27 @@ def __init_subclass__(cls):
if attr_child_overlap:
raise AttributeError(f"Fields {attr_child_overlap} are included in both _attributes and _children.")

def __repr__(self):
def __repr__(self) -> str:
return f'{self.get_type()} "{self.get_unique_id()}"'

def __str__(self):
def __str__(self) -> str:
return self.get_unique_id()

def dict(self, **kwargs) -> dict:
def dict(self, **kwargs: Any) -> Dict:
"""Convert this DiffSyncModel to a dict, excluding the diffsync field by default as it is not serializable."""
if "exclude" not in kwargs:
kwargs["exclude"] = {"diffsync"}
return super().dict(**kwargs)

def json(self, **kwargs) -> str:
def json(self, **kwargs: Any) -> str:
"""Convert this DiffSyncModel to a JSON string, excluding the diffsync field by default as it is not serializable."""
if "exclude" not in kwargs:
kwargs["exclude"] = {"diffsync"}
if "exclude_defaults" not in kwargs:
kwargs["exclude_defaults"] = True
return super().json(**kwargs)

def str(self, include_children: bool = True, indent: int = 0) -> str:
def to_detailed_string(self, include_children: bool = True, indent: int = 0) -> str:
"""Build a detailed string representation of this DiffSyncModel and optionally its children."""
margin = " " * indent
output = f"{margin}{self.get_type()}: {self.get_unique_id()}: {self.get_attrs()}"
Expand All @@ -167,12 +167,12 @@ def str(self, include_children: bool = True, indent: int = 0) -> str:
for child_id in child_ids:
try:
child = self.diffsync.get(modelname, child_id)
output += "\n" + child.str(include_children=include_children, indent=indent + 4)
output += "\n" + child.to_detailed_string(include_children=include_children, indent=indent + 4)
except ObjectNotFound:
output += f"\n{margin} {child_id} (ERROR: details unavailable)"
return output

def set_status(self, status: DiffSyncStatus, message: Text = ""):
def set_status(self, status: DiffSyncStatus, message: Text = "") -> None:
"""Update the status (and optionally status message) of this model in response to a create/update/delete call."""
self._status = status
self._status_message = message
Expand Down Expand Up @@ -288,7 +288,7 @@ def get_type(cls) -> Text:
return cls._modelname

@classmethod
def create_unique_id(cls, **identifiers) -> Text:
def create_unique_id(cls, **identifiers: Dict[str, Any]) -> str:
"""Construct a unique identifier for this model class.
Args:
Expand All @@ -297,7 +297,7 @@ def create_unique_id(cls, **identifiers) -> Text:
return "__".join(str(identifiers[key]) for key in cls._identifiers)

@classmethod
def get_children_mapping(cls) -> Mapping[Text, Text]:
def get_children_mapping(cls) -> Dict[str, str]:
"""Get the mapping of types to fieldnames for child models of this model."""
return cls._children

Expand Down Expand Up @@ -347,9 +347,9 @@ def get_shortname(self) -> Text:

def get_status(self) -> Tuple[DiffSyncStatus, Text]:
"""Get the status of the last create/update/delete operation on this object, and any associated message."""
return (self._status, self._status_message)
return self._status, self._status_message

def add_child(self, child: "DiffSyncModel"):
def add_child(self, child: "DiffSyncModel") -> None:
"""Add a child reference to an object.
The child object isn't stored, only its unique id.
Expand All @@ -373,7 +373,7 @@ def add_child(self, child: "DiffSyncModel"):
raise ObjectAlreadyExists(f"Already storing a {child_type} with unique_id {child.get_unique_id()}", child)
childs.append(child.get_unique_id())

def remove_child(self, child: "DiffSyncModel"):
def remove_child(self, child: "DiffSyncModel") -> None:
"""Remove a child reference from an object.
The name of the storage attribute is defined in `_children` per object type.
Expand Down Expand Up @@ -404,13 +404,15 @@ class DiffSync: # pylint: disable=too-many-public-methods
# modelname1 = MyModelClass1
# modelname2 = MyModelClass2

type: ClassVar[Optional[str]] = None
type: Optional[str] = None
"""Type of the object, will default to the name of the class if not provided."""

top_level: ClassVar[List[str]] = []
"""List of top-level modelnames to begin from when diffing or synchronizing."""

def __init__(self, name=None, internal_storage_engine=LocalStore):
def __init__(
self, name: Optional[str] = None, internal_storage_engine: Union[Type[BaseStore], BaseStore] = LocalStore
) -> None:
"""Generic initialization function.
Subclasses should be careful to call super().__init__() if they override this method.
Expand All @@ -429,7 +431,7 @@ def __init__(self, name=None, internal_storage_engine=LocalStore):
# If the name has not been provided, use the type as the name
self.name = name if name else self.type

def __init_subclass__(cls):
def __init_subclass__(cls) -> None:
"""Validate that references to specific DiffSyncModels use the correct modelnames.
Called automatically on subclass declaration.
Expand All @@ -448,16 +450,16 @@ def __init_subclass__(cls):
if not isclass(value) or not issubclass(value, DiffSyncModel):
raise AttributeError(f'top_level references attribute "{name}" but it is not a DiffSyncModel subclass!')

def __str__(self):
def __str__(self) -> str:
"""String representation of a DiffSync."""
if self.type != self.name:
return f'{self.type} "{self.name}"'
return self.type

def __repr__(self):
def __repr__(self) -> str:
return f"<{str(self)}>"

def __len__(self):
def __len__(self) -> int:
"""Total number of elements stored."""
return self.store.count()

Expand All @@ -481,11 +483,11 @@ def _get_initial_value_order(cls) -> List[str]:
value_order.append(item)
return value_order

def load(self):
def load(self) -> None:
"""Load all desired data from whatever backend data source into this instance."""
# No-op in this generic class

def dict(self, exclude_defaults: bool = True, **kwargs) -> Mapping:
def dict(self, exclude_defaults: bool = True, **kwargs: Any) -> Dict[str, Dict[str, Dict]]:
"""Represent the DiffSync contents as a dict, as if it were a Pydantic model."""
data: Dict[str, Dict[str, Dict]] = {}
for modelname in self.store.get_all_model_names():
Expand All @@ -494,7 +496,7 @@ def dict(self, exclude_defaults: bool = True, **kwargs) -> Mapping:
data[obj.get_type()][obj.get_unique_id()] = obj.dict(exclude_defaults=exclude_defaults, **kwargs)
return data

def str(self, indent: int = 0) -> str:
def to_detailed_string(self, indent: int = 0) -> str:
"""Build a detailed string representation of this DiffSync."""
margin = " " * indent
output = ""
Expand All @@ -507,10 +509,10 @@ def str(self, indent: int = 0) -> str:
output += ": []"
else:
for model in models:
output += "\n" + model.str(indent=indent + 2)
output += "\n" + model.to_detailed_string(indent=indent + 2)
return output

def load_from_dict(self, data: Dict):
def load_from_dict(self, data: Dict) -> None:
"""The reverse of `dict` method, taking a dictionary and loading into the inventory.
Args:
Expand Down Expand Up @@ -594,7 +596,7 @@ def sync_complete(
diff: Diff,
flags: DiffSyncFlags = DiffSyncFlags.NONE,
logger: Optional[structlog.BoundLogger] = None,
):
) -> None:
"""Callback triggered after a `sync_from` operation has completed and updated the model data of this instance.
Note that this callback is **only** triggered if the sync actually resulted in data changes. If there are no
Expand Down Expand Up @@ -657,11 +659,11 @@ def diff_to(
# Object Storage Management
# ------------------------------------------------------------------------------

def get_all_model_names(self):
def get_all_model_names(self) -> Set[str]:
"""Get all model names.
Returns:
List[str]: List of model names
List of model names
"""
return self.store.get_all_model_names()

Expand Down Expand Up @@ -730,7 +732,7 @@ def get_tree_traversal(cls, as_dict: bool = False) -> Union[Text, Mapping]:
"""Get a string describing the tree traversal for the diffsync object.
Args:
as_dict: Whether or not to return as a dictionary
as_dict: Whether to return as a dictionary
Returns:
A string or dictionary representation of tree
Expand All @@ -751,7 +753,7 @@ def get_tree_traversal(cls, as_dict: bool = False) -> Union[Text, Mapping]:
return output_dict
return tree_string(output_dict, cls.__name__)

def add(self, obj: DiffSyncModel):
def add(self, obj: DiffSyncModel) -> None:
"""Add a DiffSyncModel object to the store.
Args:
Expand All @@ -762,7 +764,7 @@ def add(self, obj: DiffSyncModel):
"""
return self.store.add(obj=obj)

def update(self, obj: DiffSyncModel):
def update(self, obj: DiffSyncModel) -> None:
"""Update a DiffSyncModel object to the store.
Args:
Expand All @@ -773,7 +775,7 @@ def update(self, obj: DiffSyncModel):
"""
return self.store.update(obj=obj)

def remove(self, obj: DiffSyncModel, remove_children: bool = False):
def remove(self, obj: DiffSyncModel, remove_children: bool = False) -> None:
"""Remove a DiffSyncModel object from the store.
Args:
Expand Down Expand Up @@ -835,14 +837,14 @@ def update_or_add_model_instance(self, obj: DiffSyncModel) -> Tuple[DiffSyncMode
"""
return self.store.update_or_add_model_instance(obj=obj)

def count(self, model: Union[Text, "DiffSyncModel", Type["DiffSyncModel"], None] = None):
def count(self, model: Union[Text, "DiffSyncModel", Type["DiffSyncModel"], None] = None) -> int:
"""Count how many objects of one model type exist in the backend store.
Args:
model (DiffSyncModel): The DiffSyncModel to check the number of elements. If not provided, default to all.
Returns:
Int: Number of elements of the model type
Number of elements of the model type
"""
return self.store.count(model=model)

Expand Down
Loading

0 comments on commit b0dcde9

Please sign in to comment.