diff --git a/traits/tests/test_regression.py b/traits/tests/test_regression.py index f449ecc26..99b1d16f6 100644 --- a/traits/tests/test_regression.py +++ b/traits/tests/test_regression.py @@ -28,6 +28,7 @@ DelegatesTo, Dict, Either, + Enum, Instance, Int, List, @@ -309,6 +310,24 @@ class A(HasTraits): with self.assertRaises(ZeroDivisionError): a.bar = "foo" + def test_clone_list_of_enum_trait(self): + # Regression test for enthought/traits#1622. + + class Order(HasTraits): + menu = List(Str) + selection = List(Enum(values="menu")) + + order = Order(menu=["fish"], selection=["fish"]) + clone = order.clone_traits() + + self.assertEqual(clone.selection, ["fish"]) + + order.selection.append('fish') + self.assertEqual(clone.selection, ['fish']) + + with self.assertRaises(TraitError): + clone.selection.append("bouillabaisse") + class NestedContainerClass(HasTraits): # Used in regression test for changes to nested containers diff --git a/traits/tests/test_trait_dict_object.py b/traits/tests/test_trait_dict_object.py index bcf6cfda5..65c2e9df9 100644 --- a/traits/tests/test_trait_dict_object.py +++ b/traits/tests/test_trait_dict_object.py @@ -455,3 +455,15 @@ class DifferentName(TraitDictEvent): differnt_name_subclass = DifferentName() self.assertEqual(desired_repr, str(differnt_name_subclass)) self.assertEqual(desired_repr, repr(differnt_name_subclass)) + + def test_disconnected_dict(self): + # Objects that are disconnected from their HasTraits "owner" can arise + # as a result of clone_traits operations, or of serialization and + # deserialization. + disconnected = TraitDictObject( + trait=Dict(Str, Str), + object=None, + name="foo", + value={}, + ) + self.assertEqual(disconnected.object(), None) diff --git a/traits/tests/test_trait_list_object.py b/traits/tests/test_trait_list_object.py index 5ea8b5161..ca90c88e6 100644 --- a/traits/tests/test_trait_list_object.py +++ b/traits/tests/test_trait_list_object.py @@ -1457,3 +1457,15 @@ def test_dead_object_reference(self): self.assertEqual(list_object, [1, 2, 3, 4, 5]) with self.assertRaises(TraitError): list_object.append(4) + + def test_disconnected_list(self): + # Objects that are disconnected from their HasTraits "owner" can arise + # as a result of clone_traits operations, or of serialization and + # deserialization. + disconnected = TraitListObject( + trait=List(Int), + object=None, + name="foo", + value=[1, 2, 3], + ) + self.assertEqual(disconnected.object(), None) diff --git a/traits/tests/test_trait_set_object.py b/traits/tests/test_trait_set_object.py index b07ad4b9d..531071db1 100644 --- a/traits/tests/test_trait_set_object.py +++ b/traits/tests/test_trait_set_object.py @@ -14,7 +14,7 @@ from traits.api import HasTraits, Set, Str from traits.trait_base import _validate_everything from traits.trait_errors import TraitError -from traits.trait_set_object import TraitSet, TraitSetEvent +from traits.trait_set_object import TraitSet, TraitSetEvent, TraitSetObject from traits.trait_types import _validate_int @@ -517,6 +517,18 @@ class Foo(HasTraits): # then notifier.assert_not_called() + def test_disconnected_set(self): + # Objects that are disconnected from their HasTraits "owner" can arise + # as a result of clone_traits operations, or of serialization and + # deserialization. + disconnected = TraitSetObject( + trait=Set(Str), + object=None, + name="foo", + value=set(), + ) + self.assertEqual(disconnected.object(), None) + class TestTraitSetEvent(unittest.TestCase): diff --git a/traits/trait_dict_object.py b/traits/trait_dict_object.py index 69335eca7..9d5ca277c 100644 --- a/traits/trait_dict_object.py +++ b/traits/trait_dict_object.py @@ -414,8 +414,9 @@ class TraitDictObject(TraitDict): trait : CTrait instance The CTrait instance associated with the attribute that this dict has been set to. - object : HasTraits instance - The HasTraits instance that the dict has been set as an attribute for. + object : HasTraits + The object this dict belongs to. Can also be None in cases where the + dict has been disconnected from its HasTraits parent. name : str The name of the attribute on the object. value : dict @@ -426,9 +427,9 @@ class TraitDictObject(TraitDict): trait : CTrait instance The CTrait instance associated with the attribute that this dict has been set to. - object : weak reference to a HasTraits instance - A weak reference to a HasTraits instance that the dict has been set - as an attribute for. + object : callable + A callable that when called with no arguments returns the HasTraits + object that this dict belongs to, or None if there is no such object. name : str The name of the attribute on the object. name_items : str @@ -438,7 +439,7 @@ class TraitDictObject(TraitDict): def __init__(self, trait, object, name, value): self.trait = trait - self.object = ref(object) + self.object = (lambda: None) if object is None else ref(object) self.name = name self.name_items = None if trait.has_items: @@ -585,7 +586,7 @@ def __deepcopy__(self, memo): """ result = TraitDictObject( self.trait, - lambda: None, + None, self.name, dict(copy.deepcopy(x, memo) for x in self.items()), ) diff --git a/traits/trait_list_object.py b/traits/trait_list_object.py index 40d185c04..a4766f82c 100644 --- a/traits/trait_list_object.py +++ b/traits/trait_list_object.py @@ -548,7 +548,8 @@ class TraitListObject(TraitList): trait : CTrait The trait that the list has been assigned to. object : HasTraits - The object the list belongs to. + The object this list belongs to. Can also be None in cases where the + list has been disconnected from its HasTraits parent. name : str The name of the trait on the object. value : iterable @@ -558,8 +559,9 @@ class TraitListObject(TraitList): ---------- trait : CTrait The trait that the list has been assigned to. - object : HasTraits - The object the list belongs to. + object : callable + A callable that when called with no arguments returns the HasTraits + object that this list belongs to, or None if there is no such object. name : str The name of the trait on the object. value : iterable @@ -569,7 +571,7 @@ class TraitListObject(TraitList): def __init__(self, trait, object, name, value): self.trait = trait - self.object = ref(object) + self.object = (lambda: None) if object is None else ref(object) self.name = name self.name_items = None if trait.has_items: @@ -812,7 +814,7 @@ def __deepcopy__(self, memo): """ return TraitListObject( self.trait, - lambda: None, + None, self.name, [copy.deepcopy(x, memo) for x in self], ) diff --git a/traits/trait_set_object.py b/traits/trait_set_object.py index a9d60cbcd..69b66a4a9 100644 --- a/traits/trait_set_object.py +++ b/traits/trait_set_object.py @@ -451,7 +451,8 @@ class TraitSetObject(TraitSet): trait : CTrait The trait that the set has been assigned to. object : HasTraits - The object the set belongs to. + The object this set belongs to. Can also be None in cases where the + set has been disconnected from its HasTraits parent. name : str The name of the trait on the object. value : iterable @@ -461,8 +462,9 @@ class TraitSetObject(TraitSet): ---------- trait : CTrait The trait that the set has been assigned to. - object : HasTraits - The object the set belongs to. + object : callable + A callable that when called with no arguments returns the HasTraits + object that this set belongs to, or None if there is no such object. name : str The name of the trait on the object. value : iterable @@ -472,7 +474,7 @@ class TraitSetObject(TraitSet): def __init__(self, trait, object, name, value): self.trait = trait - self.object = ref(object) + self.object = (lambda: None) if object is None else ref(object) self.name = name self.name_items = None if trait.has_items: @@ -560,7 +562,7 @@ def __deepcopy__(self, memo): result = TraitSetObject( self.trait, - lambda: None, + None, self.name, {copy.deepcopy(x, memo) for x in self}, )