Skip to content

Commit

Permalink
Add allow_none flag for Callable trait (#885)
Browse files Browse the repository at this point in the history
* Add allow_none flag for Callable trait

* Allow none to callable in complex trait validator

* Update tests and raise error on Union validation failure

* check allow_none only if value is Py_None

* Validate callables pickled with old styles

* Add a test with a empty Callable() trait

* Missed commit

* Add test for old stye callable

* Style changes
  • Loading branch information
midhun-pm authored Feb 14, 2020
1 parent 14e0f4a commit ba0c0ca
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 12 deletions.
54 changes: 48 additions & 6 deletions traits/ctraits.c
Original file line number Diff line number Diff line change
Expand Up @@ -3679,11 +3679,31 @@ validate_trait_callable(
trait_object *trait, has_traits_object *obj, PyObject *name,
PyObject *value)
{

if ((value == Py_None) || PyCallable_Check(value)) {
if (PyCallable_Check(value)) {
Py_INCREF(value);
return value;
}
else if (value == Py_None) {
int allow_none;
int tuple_size = PyTuple_GET_SIZE(trait->py_validate);

//Handle callables without allow_none, default to allow None
if (tuple_size < 2) {
Py_INCREF(value);
return value;
}

allow_none = PyObject_IsTrue(PyTuple_GET_ITEM(trait->py_validate, 1));

if (allow_none == -1) {
return NULL;
}

else if (allow_none) {
Py_INCREF(value);
return value;
}
}

return raise_trait_error(trait, obj, name, value);
}
Expand Down Expand Up @@ -4072,10 +4092,32 @@ validate_trait_complex(
return result;

case 22: /* Callable check: */
if (value == Py_None || PyCallable_Check(value)) {
goto done;
{
if (PyCallable_Check(value)) {
return value;
}
else if (value == Py_None) {
int allow_none;
int tuple_size = PyTuple_GET_SIZE(trait->py_validate);

//Handle callables without allow_none, default to allow None
if (tuple_size < 2) {
Py_INCREF(value);
return value;
}

allow_none = PyObject_IsTrue(PyTuple_GET_ITEM(trait->py_validate, 1));

if (allow_none == -1) {
return NULL;
}

else if (allow_none) {
return value;
}
}
break;
}
break;

default: /* Should never happen...indicates an internal error: */
assert(0); /* invalid validation type */
Expand Down Expand Up @@ -4271,7 +4313,7 @@ _trait_set_validate(trait_object *trait, PyObject *args)
break;

case 22: /* Callable check: */
if (n == 1) {
if (n == 1 || n == 2) {
goto done;
}
break;
Expand Down
62 changes: 58 additions & 4 deletions traits/tests/test_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
Callable,
Either,
HasTraits,
Int,
Str,
TraitError,
Union,
ValidateTrait
)


Expand All @@ -32,14 +35,12 @@ def instance_method(self):


class MyCallable(HasTraits):

value = Callable()

callable_or_str = Either(Callable(), Str())
callable_or_str = Either(Callable(), Str)


class MyBaseCallable(HasTraits):

value = BaseCallable


Expand All @@ -50,7 +51,7 @@ def test_default(self):
self.assertIsNone(a.value)

def test_accepts_lambda(self):
func = lambda v: v + 1 # noqa: E731
func = lambda v: v + 1 # noqa: E731
a = MyCallable(value=func)
self.assertIs(a.value, func)

Expand Down Expand Up @@ -90,12 +91,65 @@ def test_callable_in_complex_trait(self):
a.callable_or_str = value
self.assertEqual(a.callable_or_str, old_value)

def test_disallow_none(self):

class MyNewCallable(HasTraits):
value = Callable(default_value=pow, allow_none=False)

obj = MyNewCallable()

self.assertIsNotNone(obj.value)

with self.assertRaises(TraitError):
obj.value = None

self.assertEqual(8, obj.value(2, 3))

def test_disallow_none_compound(self):

class MyNewCallable2(HasTraits):
value = Callable(pow, allow_none=True)
empty_callable = Callable()
a_non_none_union = Union(Callable(allow_none=False), Int)
a_allow_none_union = Union(Callable(allow_none=True), Int)

obj = MyNewCallable2()
self.assertIsNotNone(obj.value)
self.assertIsNone(obj.empty_callable)

obj.value = None
obj.empty_callable = None
self.assertIsNone(obj.value)
self.assertIsNone(obj.empty_callable)

obj.a_non_none_union = 5
obj.a_allow_none_union = 5

with self.assertRaises(TraitError):
obj.a_non_none_union = None
obj.a_allow_none_union = None

def test_old_style_callable(self):
class OldCallable(Callable):
def __init__(self, value=None, **metadata):
self.fast_validate = (ValidateTrait.callable,)
super(BaseCallable, self).__init__(value, **metadata)

class MyCallable(HasTraits):
# allow_none flag should be ineffective
value = OldCallable()

obj = MyCallable()
obj.value = None
self.assertIsNone(obj.value)


class TestBaseCallable(unittest.TestCase):

def test_override_validate(self):
""" Verify `BaseCallable` can be subclassed to create new traits.
"""

class ZeroArgsCallable(BaseCallable):

def validate(self, object, name, value):
Expand Down
8 changes: 6 additions & 2 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,13 @@ def validate(self, object, name, value):
class Callable(BaseCallable):
""" A fast-validating trait type whose value must be a Python callable.
"""
def __init__(self, value=None, allow_none=True, **metadata):

self.fast_validate = (ValidateTrait.callable, allow_none)

#: The C-level fast validator to use
fast_validate = (ValidateTrait.callable,)
default_value = metadata.pop("default_value", value)

super().__init__(default_value, **metadata)


class BaseType(TraitType):
Expand Down

0 comments on commit ba0c0ca

Please sign in to comment.