Skip to content

Commit

Permalink
fixing some types based on pyright report
Browse files Browse the repository at this point in the history
  • Loading branch information
seperman committed Mar 5, 2025
1 parent a7b4a45 commit 661c3b9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 57 deletions.
2 changes: 1 addition & 1 deletion deepdiff/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from inspect import getmembers
from itertools import zip_longest
from functools import lru_cache
from deepdiff.helper import (strings, bytes_type, numbers, uuids, datetimes, ListItemRemovedOrAdded, notpresent,
from deepdiff.helper import (strings, bytes_type, numbers, uuids, ListItemRemovedOrAdded, notpresent,
IndexedHash, unprocessed, add_to_frozen_set, basic_types,
convert_item_or_items_into_set_else_none, get_type,
convert_item_or_items_into_compiled_regexes_else_none,
Expand Down
37 changes: 17 additions & 20 deletions deepdiff/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from ast import literal_eval
from decimal import Decimal, localcontext, InvalidOperation as InvalidDecimalOperation
from itertools import repeat
# from orderly_set import OrderlySet as SetOrderedBase # median: 0.806 s, some tests are failing
# from orderly_set import SetOrdered as SetOrderedBase # median 1.011 s, didn't work for tests
from orderly_set import StableSetEq as SetOrderedBase # median: 1.0867 s for cache test, 5.63s for all tests
# from orderly_set import OrderedSet as SetOrderedBase # median 1.1256 s for cache test, 5.63s for all tests
from threading import Timer


Expand Down Expand Up @@ -91,14 +88,14 @@ def __repr__(self):
)

numpy_dtypes = set(numpy_numbers)
numpy_dtypes.add(np_bool_)
numpy_dtypes.add(np_bool_) # type: ignore

numpy_dtype_str_to_type = {
item.__name__: item for item in numpy_dtypes
}

try:
from pydantic.main import BaseModel as PydanticBaseModel
from pydantic.main import BaseModel as PydanticBaseModel # type: ignore
except ImportError:
PydanticBaseModel = pydantic_base_model_type

Expand Down Expand Up @@ -367,7 +364,7 @@ def get_type(obj):
Get the type of object or if it is a class, return the class itself.
"""
if isinstance(obj, np_ndarray):
return obj.dtype.type
return obj.dtype.type # type: ignore
return obj if type(obj) is type else type(obj)


Expand Down Expand Up @@ -409,7 +406,7 @@ def number_to_string(number, significant_digits, number_format_notation="f"):
except KeyError:
raise ValueError("number_format_notation got invalid value of {}. The valid values are 'f' and 'e'".format(number_format_notation)) from None

if not isinstance(number, numbers):
if not isinstance(number, numbers): # type: ignore
return number
elif isinstance(number, Decimal):
with localcontext() as ctx:
Expand All @@ -423,32 +420,31 @@ def number_to_string(number, significant_digits, number_format_notation="f"):
# For example '999.99999999' will become '1000.000000' after quantize
ctx.prec += 1
number = number.quantize(Decimal('0.' + '0' * significant_digits))
elif isinstance(number, only_complex_number):
elif isinstance(number, only_complex_number): # type: ignore
# Case for complex numbers.
number = number.__class__(
"{real}+{imag}j".format(
"{real}+{imag}j".format( # type: ignore
real=number_to_string(
number=number.real,
number=number.real, # type: ignore
significant_digits=significant_digits,
number_format_notation=number_format_notation
),
imag=number_to_string(
number=number.imag,
number=number.imag, # type: ignore
significant_digits=significant_digits,
number_format_notation=number_format_notation
)
)
) # type: ignore
)
else:
# import pytest; pytest.set_trace()
number = round(number=number, ndigits=significant_digits)
number = round(number=number, ndigits=significant_digits) # type: ignore

if significant_digits == 0:
number = int(number)

if number == 0.0:
# Special case for 0: "-0.xx" should compare equal to "0.xx"
number = abs(number)
number = abs(number) # type: ignore

# Cast number to string
result = (using % significant_digits).format(number)
Expand Down Expand Up @@ -565,7 +561,8 @@ def start(self):

def stop(self):
duration = self._get_duration_sec()
self._timer.cancel()
if self._timer is not None:
self._timer.cancel()
self.is_running = False
return duration

Expand Down Expand Up @@ -661,8 +658,8 @@ def cartesian_product_numpy(*arrays):
https://stackoverflow.com/a/49445693/1497443
"""
la = len(arrays)
dtype = np.result_type(*arrays)
arr = np.empty((la, *map(len, arrays)), dtype=dtype)
dtype = np.result_type(*arrays) # type: ignore
arr = np.empty((la, *map(len, arrays)), dtype=dtype) # type: ignore
idx = slice(None), *repeat(None, la)
for i, a in enumerate(arrays):
arr[i, ...] = a[idx[:la - i]]
Expand All @@ -676,7 +673,7 @@ def diff_numpy_array(A, B):
By Divakar
https://stackoverflow.com/a/52417967/1497443
"""
return A[~np.isin(A, B)]
return A[~np.isin(A, B)] # type: ignore


PYTHON_TYPE_TO_NUMPY_TYPE = {
Expand Down Expand Up @@ -754,7 +751,7 @@ class OpcodeTag(EnumBase):
insert = 'insert'
delete = 'delete'
equal = 'equal'
replace = 'replace'
replace = 'replace' # type: ignore
# swapped = 'swapped' # in the future we should support reporting of items swapped with each other


Expand Down
38 changes: 20 additions & 18 deletions deepdiff/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,26 @@ def mutual_add_removes_to_become_value_changes(self):
This function should only be run on the Tree Result.
"""
if self.get('iterable_item_added') and self.get('iterable_item_removed'):
added_paths = {i.path(): i for i in self['iterable_item_added']}
removed_paths = {i.path(): i for i in self['iterable_item_removed']}
iterable_item_added = self.get('iterable_item_added')
iterable_item_removed = self.get('iterable_item_removed')
if iterable_item_added is not None and iterable_item_removed is not None:
added_paths = {i.path(): i for i in iterable_item_added}
removed_paths = {i.path(): i for i in iterable_item_removed}
mutual_paths = set(added_paths) & set(removed_paths)

if mutual_paths and 'values_changed' not in self:
if mutual_paths and 'values_changed' not in self or self['values_changed'] is None:
self['values_changed'] = SetOrdered()
for path in mutual_paths:
level_before = removed_paths[path]
self['iterable_item_removed'].remove(level_before)
iterable_item_removed.remove(level_before)
level_after = added_paths[path]
self['iterable_item_added'].remove(level_after)
iterable_item_added.remove(level_after)
level_before.t2 = level_after.t2
self['values_changed'].add(level_before)
self['values_changed'].add(level_before) # type: ignore
level_before.report_type = 'values_changed'
if 'iterable_item_removed' in self and not self['iterable_item_removed']:
if 'iterable_item_removed' in self and not iterable_item_removed:
del self['iterable_item_removed']
if 'iterable_item_added' in self and not self['iterable_item_added']:
if 'iterable_item_added' in self and not iterable_item_added:
del self['iterable_item_added']

def __getitem__(self, item):
Expand Down Expand Up @@ -242,7 +244,7 @@ def _from_tree_set_item_added_or_removed(self, tree, key):
item = "'%s'" % item
if is_dict:
if path not in set_item_info:
set_item_info[path] = set()
set_item_info[path] = set() # type: ignore
set_item_info[path].add(item)
else:
set_item_info.add("{}[{}]".format(path, str(item)))
Expand Down Expand Up @@ -619,12 +621,12 @@ def auto_generate_child_rel(self, klass, param, param2=None):
:param param: A ChildRelationship subclass-dependent parameter describing how to get from parent to child,
e.g. the key in a dict
"""
if self.down.t1 is not notpresent:
if self.down.t1 is not notpresent: # type: ignore
self.t1_child_rel = ChildRelationship.create(
klass=klass, parent=self.t1, child=self.down.t1, param=param)
if self.down.t2 is not notpresent:
klass=klass, parent=self.t1, child=self.down.t1, param=param) # type: ignore
if self.down.t2 is not notpresent: # type: ignore
self.t2_child_rel = ChildRelationship.create(
klass=klass, parent=self.t2, child=self.down.t2, param=param if param2 is None else param2)
klass=klass, parent=self.t2, child=self.down.t2, param=param if param2 is None else param2) # type: ignore

@property
def all_up(self):
Expand Down Expand Up @@ -739,15 +741,15 @@ def path(self, root="root", force=None, get_parent_too=False, use_t2=False, outp
result = None
break
elif output_format == 'list':
result.append(next_rel.param)
result.append(next_rel.param) # type: ignore

# Prepare processing next level
level = level.down

if output_format == 'str':
if get_parent_too:
self._path[cache_key] = (parent, param, result)
output = (self._format_result(root, parent), param, self._format_result(root, result))
self._path[cache_key] = (parent, param, result) # type: ignore
output = (self._format_result(root, parent), param, self._format_result(root, result)) # type: ignore
else:
self._path[cache_key] = result
output = self._format_result(root, result)
Expand Down Expand Up @@ -907,7 +909,7 @@ def stringify_param(self, force=None):
elif isinstance(param, tuple): # Currently only for numpy ndarrays
result = ']['.join(map(repr, param))
elif hasattr(param, '__dataclass_fields__'):
attrs_to_values = [f"{key}={value}" for key, value in [(i, getattr(param, i)) for i in param.__dataclass_fields__]]
attrs_to_values = [f"{key}={value}" for key, value in [(i, getattr(param, i)) for i in param.__dataclass_fields__]] # type: ignore
result = f"{param.__class__.__name__}({','.join(attrs_to_values)})"
else:
candidate = repr(param)
Expand Down
40 changes: 22 additions & 18 deletions deepdiff/summarize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
from deepdiff.serialization import json_dumps


Expand All @@ -13,22 +14,23 @@ def _truncate(s, max_len):
return s[:max_len - 5] + "..." + s[-2:]

class JSONNode:
def __init__(self, data, key=None):
def __init__(self, data: Any, key=None):
"""
Build a tree node for the JSON data.
If this node is a child of a dict, key is its key name.
"""
self.key = key
self.children_list: list[JSONNode] = []
self.children_dict: list[tuple[Any, JSONNode]] = []
if isinstance(data, dict):
self.type = "dict"
self.children = []
# Preserve insertion order: list of (key, child) pairs.
for k, v in data.items():
child = JSONNode(v, key=k)
self.children.append((k, child))
self.children_dict.append((k, child))
elif isinstance(data, list):
self.type = "list"
self.children = [JSONNode(item) for item in data]
self.children_list = [JSONNode(item) for item in data]
else:
self.type = "primitive"
# For primitives, use json.dumps to get a compact representation.
Expand All @@ -37,24 +39,25 @@ def __init__(self, data, key=None):
except Exception:
self.value = str(data)

def full_repr(self):
def full_repr(self) -> str:
"""Return the full minimized JSON representation (without trimming) for this node."""
if self.type == "primitive":
return self.value
elif self.type == "dict":
parts = []
for k, child in self.children:
for k, child in self.children_dict:
parts.append(f'"{k}":{child.full_repr()}')
return "{" + ",".join(parts) + "}"
elif self.type == "list":
parts = [child.full_repr() for child in self.children]
parts = [child.full_repr() for child in self.children_list]
return "[" + ",".join(parts) + "]"
return self.value

def full_weight(self):
"""Return the character count of the full representation."""
return len(self.full_repr())

def summarize(self, budget):
def _summarize(self, budget) -> str:
"""
Return a summary string for this node that fits within budget characters.
The algorithm may drop whole sub-branches (for dicts) or truncate long primitives.
Expand All @@ -69,16 +72,17 @@ def summarize(self, budget):
return self._summarize_dict(budget)
elif self.type == "list":
return self._summarize_list(budget)
return self.value

def _summarize_dict(self, budget):
def _summarize_dict(self, budget) -> str:
# If the dict is empty, return {}
if not self.children:
if not self.children_dict:
return "{}"
# Build a list of pairs with fixed parts:
# Each pair: key_repr is f'"{key}":'
# Also store the full (untrimmed) child representation.
pairs = []
for k, child in self.children:
for k, child in self.children_dict:
key_repr = f'"{k}":'
child_full = child.full_repr()
pair_full = key_repr + child_full
Expand All @@ -103,7 +107,7 @@ def _summarize_dict(self, budget):
# Heuristic: while the representation is too long, drop the pair whose child_full is longest.
while kept:
# Sort kept pairs in original insertion order.
kept_sorted = sorted(kept, key=lambda p: self.children.index((p["key"], p["child"])))
kept_sorted = sorted(kept, key=lambda p: self.children_dict.index((p["key"], p["child"])))
current_n = len(kept_sorted)
fixed = sum(len(p["key_repr"]) for p in kept_sorted) + (current_n - 1) + 2
remaining_budget = budget - fixed
Expand All @@ -116,7 +120,7 @@ def _summarize_dict(self, budget):
child_summaries = []
for p in kept_sorted:
ideal = int(remaining_budget * (len(p["child_full"]) / total_child_full)) if total_child_full > 0 else 0
summary_child = p["child"].summarize(ideal)
summary_child = p["child"]._summarize(ideal)
child_summaries.append(summary_child)
candidate = "{" + ",".join([p["key_repr"] + s for p, s in zip(kept_sorted, child_summaries)]) + "}"
if len(candidate) <= budget:
Expand All @@ -127,17 +131,17 @@ def _summarize_dict(self, budget):
# If nothing remains, return a truncated empty object.
return _truncate("{}", budget)

def _summarize_list(self, budget):
def _summarize_list(self, budget) -> str:
# If the list is empty, return []
if not self.children:
if not self.children_list:
return "[]"
full_repr = self.full_repr()
if len(full_repr) <= budget:
return full_repr
# For lists, show only the first element and an omission indicator if more elements exist.
suffix = ",..." if len(self.children) > 1 else ""
suffix = ",..." if len(self.children_list) > 1 else ""
inner_budget = budget - 2 - len(suffix) # subtract brackets and suffix
first_summary = self.children[0].summarize(inner_budget)
first_summary = self.children_list[0]._summarize(inner_budget)
candidate = "[" + first_summary + suffix + "]"
if len(candidate) <= budget:
return candidate
Expand All @@ -150,4 +154,4 @@ def summarize(data, max_length=200):
ensuring the final string is no longer than self.max_length.
"""
root = JSONNode(data)
return root.summarize(max_length).replace("{,", "{")
return root._summarize(max_length).replace("{,", "{")

0 comments on commit 661c3b9

Please sign in to comment.