Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Python Enums as values for Enum trait #685

Merged
merged 7 commits into from
Jan 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion traits/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,22 @@
#
# Thanks for using Enthought open source!

import enum
import unittest

from traits.api import Enum, HasTraits, List, Property, TraitError
from traits.api import Any, Enum, HasTraits, List, Property, TraitError


class FooEnum(enum.Enum):
foo = 0
bar = 1
baz = 2


class OtherEnum(enum.Enum):
one = 1
two = 2
three = 3


class ExampleModel(HasTraits):
Expand All @@ -21,6 +34,45 @@ def _get_valid_models(self):
return ["model1", "model2", "model3"]


class EnumListExample(HasTraits):

values = Any(['foo', 'bar', 'baz'])

value = Enum(['foo', 'bar', 'baz'])

value_default = Enum('bar', ['foo', 'bar', 'baz'])

value_name = Enum(values='values')

value_name_default = Enum('bar', values='values')


class EnumTupleExample(HasTraits):

values = Any(('foo', 'bar', 'baz'))

value = Enum(('foo', 'bar', 'baz'))

value_default = Enum('bar', ('foo', 'bar', 'baz'))

value_name = Enum(values='values')

value_name_default = Enum('bar', values='values')


class EnumEnumExample(HasTraits):

values = Any(FooEnum)

value = Enum(FooEnum)

value_default = Enum(FooEnum.bar, FooEnum)

value_name = Enum(values='values')

value_name_default = Enum(FooEnum.bar, values='values')


class EnumTestCase(unittest.TestCase):
def test_valid_enum(self):
example_model = ExampleModel(root="model1")
Expand All @@ -33,3 +85,72 @@ def assign_invalid():
example_model.root = "not_valid_model"

self.assertRaises(TraitError, assign_invalid)

def test_enum_list(self):
example = EnumListExample()
self.assertEqual(example.value, 'foo')
self.assertEqual(example.value_default, 'bar')
self.assertEqual(example.value_name, 'foo')
self.assertEqual(example.value_name_default, 'bar')

example.value = 'bar'
self.assertEqual(example.value, 'bar')

with self.assertRaises(TraitError):
example.value = "something"

with self.assertRaises(TraitError):
example.value = 0

example.values = ['one', 'two', 'three']
example.value_name = 'two'
self.assertEqual(example.value_name, 'two')

with self.assertRaises(TraitError):
example.value_name = 'bar'

def test_enum_tuple(self):
example = EnumTupleExample()
self.assertEqual(example.value, 'foo')
self.assertEqual(example.value_default, 'bar')
self.assertEqual(example.value_name, 'foo')
self.assertEqual(example.value_name_default, 'bar')

example.value = 'bar'
self.assertEqual(example.value, 'bar')

with self.assertRaises(TraitError):
example.value = "something"

with self.assertRaises(TraitError):
example.value = 0

example.values = ('one', 'two', 'three')
example.value_name = 'two'
self.assertEqual(example.value_name, 'two')

with self.assertRaises(TraitError):
example.value_name = 'bar'

def test_enum_enum(self):
example = EnumEnumExample()
self.assertEqual(example.value, FooEnum.foo)
self.assertEqual(example.value_default, FooEnum.bar)
self.assertEqual(example.value_name, FooEnum.foo)
self.assertEqual(example.value_name_default, FooEnum.bar)

example.value = FooEnum.bar
self.assertEqual(example.value, FooEnum.bar)

with self.assertRaises(TraitError):
example.value = "foo"

with self.assertRaises(TraitError):
example.value = 0

example.values = OtherEnum
example.value_name = OtherEnum.two
self.assertEqual(example.value_name, OtherEnum.two)

with self.assertRaises(TraitError):
example.value_name = FooEnum.bar
33 changes: 33 additions & 0 deletions traits/trait_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Imports:
# -------------------------------------------------------------------------------

import enum
import os
import sys
from os import getcwd
Expand All @@ -33,6 +34,8 @@

SequenceTypes = (list, tuple)

EnumTypes = (list, tuple, enum.EnumMeta)

corranwebster marked this conversation as resolved.
Show resolved Hide resolved
ComplexTypes = (float, int)

RangeTypes = (int, float)
Expand Down Expand Up @@ -194,6 +197,36 @@ def strx(arg):
complex: (ValidateTrait.coerce, complex, float, int),
}


def safe_contains(value, container):
""" Perform "in" containment check, allowing for TypeErrors.

This is required because in some circumstances ``x in y`` can raise a
TypeError. In these cases we make the (reasonable) assumption that the
value is _not_ contained in the container.
"""
try:
return value in container
except TypeError:
return False


def collection_default(collection):
""" Get the first item of a collection, returning None if empty.

Parameters
----------
collection : collection
A Python collection, which is presumed to be repeatably iterable.

Returns
-------
default : any
The first item of the collection, or None if the collection is empty.
"""
return next(iter(collection), None)


# -------------------------------------------------------------------------------
# Return a string containing the class name of an object with the correct
# article (a or an) preceding it (e.g. 'an Image', 'a PlotValue'):
Expand Down
55 changes: 35 additions & 20 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# -------------------------------------------------------------------------------

import datetime
import enum
from importlib import import_module
import operator
import re
Expand All @@ -34,11 +35,15 @@
get_module_name,
HandleWeakRef,
class_of,
collection_default,
EnumTypes,
RangeTypes,
safe_contains,
SequenceTypes,
TypeTypes,
Undefined,
TraitsCache,
xgetattr,
)
from .trait_converters import trait_from
from .trait_dict_object import TraitDictEvent, TraitDictObject
Expand Down Expand Up @@ -1941,41 +1946,53 @@ def __init__(self, *args, **metadata):

Parameters
----------
values : list or tuple
The enumeration of all legal values for the trait
*args : *values or (default, values) or values
The enumeration of all legal values for the trait, either as
positional arguments or a list, enum.Enum or tuple. The default
value is the first positional argument, or the first item of
the sequence.
values : str
The name of a trait holding the values, in which case there
must be at most one positional argument holding the default
value. If there is no default value, then the default value
is the first item of the value stored in the trait.

Default Value
-------------
values[0]
"""
values = metadata.pop("values", None)
if isinstance(values, str):
self.name = values
self.get, self.set, self.validate = self._get, self._set, None
n = len(args)
if n == 0:
default_value = None
super(BaseEnum, self).__init__(**metadata)
elif n == 1:
default_value = args[0]
super(BaseEnum, self).__init__(default_value, **metadata)
else:
raise TraitError(
"Incorrect number of arguments specified "
"when using the 'values' keyword"
)
self.name = values
self.values = compile("object." + values, "<string>", "eval")
self.get, self.set, self.validate = self._get, self._set, None
else:
default_value = args[0]
if (len(args) == 1) and isinstance(default_value, SequenceTypes):
if (len(args) == 1) and isinstance(default_value, EnumTypes):
args = default_value
default_value = args[0]
elif (len(args) == 2) and isinstance(args[1], SequenceTypes):
default_value = collection_default(args)
elif (len(args) == 2) and isinstance(args[1], EnumTypes):
args = args[1]

if isinstance(args, enum.EnumMeta):
metadata.setdefault('format_func', operator.attrgetter('name'))
metadata.setdefault('evaluate', args)

self.name = ""
self.values = tuple(args)
self.init_fast_validate(ValidateTrait.enum, self.values)

super(BaseEnum, self).__init__(default_value, **metadata)
super(BaseEnum, self).__init__(default_value, **metadata)

def init_fast_validate(self, *args):
""" Does nothing for the BaseEnum class. Used in the Enum class to set
Expand All @@ -1987,7 +2004,7 @@ def validate(self, object, name, value):
""" Validates that the value is one of the enumerated set of valid
values.
"""
if value in self.values:
if safe_contains(value, self.values):
return value

self.error(object, name, value)
Expand All @@ -1998,7 +2015,7 @@ def full_info(self, object, name, value):
if self.name == "":
values = self.values
else:
values = eval(self.values)
values = xgetattr(object, self.name)

return " or ".join([repr(x) for x in values])

Expand All @@ -2016,25 +2033,23 @@ def create_editor(self):
name=self.name,
cols=self.cols or 3,
evaluate=self.evaluate,
mode=self.mode or "radio",
format_func=self.format_func,
mode=self.mode if self.mode else "radio",
)

def _get(self, object, name, trait):
""" Returns the current value of a dynamic enum trait.
"""
value = self.get_value(object, name, trait)
values = eval(self.values)
if value not in values:
value = None
if len(values) > 0:
value = values[0]

values = xgetattr(object, self.name)
if not safe_contains(value, values):
value = collection_default(values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to tackle #389 and this line is all I needed to solve #389.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kitchoi can you clarify: if this PR gets merged, does that mean #389 can be closed? Or is there more work to do between merging this PR and solving #389?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think #389 can be closed by this PR (as of now) without more being done to it.

While #389 concerns the dictionary only, the idea is to relax the requirement on the values supporting __index__. As long as values implements __iter__, Enum can obtain the allowed values.

I understand your concern with the potential non-deterministic behaviour, e.g. for set and the dict prior to Python 3.6. But if, say, values is an instance of OrderedDict, the Enum on master will fail trying to use the first key as a default. One has to make another property traits to return a list from the keys of OrderedDict. I guess it is about how opinionated this Enum should be when it comes to setting a default value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To verify:

from traits.api import HasTraits, Enum, Any


class Foo(HasTraits):

    a = Enum(values="b")

    b = Any()

    def _b_default(self):
        return {"a": 1, "b": 2}


if __name__ == "__main__":
    f = Foo()
    # f.configure_traits()
    print(repr(f.a))

On this branch, a value 'a' was printed.
On master, this results in an error:

Traceback (most recent call last):
  File "dict_enum.py", line 17, in <module>
    print(repr(f.a))
  File "/Users/kchoi/ETS/traits/traits/trait_types.py", line 2030, in _get
    value = values[0]
KeyError: 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think I liked the error better. :-)

return value

def _set(self, object, name, value):
""" Sets the current value of a dynamic range trait.
"""
if value in eval(self.values):
if safe_contains(value, xgetattr(object, self.name)):
self.set_value(object, name, value)
else:
self.error(object, name, value)
Expand Down