Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow_none flag for Callable trait #885

Merged
merged 9 commits into from
Feb 14, 2020
54 changes: 48 additions & 6 deletions traits/ctraits.c
Original file line number Diff line number Diff line change
Expand Up @@ -3677,11 +3677,31 @@ validate_trait_callable(
trait_object *trait, has_traits_object *obj, PyObject *name,
PyObject *value)
{

if ((value == Py_None) || PyCallable_Check(value)) {
if (PyCallable_Check(value)) {
Py_INCREF(value);
return value;
}
else if (value == Py_None) {
int allow_none;
int tuple_size = PyTuple_GET_SIZE(trait->py_validate);

//Handle callables without allow_none, default to allow None
if (tuple_size < 2) {
Py_INCREF(value);
return value;
}

allow_none = PyObject_IsTrue(PyTuple_GET_ITEM(trait->py_validate, 1));

if (allow_none == -1) {
return NULL;
}

else if (allow_none) {
Py_INCREF(value);
return value;
}
}

return raise_trait_error(trait, obj, name, value);
}
Expand Down Expand Up @@ -4070,10 +4090,32 @@ validate_trait_complex(
return result;

case 22: /* Callable check: */
if (value == Py_None || PyCallable_Check(value)) {
goto done;
{
if (PyCallable_Check(value)) {
return value;
}
else if (value == Py_None) {
int allow_none;
int tuple_size = PyTuple_GET_SIZE(trait->py_validate);

//Handle callables without allow_none, default to allow None
if (tuple_size < 2) {
Py_INCREF(value);
return value;
}

allow_none = PyObject_IsTrue(PyTuple_GET_ITEM(trait->py_validate, 1));

if (allow_none == -1) {
return NULL;
}

else if (allow_none) {
return value;
}
}
break;
}
break;

default: /* Should never happen...indicates an internal error: */
assert(0); /* invalid validation type */
Expand Down Expand Up @@ -4269,7 +4311,7 @@ _trait_set_validate(trait_object *trait, PyObject *args)
break;

case 22: /* Callable check: */
if (n == 1) {
if (n == 1 || n == 2) {
goto done;
}
break;
Expand Down
62 changes: 58 additions & 4 deletions traits/tests/test_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
Callable,
Either,
HasTraits,
Int,
Str,
TraitError,
Union,
ValidateTrait
)


Expand All @@ -32,14 +35,12 @@ def instance_method(self):


class MyCallable(HasTraits):

value = Callable()

callable_or_str = Either(Callable(), Str())
callable_or_str = Either(Callable(), Str)


class MyBaseCallable(HasTraits):

value = BaseCallable


Expand All @@ -50,7 +51,7 @@ def test_default(self):
self.assertIsNone(a.value)

def test_accepts_lambda(self):
func = lambda v: v + 1 # noqa: E731
func = lambda v: v + 1 # noqa: E731
a = MyCallable(value=func)
self.assertIs(a.value, func)

Expand Down Expand Up @@ -90,12 +91,65 @@ def test_callable_in_complex_trait(self):
a.callable_or_str = value
self.assertEqual(a.callable_or_str, old_value)

def test_disallow_none(self):

class MyNewCallable(HasTraits):
value = Callable(default_value=pow, allow_none=False)

obj = MyNewCallable()

self.assertIsNotNone(obj.value)

with self.assertRaises(TraitError):
obj.value = None

mdickinson marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(8, obj.value(2, 3))

def test_disallow_none_compound(self):

class MyNewCallable2(HasTraits):
value = Callable(pow, allow_none=True)
empty_callable = Callable()
a_non_none_union = Union(Callable(allow_none=False), Int)
a_allow_none_union = Union(Callable(allow_none=True), Int)

obj = MyNewCallable2()
self.assertIsNotNone(obj.value)
self.assertIsNone(obj.empty_callable)

obj.value = None
obj.empty_callable = None
self.assertIsNone(obj.value)
self.assertIsNone(obj.empty_callable)

obj.a_non_none_union = 5
obj.a_allow_none_union = 5

with self.assertRaises(TraitError):
obj.a_non_none_union = None
obj.a_allow_none_union = None

def test_old_style_callable(self):
class OldCallable(Callable):
def __init__(self, value=None, **metadata):
self.fast_validate = (ValidateTrait.callable,)
super(BaseCallable, self).__init__(value, **metadata)

class MyCallable(HasTraits):
# allow_none flag should be ineffective
value = OldCallable()

obj = MyCallable()
obj.value = None
self.assertIsNone(obj.value)


class TestBaseCallable(unittest.TestCase):

def test_override_validate(self):
""" Verify `BaseCallable` can be subclassed to create new traits.
"""

class ZeroArgsCallable(BaseCallable):

def validate(self, object, name, value):
Expand Down
8 changes: 6 additions & 2 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,13 @@ def validate(self, object, name, value):
class Callable(BaseCallable):
""" A fast-validating trait type whose value must be a Python callable.
"""
def __init__(self, value=None, allow_none=True, **metadata):

self.fast_validate = (ValidateTrait.callable, allow_none)

#: The C-level fast validator to use
fast_validate = (ValidateTrait.callable,)
default_value = metadata.pop("default_value", value)

super().__init__(default_value, **metadata)


class BaseType(TraitType):
Expand Down