Skip to content

Commit

Permalink
Replace equality definition on ObserverExpression (#1517)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdickinson authored Sep 10, 2021
1 parent 74d249c commit e75d199
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 21 deletions.
47 changes: 32 additions & 15 deletions traits/observation/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,6 @@ class ObserverExpression:

__slots__ = ()

def __eq__(self, other):
""" Return true if the other value is an ObserverExpression with
equivalent content.
Returns
-------
bool
"""
if type(other) is not type(self):
return False
return self._as_graphs() == other._as_graphs()

def __or__(self, expression):
""" Create a new expression that matches this expression OR
the given expression.
Expand Down Expand Up @@ -280,14 +268,23 @@ class SingleObserverExpression(ObserverExpression):
""" Container of ObserverExpression for wrapping a single observer.
"""

__slots__ = ("observer",)
__slots__ = ("_observer",)

def __init__(self, observer):
self.observer = observer
self._observer = observer

def __hash__(self):
return hash((type(self).__name__, self._observer))

def __eq__(self, other):
return (
type(self) is type(other)
and self._observer == other._observer
)

def _create_graphs(self, branches):
return [
ObserverGraph(node=self.observer, children=branches),
ObserverGraph(node=self._observer, children=branches),
]


Expand All @@ -308,6 +305,16 @@ def __init__(self, first, second):
self._first = first
self._second = second

def __hash__(self):
return hash((type(self).__name__, self._first, self._second))

def __eq__(self, other):
return (
type(self) is type(other)
and self._first == other._first
and self._second == other._second
)

def _create_graphs(self, branches):
branches = self._second._create_graphs(branches=branches)
return self._first._create_graphs(branches=branches)
Expand All @@ -330,6 +337,16 @@ def __init__(self, left, right):
self._left = left
self._right = right

def __hash__(self):
return hash((type(self).__name__, self._left, self._right))

def __eq__(self, other):
return (
type(self) is type(other)
and self._left == other._left
and self._right == other._right
)

def _create_graphs(self, branches):
left_graphs = self._left._create_graphs(branches=branches)
right_graphs = self._right._create_graphs(branches=branches)
Expand Down
12 changes: 10 additions & 2 deletions traits/observation/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,14 @@ def test_call_signatures(self):
)


class TestObserverExpressionEquality(unittest.TestCase):
""" Test ObserverExpression.__eq__ """
class TestObserverExpressionEqualityAndHashing(unittest.TestCase):
""" Test ObserverExpression.__eq__ and ObserverExpression.__hash__. """

def test_trait_equality(self):
expr1 = create_expression(1)
expr2 = create_expression(1)
self.assertEqual(expr1, expr2)
self.assertEqual(hash(expr1), hash(expr2))

def test_join_equality_with_then(self):
# The following all result in the same graphs
Expand All @@ -677,6 +678,13 @@ def test_join_equality_with_then(self):
combined2 = expr1.then(expr2)

self.assertEqual(combined1, combined2)
self.assertEqual(hash(combined1), hash(combined2))

def test_equality_of_parallel_expressions(self):
expr1 = create_expression(1) | create_expression(2)
expr2 = create_expression(1) | create_expression(2)
self.assertEqual(expr1, expr2)
self.assertEqual(hash(expr1), hash(expr2))

def test_equality_different_type(self):
expr = create_expression(1)
Expand Down
9 changes: 5 additions & 4 deletions traits/observation/tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ def test_deep_nesting(self):

actual = parse("[a:[b.[c:d]]]")
expected = (
trait("a", notify=False)
.trait("b")
.trait("c", notify=False)
.trait("d")
trait("a", notify=False).then(
trait("b").then(
trait("c", notify=False).then(trait("d"))
)
)
)
self.assertEqual(actual, expected)

Expand Down

0 comments on commit e75d199

Please sign in to comment.