From 91cd08cc60db942557be45eb1663578d66f8ce73 Mon Sep 17 00:00:00 2001 From: Kevin Duff Date: Mon, 3 Aug 2020 17:22:42 +0100 Subject: [PATCH] Add casting keyword to numeric array types. (#547) * Add casting keyowrd to AbstractArray and its children * Explicitly add to Array and CArray init * Add docstring * Update docstring * Change requests * Docstring tweak * Docstring tweak * make casting kwarg-onlt for ArrayAbstract --- traits/tests/test_array.py | 20 ++++++++- traits/tests/test_array_or_none.py | 16 +++++++ traits/trait_numeric.py | 69 ++++++++++++++++++++++++++---- 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/traits/tests/test_array.py b/traits/tests/test_array.py index 5f836242e..3d5b18c5d 100644 --- a/traits/tests/test_array.py +++ b/traits/tests/test_array.py @@ -10,7 +10,7 @@ import unittest -from traits.api import Array, Bool, HasTraits +from traits.api import Array, Bool, HasTraits, TraitError from traits.testing.optional_dependencies import numpy, requires_numpy @@ -25,10 +25,10 @@ def _a_changed(self): self.event_fired = True +@requires_numpy class ArrayTestCase(unittest.TestCase): """ Test cases for delegated traits. """ - @requires_numpy def test_zero_to_one_element(self): """ Test that an event fires when an Array trait changes from zero to one element. @@ -43,3 +43,19 @@ def test_zero_to_one_element(self): # Confirm that the static trait handler was invoked. self.assertEqual(f.event_fired, True) + + def test_safe_casting(self): + class Bar(HasTraits): + unsafe_f32 = Array(dtype="float32") + safe_f32 = Array(dtype="float32", casting="safe") + + f64 = numpy.array([1], dtype="float64") + f32 = numpy.array([1], dtype="float32") + + b = Bar() + + b.unsafe_f32 = f32 + b.unsafe_f32 = f64 + b.safe_f32 = f32 + with self.assertRaises(TraitError): + b.safe_f32 = f64 diff --git a/traits/tests/test_array_or_none.py b/traits/tests/test_array_or_none.py index 940a918e7..6d5f87673 100644 --- a/traits/tests/test_array_or_none.py +++ b/traits/tests/test_array_or_none.py @@ -166,3 +166,19 @@ class FooBar(HasTraits): foo = foo_bar.foo foo += 1729.0 self.assertFalse((foo_bar.foo == foo_bar.bar).all()) + + def test_safe_casting(self): + class Bar(HasTraits): + unsafe_f32 = ArrayOrNone(dtype="float32") + safe_f32 = ArrayOrNone(dtype="float32", casting="safe") + + f64 = numpy.array([1], dtype="float64") + f32 = numpy.array([1], dtype="float32") + + b = Bar() + + b.unsafe_f32 = f32 + b.unsafe_f32 = f64 + b.safe_f32 = f32 + with self.assertRaises(TraitError): + b.safe_f32 = f64 diff --git a/traits/trait_numeric.py b/traits/trait_numeric.py index 575da992c..942dc2471 100644 --- a/traits/trait_numeric.py +++ b/traits/trait_numeric.py @@ -55,6 +55,8 @@ def __init__( value=None, coerce=False, typecode=None, + *, + casting="unsafe", **metadata ): global ndarray, asarray @@ -130,6 +132,7 @@ def __init__( self.dtype = dtype self.shape = shape self.coerce = coerce + self.casting = casting super(AbstractArray, self).__init__(value, **metadata) @@ -148,11 +151,7 @@ def validate(self, object, name, value): # Make sure the array is of the right type: if (self.dtype is not None) and (value.dtype != self.dtype): - if self.coerce: - value = value.astype(self.dtype) - else: - # XXX: this also coerces. - value = asarray(value, self.dtype) + value = value.astype(self.dtype, casting=self.casting) # If no shape requirements, then return the value: trait_shape = self.shape @@ -310,13 +309,39 @@ class Array(AbstractArray): second dimension must be at least 2.) value : numpy array A default value for the array. + casting : str + Casting rule for the array's dtype. If ``dtype`` is set, a value can + only be assigned if it passes the casting rule. Values can be: + + - "no": No casting is allowed + - "equiv": Only byte-order changes are allowed + - "safe": Only allow casting that fully preserves values (e.g. + "float32" to "float64") + - "same-kind": Only safe casts or casts within a kind (e.g. "float64" + to "float32") are allowed + - "unsafe": Any casting is allowed + + Default is "unsafe". """ def __init__( - self, dtype=None, shape=None, value=None, typecode=None, **metadata + self, + dtype=None, + shape=None, + value=None, + typecode=None, + *, + casting="unsafe", + **metadata ): super(Array, self).__init__( - dtype, shape, value, False, typecode=typecode, **metadata + dtype, + shape, + value, + False, + typecode=typecode, + casting=casting, + **metadata ) @@ -351,13 +376,39 @@ class CArray(AbstractArray): second dimension must be at least 2.) value : numpy array A default value for the array. + casting : str + Casting rule for the array's dtype. If ``dtype`` is set, a value can + only be assigned if it passes the casting rule. Values can be: + + - "no": No casting is allowed + - "equiv": Only byte-order changes are allowed + - "safe": Only allow casting that fully preserves values (e.g. + "float32" to "float64") + - "same-kind": Only safe casts or casts within a kind (e.g. "float64" + to "float32") are allowed + - "unsafe": Any casting is allowed + + Default is "unsafe". """ def __init__( - self, dtype=None, shape=None, value=None, typecode=None, **metadata + self, + dtype=None, + shape=None, + value=None, + typecode=None, + *, + casting="unsafe", + **metadata ): super(CArray, self).__init__( - dtype, shape, value, True, typecode=typecode, **metadata + dtype, + shape, + value, + True, + typecode=typecode, + casting=casting, + **metadata )