diff --git a/traits/tests/test_tuple.py b/traits/tests/test_tuple.py index 20609cdf8..3ffea8e2a 100644 --- a/traits/tests/test_tuple.py +++ b/traits/tests/test_tuple.py @@ -12,7 +12,8 @@ """ import unittest -from traits.api import BaseInt, Either, HasTraits, Int, Tuple +from traits.api import ( + BaseInt, Either, HasTraits, Int, List, Str, TraitError, Tuple) from traits.tests.tuple_test_mixin import TupleTestMixin @@ -45,3 +46,30 @@ class A(HasTraits): with self.assertRaises(ZeroDivisionError): a.bar = (3, 5) + + def test_non_constant_defaults(self): + class A(HasTraits): + foo = Tuple(List(Int),) + + a = A() + a.foo[0].append(35) + self.assertEqual(a.foo[0], [35]) + + # The inner list should be being validated. + with self.assertRaises(TraitError): + a.foo[0].append(3.5) + + # The inner list should not be shared between instances. + b = A() + self.assertEqual(b.foo[0], []) + + def test_constant_defaults(self): + # Exercise the code path where all child traits have a constant + # default type. + class A(HasTraits): + foo = Tuple(Int, Tuple(Str, Int)) + + a = A() + b = A() + self.assertEqual(a.foo, (0, ("", 0))) + self.assertIs(a.foo, b.foo) diff --git a/traits/trait_types.py b/traits/trait_types.py index 8d1d74f08..4583a73df 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -2283,16 +2283,41 @@ def __init__(self, *types, **metadata): if len(types) == 0: types = [Trait(element) for element in default_value] - self.types = tuple([trait_from(type) for type in types]) + self.types = tuple(trait_from(type) for type in types) self.init_fast_validate(ValidateTrait.tuple, self.types) if default_value is None: - default_value = tuple( - [type.default_value()[1] for type in self.types] + # Optimisation: if all child traits have a constant default value, + # we can use a constant default value too. Otherwise the default + # needs to be computed dynamically. + child_defaults = [] + child_default_types = [] + for child_trait in self.types: + child_default_type, child_default = child_trait.default_value() + + child_default_types.append(child_default_type) + child_defaults.append(child_default) + + constant_default = all( + dvt == DefaultValue.constant for dvt in child_default_types ) + if constant_default: + self.default_value_type = DefaultValue.constant + default_value = tuple(child_defaults) + else: + self.default_value_type = DefaultValue.callable + default_value = self._get_default_value super().__init__(default_value, **metadata) + def _get_default_value(self, object): + # Dynamic default, used when at least one of the child traits requires + # a dynamic default. + return tuple( + inner_trait.default_value_for(object, "") + for inner_trait in self.types + ) + def init_fast_validate(self, *args): """ Saves the validation parameters. """