Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up PrefixList and PrefixMap #1564

Merged
merged 11 commits into from
Oct 5, 2021
65 changes: 56 additions & 9 deletions traits/tests/test_prefix_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class A(HasTraits):
def test_repeated_prefix(self):
class A(HasTraits):
foo = PrefixList(("abc1", "abc2"))

a = A()

a.foo = "abc1"
Expand All @@ -59,16 +60,45 @@ class A(HasTraits):
with self.assertRaises(TraitError):
a.foo = "abc"

def test_invalid_default(self):
with self.assertRaises(TraitError) as exception_context:
def test_default_default(self):
class A(HasTraits):
foo = PrefixList(["zero", "one", "two"], default_value="zero")

a = A()
self.assertEqual(a.foo, "zero")

def test_explicit_default(self):
class A(HasTraits):
foo = PrefixList(["zero", "one", "two"], default_value="one")

a = A()
self.assertEqual(a.foo, "one")

def test_default_subject_to_completion(self):
class A(HasTraits):
foo = PrefixList(["zero", "one", "two"], default_value="o")

a = A()
self.assertEqual(a.foo, "one")

def test_default_subject_to_validation(self):
with self.assertRaises(ValueError):

class A(HasTraits):
foo = PrefixList(["zero", "one", "two"], default_value="uno")

self.assertIn(
"The value of a PrefixList trait must be 'zero' or 'one' or 'two' "
"(or any unique prefix), but a value of 'uno'",
str(exception_context.exception),
)
def test_default_legal_but_not_unique_prefix(self):
# Corner case to exercise internal logic: the default is not a unique
# prefix, but it is one of the list of values, so it's legal.
class A(HasTraits):
foo = PrefixList(["live", "modal", "livemodal"], default="live")

a = A()
self.assertEqual(a.foo, "live")

def test_default_value_cant_be_passed_by_position(self):
with self.assertRaises(TypeError):
PrefixList(["zero", "one", "two"], "one")

def test_values_not_sequence(self):
# Defining values with this signature is not supported
Expand All @@ -82,8 +112,8 @@ def test_values_not_all_iterables(self):

self.assertEqual(
str(exception_context.exception),
"Legal values should be provided via an iterable of strings, "
"got 'zero'."

"values should be a collection of strings, not 'zero'"
)

def test_values_is_empty(self):
Expand All @@ -92,6 +122,23 @@ def test_values_is_empty(self):
with self.assertRaises(ValueError):
PrefixList([])

def test_values_is_empty_with_default_value(self):
# Raise even if we give a default value.
with self.assertRaises(ValueError):
PrefixList([], default_value="one")

def test_no_nested_exception(self):
# Regression test for enthought/traits#1155
class A(HasTraits):
foo = PrefixList(["zero", "one", "two"])

a = A()
try:
a.foo = "three"
except TraitError as exc:
self.assertIsNone(exc.__context__)
self.assertIsNone(exc.__cause__)

def test_pickle_roundtrip(self):
class A(HasTraits):
foo = PrefixList(["zero", "one", "two"], default_value="one")
Expand Down
21 changes: 16 additions & 5 deletions traits/tests/test_prefix_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class Person(HasTraits):
self.assertEqual(p.married, "nah")
self.assertEqual(p.married_, 0)

def test_default_keyword_only(self):
with self.assertRaises(TypeError):
PrefixMap({"yes": 1, "no": 0}, "yes")

def test_default_method(self):
class Person(HasTraits):
married = PrefixMap({"yes": 1, "yeah": 1, "no": 0, "nah": 0})
Expand Down Expand Up @@ -180,15 +184,22 @@ class Person(HasTraits):
self.assertEqual(p.married, "yeah")

def test_static_default_validation_error(self):
with self.assertRaises(TraitError) as exception_context:
with self.assertRaises(ValueError):
class Person(HasTraits):
married = PrefixMap(
{"yes": 1, "yeah": 1, "no": 0}, default_value="meh")

self.assertIn(
"but a value 'meh' was specified",
str(exception_context.exception),
)
def test_no_nested_exception(self):
# Regression test for enthought/traits#1155
class A(HasTraits):
washable = PrefixMap({"yes": 1, "no": 0})

a = A()
try:
a.washable = "affirmatron"
except TraitError as exc:
self.assertIsNone(exc.__context__)
self.assertIsNone(exc.__cause__)

def test_pickle_roundtrip(self):
class Person(HasTraits):
Expand Down
180 changes: 107 additions & 73 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,76 +2789,91 @@ class Person(HasTraits):
to either 'yes' or 'no'. That is, if the value 'y' is assigned to the
**married** attribute, the actual value assigned will be 'yes'.

Note that the algorithm used by PrefixList in determining whether
a string is a valid value is fairly efficient in terms of both time and
space, and is not based on a brute force set of comparisons.

Parameters
----------
values
A single iterable of legal string values.
A list or other iterable of legal string values for this trait.

Attributes
----------
values : tuple of strings
Enumeration of all legal values for a trait.
values : list of str
The list of legal values for this trait.
"""

#: The default value for the trait:
default_value = None

#: The default value type to use (i.e. 'constant'):
#: The default value type to use.
default_value_type = DefaultValue.constant

def __init__(self, values, **metadata):
def __init__(self, values, *, default_value=None, **metadata):
# Avoid confusion from treating a string-like object as an iterable.
if isinstance(values, (str, bytes, bytearray)):
raise TypeError(
"Legal values should be provided via an iterable of strings, "
"got {!r}.".format(values)
"values should be a collection of strings, "
f"not {values!r}"
)
self.values = list(values)
self.values_ = values_ = {}
for key in values:
values_[key] = key
values = list(values)
if not values:
raise ValueError("values must be nonempty")

default = self.default_value
if 'default_value' in metadata:
default = metadata.pop('default_value')
default = self.value_for(default)
elif self.values:
default = self.values[0]
self.values = values
# Use a set for faster lookup in the common case that the value
# to be validated is one of the elements of 'values' (rather than
# a strict prefix).
self._values_as_set = frozenset(values)

if default_value is not None:
default_value = self._complete_value(default_value)
else:
raise ValueError(
"The iterable of legal string values can not be empty."
)
default_value = self.values[0]

super().__init__(default, **metadata)
super().__init__(default_value, **metadata)

def value_for(self, value):
if not isinstance(value, str):
raise TraitError(
"The value of a {} trait must be {}, but a value of {!r} {!r} "
"was specified.".format(
self.__class__.__name__, self.info(), value, type(value))
)
def _complete_value(self, value):
"""
Validate and complete a given value.

if value in self.values_:
return self.values_[value]
Parameters
----------
value : str
Value to be validated.

Returns
-------
completion : str
Equal to *value*, if *value* is already a member of self.values.
Otherwise, the unique member of self.values for which *value*
is a prefix.

Raises
------
ValueError
If value is not in self.values, and is not a prefix of any
element of self.values, or is a prefix of multiple elements
of self.values.
"""
if value in self._values_as_set:
return value

matches = [key for key in self.values if key.startswith(value)]
if len(matches) == 1:
self.values_[value] = match = matches[0]
return match
return matches[0]

raise TraitError(
"The value of a {} trait must be {}, but a value of {!r} {!r} was "
"specified.".format(
self.__class__.__name__, self.info(), value, type(value))
raise ValueError(
f"{value!r} is neither a member nor a unique prefix of a member "
f"of {self.values}"
)

def validate(self, object, name, value):
if isinstance(value, str):
try:
return self._complete_value(value)
except ValueError:
pass

self.error(object, name, value)

def info(self):
return (
" or ".join([repr(x) for x in self.values])
" or ".join(repr(x) for x in self.values)
+ " (or any unique prefix)"
)

Expand Down Expand Up @@ -3268,46 +3283,63 @@ class PrefixMap(TraitType):

is_mapped = True

def __init__(self, map, **metadata):
def __init__(self, map, *, default_value=None, **metadata):
map = dict(map)
if not map:
raise ValueError("map must be nonempty")
self.map = map
self._map = {}
for key in map.keys():
self._map[key] = key

try:
default_value = metadata.pop("default_value")
except KeyError:
if len(self.map) > 0:
default_value = next(iter(self.map))
else:
raise ValueError(
"The dictionary of valid values can not be empty."
) from None
if default_value is not None:
default_value = self._complete_value(default_value)
else:
default_value = self.value_for(default_value)
default_value = next(iter(self.map))

super().__init__(default_value, **metadata)

def value_for(self, value):
if not isinstance(value, str):
raise TraitError(
"Value must be {}, but a value {!r} was specified.".format(
self.info(), value)
)
def _complete_value(self, value):
corranwebster marked this conversation as resolved.
Show resolved Hide resolved
"""
Validate and complete a given value.

Parameters
----------
value : str
Value to be validated.

if value in self._map:
return self._map[value]
Returns
-------
completion : str
Equal to *value*, if *value* is already a member of self.map.
Otherwise, the unique member of self.values for which *value*
is a prefix.

Raises
------
ValueError
If value is not in self.map, and is not a prefix of any
element of self.map, or is a prefix of multiple elements
of self.map.
"""
if value in self.map:
return value

matches = [key for key in self.map if key.startswith(value)]
if len(matches) == 1:
self._map[value] = match = matches[0]
return match
return matches[0]

raise TraitError(
"Value must be {}, but a value {!r} was specified.".format(
self.info(), value)
raise ValueError(
f"{value!r} is neither a member nor a unique prefix of a member "
f"of {list(self.map)}"
)

def validate(self, object, name, value):
if isinstance(value, str):
try:
return self._complete_value(value)
except ValueError:
pass

self.error(object, name, value)

def mapped_value(self, value):
""" Get the mapped value for a value. """
return self.map[value]
Expand All @@ -3316,8 +3348,10 @@ def post_setattr(self, object, name, value):
setattr(object, name + "_", self.mapped_value(value))

def info(self):
keys = sorted(repr(x) for x in self.map.keys())
return " or ".join(keys) + " (or any unique prefix)"
return (
" or ".join(repr(x) for x in self.map)
+ " (or any unique prefix)"
)

def get_editor(self, trait):
from traitsui.api import EnumEditor
Expand Down