diff --git a/traits/ctraits.c b/traits/ctraits.c index 342e16bbd..09247c575 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -3339,7 +3339,7 @@ validate_trait_integer( */ static PyObject * -as_float(PyObject *value) +validate_float(PyObject *value) { double value_as_double; @@ -3357,6 +3357,12 @@ as_float(PyObject *value) return PyFloat_FromDouble(value_as_double); } +static PyObject * +_ctraits_validate_float(PyObject *self, PyObject *value) +{ + return validate_float(value); +} + /*----------------------------------------------------------------------------- | Verifies that a Python value is convertible to float | @@ -3374,7 +3380,7 @@ validate_trait_float( trait_object *trait, has_traits_object *obj, PyObject *name, PyObject *value) { - PyObject *result = as_float(value); + PyObject *result = validate_float(value); /* A TypeError represents a type validation failure, and should be re-raised as a TraitError. Other exceptions should be propagated. */ if (result == NULL && PyErr_ExceptionMatches(PyExc_TypeError)) { @@ -3450,7 +3456,7 @@ validate_trait_float_range( PyObject *result; int in_range; - result = as_float(value); + result = validate_float(value); if (result == NULL) { if (PyErr_ExceptionMatches(PyExc_TypeError)) { /* Reraise any TypeError as a TraitError. */ @@ -3914,7 +3920,7 @@ validate_trait_complex( break; case 4: /* Floating point range check: */ - result = as_float(value); + result = validate_float(value); if (result == NULL) { if (PyErr_ExceptionMatches(PyExc_TypeError)) { /* A TypeError should ultimately get re-raised @@ -4132,7 +4138,7 @@ validate_trait_complex( /* A TypeError indicates that we don't have a match. Clear the error and continue with the next item in the complex sequence. */ - result = as_float(value); + result = validate_float(value); if (result == NULL && PyErr_ExceptionMatches(PyExc_TypeError)) { PyErr_Clear(); @@ -5547,6 +5553,15 @@ _ctraits_ctrait(PyObject *self, PyObject *args) | 'CTrait' instance methods: +----------------------------------------------------------------------------*/ + +PyDoc_STRVAR( + _ctraits_validate_float_doc, + "_validate_float(number)\n" + "\n" + "Return *number* converted to a float. Raise TypeError if \n" + "conversion is not possible.\n" +); + static PyMethodDef ctraits_methods[] = { {"_list_classes", (PyCFunction)_ctraits_list_classes, METH_VARARGS, PyDoc_STR( @@ -5555,6 +5570,8 @@ static PyMethodDef ctraits_methods[] = { PyDoc_STR("_adapt(adaptation_function)")}, {"_ctrait", (PyCFunction)_ctraits_ctrait, METH_VARARGS, PyDoc_STR("_ctrait(CTrait_class)")}, + {"_validate_float", (PyCFunction)_ctraits_validate_float, METH_O, + _ctraits_validate_float_doc}, {NULL, NULL}, }; diff --git a/traits/tests/test_float.py b/traits/tests/test_float.py index f10590773..0e442447e 100644 --- a/traits/tests/test_float.py +++ b/traits/tests/test_float.py @@ -18,6 +18,24 @@ from traits.testing.optional_dependencies import numpy, requires_numpy +class IntegerLike: + def __init__(self, value): + self._value = value + + def __index__(self): + return self._value + + +# Python versions < 3.8 don't support conversion of something with __index__ +# to float. +try: + float(IntegerLike(3)) +except TypeError: + float_accepts_index = False +else: + float_accepts_index = True + + class MyFloat(object): def __init__(self, value): self._value = value @@ -93,6 +111,16 @@ def test_accepts_int(self): self.assertIs(type(a.value_or_none), float) self.assertEqual(a.value_or_none, 2.0) + @unittest.skipUnless( + float_accepts_index, + "float does not support __index__ for this Python version", + ) + def test_accepts_integer_like(self): + a = self.test_class() + a.value = IntegerLike(3) + self.assertIs(type(a.value), float) + self.assertEqual(a.value, 3.0) + def test_accepts_float_like(self): a = self.test_class() diff --git a/traits/trait_types.py b/traits/trait_types.py index aa659ebdd..e8eb3d14e 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -24,6 +24,7 @@ import warnings from .constants import DefaultValue, TraitKind, ValidateTrait +from .ctraits import _validate_float from .trait_base import ( strx, get_module_name, @@ -169,20 +170,6 @@ def _validate_int(value): return int(operator.index(value)) -def _validate_float(value): - """ Convert an arbitrary Python object to a float, or raise TypeError. - """ - if type(value) is float: # fast path for common case - return value - try: - nb_float = type(value).__float__ - except AttributeError: - raise TypeError( - "Object of type {!r} not convertible to float".format(type(value)) - ) - return nb_float(value) - - # Trait Types class Any(TraitType):