diff --git a/docs/source/traits_api_reference/traits.observers.rst b/docs/source/traits_api_reference/traits.observers.rst index 9db4355ee..7b0e340f4 100644 --- a/docs/source/traits_api_reference/traits.observers.rst +++ b/docs/source/traits_api_reference/traits.observers.rst @@ -33,6 +33,10 @@ :members: :inherited-members: +.. autoclass:: SetChangeEvent + :members: + :inherited-members: + .. autoclass:: TraitChangeEvent :members: :inherited-members: diff --git a/traits/observers/_set_change_event.py b/traits/observers/_set_change_event.py new file mode 100644 index 000000000..f0a0dde5f --- /dev/null +++ b/traits/observers/_set_change_event.py @@ -0,0 +1,61 @@ +# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in LICENSE.txt and may be redistributed only under +# the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! + + +# SetChangeEvent is in the public API. + + +class SetChangeEvent: + """ Event object to represent mutations on a set. + + Attributes + ---------- + trait_set : traits.trait_set_object.TraitSet + The set being mutated. + removed : set + Values removed from the set. + added : set + Values added to the set. + """ + + def __init__(self, *, trait_set, removed, added): + self.trait_set = trait_set + self.removed = removed + self.added = added + + def __repr__(self): + return ( + "{event.__class__.__name__}(" + "trait_set={event.trait_set!r}, " + "removed={event.removed!r}, " + "added={event.added!r}" + ")".format(event=self) + ) + + +def set_event_factory(trait_set, removed, added): + """ Adapt the call signature of TraitSet.notify to create an event. + + Parameters + ---------- + trait_set : traits.trait_set_object.TraitSet + The set being mutated. + removed : set + Values removed from the set. + added : set + Values added to the set. + + Returns + ------- + SetChangeEvent + """ + return SetChangeEvent( + trait_set=trait_set, added=added, removed=removed, + ) diff --git a/traits/observers/_set_item_observer.py b/traits/observers/_set_item_observer.py new file mode 100644 index 000000000..e9ac5e39a --- /dev/null +++ b/traits/observers/_set_item_observer.py @@ -0,0 +1,203 @@ +# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in LICENSE.txt and may be redistributed only under +# the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! + +from traits.observers._i_observer import IObserver +from traits.observers._observe import add_or_remove_notifiers +from traits.observers._observer_change_notifier import ObserverChangeNotifier +from traits.observers._set_change_event import set_event_factory +from traits.observers._trait_event_notifier import TraitEventNotifier +from traits.trait_set_object import TraitSet + + +@IObserver.register +class SetItemObserver: + """ Observer for observing mutations on a set. + + Parameters + ---------- + notify : boolean + Whether to notify for changes. + optional : boolean + If False, this observer will complain if the incoming object is not + an observable set. If True and the incoming object is not a set, this + observer will do nothing. Useful for the 'items' keyword in the text + parser, where the source container type is ambiguous. + """ + + def __init__(self, *, notify, optional): + self.notify = notify + self.optional = optional + + def __hash__(self): + """ Return a hash of this object.""" + return hash((type(self).__name__, self.notify, self.optional)) + + def __eq__(self, other): + """ Return true if this observer is equal to the given one.""" + return ( + type(self) is type(other) + and self.notify == other.notify + and self.optional == other.optional + ) + + def iter_observables(self, object): + """ If the given object is an observable set, yield that set. + Otherwise, raise an error, unless this observer is optional + + Parameters + ---------- + object: object + Object provided by another observers or by the user. + + Yields + ------ + IObservable + + Raises + ------ + ValueError + If the given object is not an observable set and optional is false. + """ + if not isinstance(object, TraitSet): + if self.optional: + return + raise ValueError( + "Expected a TraitSet to be observed, " + "got {!r} (type: {!r})".format(object, type(object)) + ) + + yield object + + def iter_objects(self, object): + """ Yield the content of the set if the given object is an observable + set. Otherwise, raise an error, unless the observer is optional. + + The content of the set will be passed onto the children observer(s) + following this one in an ObserverGraph. + + Parameters + ---------- + object: object + Object provided by another observers or by the user. + + Yields + ------ + value : object + + Raises + ------ + ValueError + If the given object is not an observable set and optional is false. + """ + if not isinstance(object, TraitSet): + if self.optional: + return + raise ValueError( + "Expected a TraitSet to be observed, " + "got {!r} (type: {!r})".format(object, type(object)) + ) + + yield from object + + def get_notifier(self, handler, target, dispatcher): + """ Return a notifier for calling the user handler with the change + event. + + Returns + ------- + notifier : TraitEventNotifier + """ + return TraitEventNotifier( + handler=handler, + target=target, + dispatcher=dispatcher, + event_factory=set_event_factory, + prevent_event=lambda event: False, + ) + + def get_maintainer(self, graph, handler, target, dispatcher): + """ Return a notifier for maintaining downstream observers when + a set is mutated. + + Parameters + ---------- + graph : ObserverGraph + Description for the *downstream* observers, i.e. excluding self. + handler : callable + User handler. + target : object + Object seen by the user as the owner of the observer. + dispatcher : callable + Callable for dispatching the handler. + + Returns + ------- + notifier : ObserverChangeNotifier + """ + return ObserverChangeNotifier( + observer_handler=_observer_change_handler, + event_factory=set_event_factory, + prevent_event=lambda event: False, + graph=graph, + handler=handler, + target=target, + dispatcher=dispatcher, + ) + + def iter_extra_graphs(self, graph): + """ Yield new ObserverGraph to be contributed by this observer. + + Parameters + ---------- + graph : ObserverGraph + The graph this observer is part of. + + Yields + ------ + ObserverGraph + """ + # Unlike CTrait, no need to handle trait_added + yield from () + + +def _observer_change_handler(event, graph, handler, target, dispatcher): + """ Handler for maintaining observers. Used by ObserverChangeNotifier. + + Parameters + ---------- + event : SetChangeEvent + Change event that triggers the maintainer. + graph : ObserverGraph + Description for the *downstream* observers, i.e. excluding self. + handler : callable + User handler. + target : object + Object seen by the user as the owner of the observer. + dispatcher : callable + Callable for dispatching the handler. + """ + for removed_item in event.removed: + add_or_remove_notifiers( + object=removed_item, + graph=graph, + handler=handler, + target=target, + dispatcher=dispatcher, + remove=True, + ) + for added_item in event.added: + add_or_remove_notifiers( + object=added_item, + graph=graph, + handler=handler, + target=target, + dispatcher=dispatcher, + remove=False, + ) diff --git a/traits/observers/events.py b/traits/observers/events.py index 879cfff6a..4e03d7a4a 100644 --- a/traits/observers/events.py +++ b/traits/observers/events.py @@ -19,6 +19,10 @@ ListChangeEvent, ) +from traits.observers._set_change_event import ( # noqa: F401 + SetChangeEvent, +) + from traits.observers._trait_change_event import ( # noqa: F401 TraitChangeEvent, ) diff --git a/traits/observers/tests/test_set_change_event.py b/traits/observers/tests/test_set_change_event.py new file mode 100644 index 000000000..c844cfc1c --- /dev/null +++ b/traits/observers/tests/test_set_change_event.py @@ -0,0 +1,66 @@ +# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in LICENSE.txt and may be redistributed only under +# the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! + +import unittest + +from traits.observers._set_change_event import ( + SetChangeEvent, + set_event_factory, +) +from traits.trait_set_object import TraitSet + + +class TestSetChangeEvent(unittest.TestCase): + + def test_set_change_event_repr(self): + event = SetChangeEvent( + trait_set=set(), + added={1}, + removed={3}, + ) + actual = repr(event) + self.assertEqual( + actual, + "SetChangeEvent(trait_set=set(), removed={3}, added={1})", + ) + + +class TestSetEventFactory(unittest.TestCase): + """ Test event factory compatibility with TraitSet.notify """ + + def test_trait_set_notification_compat(self): + + events = [] + + def notifier(*args, **kwargs): + event = set_event_factory(*args, **kwargs) + events.append(event) + + trait_set = TraitSet( + [1, 2, 3], + notifiers=[notifier], + ) + + # when + trait_set.add(4) + + # then + event, = events + self.assertEqual(event.added, {4}) + self.assertEqual(event.removed, set()) + + # when + events.clear() + trait_set.remove(4) + + # then + event, = events + self.assertEqual(event.added, set()) + self.assertEqual(event.removed, {4}) diff --git a/traits/observers/tests/test_set_item_observer.py b/traits/observers/tests/test_set_item_observer.py new file mode 100644 index 000000000..f19ffeaf3 --- /dev/null +++ b/traits/observers/tests/test_set_item_observer.py @@ -0,0 +1,223 @@ +# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in LICENSE.txt and may be redistributed only under +# the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! + +import unittest +from unittest import mock + +from traits.has_traits import HasTraits +from traits.observers._set_item_observer import SetItemObserver +from traits.observers._testing import ( + call_add_or_remove_notifiers, + create_graph, + DummyObservable, + DummyObserver, + DummyNotifier, +) +from traits.trait_set_object import TraitSet +from traits.trait_types import Set + + +def create_observer(**kwargs): + """ Convenience function for creating SetItemObserver with default values. + """ + values = dict( + notify=True, + optional=False, + ) + values.update(kwargs) + return SetItemObserver(**values) + + +class TestSetItemObserverEqualHash(unittest.TestCase): + """ Test SetItemObserver __eq__, __hash__ and immutability. """ + + def test_not_equal_notify(self): + observer1 = SetItemObserver(notify=False, optional=False) + observer2 = SetItemObserver(notify=True, optional=False) + self.assertNotEqual(observer1, observer2) + + def test_not_equal_optional(self): + observer1 = SetItemObserver(notify=True, optional=True) + observer2 = SetItemObserver(notify=True, optional=False) + self.assertNotEqual(observer1, observer2) + + def test_not_equal_different_type(self): + observer1 = SetItemObserver(notify=False, optional=False) + imposter = mock.Mock() + imposter.notify = False + imposter.optional = False + self.assertNotEqual(observer1, imposter) + + def test_equal_observers(self): + observer1 = SetItemObserver(notify=False, optional=False) + observer2 = SetItemObserver(notify=False, optional=False) + self.assertEqual(observer1, observer2) + self.assertEqual(hash(observer1), hash(observer2)) + + +class CustomSet(set): + # This is a set, but not an observable + pass + + +class CustomTraitSet(TraitSet): + # This can be used with SetItemObserver + pass + + +class ClassWithSet(HasTraits): + values = Set() + + +class TestSetItemObserverIterObservable(unittest.TestCase): + """ Test SetItemObserver.iter_observables """ + + def test_trait_set_iter_observables(self): + instance = ClassWithSet() + observer = create_observer(optional=False) + actual_item, = list(observer.iter_observables(instance.values)) + + self.assertIs(actual_item, instance.values) + + def test_set_but_not_a_trait_set(self): + observer = create_observer(optional=False) + with self.assertRaises(ValueError) as exception_context: + list(observer.iter_observables(CustomSet())) + + self.assertIn( + "Expected a TraitSet to be observed, got", + str(exception_context.exception) + ) + + def test_iter_observables_custom_trait_set(self): + # A subcalss of TraitSet can also be used. + custom_trait_set = CustomTraitSet() + observer = create_observer() + + actual_item, = list(observer.iter_observables(custom_trait_set)) + + self.assertIs(actual_item, custom_trait_set) + + def test_not_a_set(self): + observer = create_observer(optional=False) + with self.assertRaises(ValueError) as exception_context: + list(observer.iter_observables(None)) + + self.assertIn( + "Expected a TraitSet to be observed, got", + str(exception_context.exception) + ) + + def test_optional_flag_not_a_set(self): + observer = create_observer(optional=True) + actual = list(observer.iter_observables(None)) + self.assertEqual(actual, []) + + def test_optional_flag_not_an_observable(self): + observer = create_observer(optional=True) + actual = list(observer.iter_observables(CustomSet())) + self.assertEqual(actual, []) + + +class TestSetItemObserverIterObjects(unittest.TestCase): + """ Test SetItemObserver.iter_objects """ + + def test_iter_objects_from_set(self): + instance = ClassWithSet() + instance.values = set([1, 2, 3]) + observer = create_observer() + actual = list(observer.iter_objects(instance.values)) + self.assertCountEqual(actual, [1, 2, 3]) + + def test_iter_observables_custom_trait_set(self): + # A subcalss of TraitSet can also be used. + custom_trait_set = CustomTraitSet([1, 2, 3]) + observer = create_observer() + + actual = list(observer.iter_objects(custom_trait_set)) + self.assertCountEqual(actual, [1, 2, 3]) + + def test_iter_objects_sanity_check(self): + # sanity check if the given object is a set + observer = create_observer(optional=False) + with self.assertRaises(ValueError) as exception_context: + list(observer.iter_objects(None)) + + self.assertIn( + "Expected a TraitSet to be observed, got", + str(exception_context.exception), + ) + + def test_iter_objects_optional(self): + observer = create_observer(optional=True) + actual = list(observer.iter_objects(None)) + self.assertEqual(actual, []) + + +class TestSetItemObserverNotifications(unittest.TestCase): + """ Integration tests with notifiers and HasTraits. """ + + def test_notify_set_change(self): + instance = ClassWithSet(values=set()) + graph = create_graph( + create_observer(notify=True), + ) + handler = mock.Mock() + call_add_or_remove_notifiers( + object=instance.values, + graph=graph, + handler=handler, + ) + + # when + instance.values.add(1) + + # then + ((event, ), _), = handler.call_args_list + self.assertEqual(event.added, set([1])) + self.assertEqual(event.removed, set()) + + def test_maintain_notifier(self): + # Test maintaining downstream notifier + + class ChildObserver(DummyObserver): + + def iter_observables(self, object): + yield object + + instance = ClassWithSet() + instance.values = set() + + notifier = DummyNotifier() + child_observer = ChildObserver(notifier=notifier) + graph = create_graph( + create_observer(notify=False, optional=False), + child_observer, + ) + + handler = mock.Mock() + call_add_or_remove_notifiers( + object=instance.values, + graph=graph, + handler=handler, + ) + + # when + observable = DummyObservable() + instance.values.add(observable) + + # then + self.assertEqual(observable.notifiers, [notifier]) + + # when + instance.values.remove(observable) + + # then + self.assertEqual(observable.notifiers, []) diff --git a/traits/trait_set_object.py b/traits/trait_set_object.py index 12f4bd6ed..5291e8a01 100644 --- a/traits/trait_set_object.py +++ b/traits/trait_set_object.py @@ -15,6 +15,7 @@ from traits.trait_base import _validate_everything from traits.trait_errors import TraitError +from traits.observers._i_observable import IObservable as _IObservable class TraitSetEvent(object): @@ -51,6 +52,7 @@ def __repr__(self): ) +@_IObservable.register class TraitSet(set): """ A subclass of set that validates and notifies listeners of changes. @@ -423,6 +425,20 @@ def __setstate__(self, state): state['notifiers'] = [] self.__dict__.update(state) + # -- Implement IObservable ------------------------------------------------ + + def _notifiers(self, force_create): + """ Return a list of callables where each callable is a notifier. + The list is expected to be mutated for contributing or removing + notifiers from the object. + + Parameters + ---------- + force_create: boolean + Not used here. + """ + return self.notifiers + class TraitSetObject(TraitSet): """ A specialization of TraitSet with a default validator and notifier