Skip to content

Commit

Permalink
Add casting keyword to numeric array types. (#547)
Browse files Browse the repository at this point in the history
* Add casting keyowrd to AbstractArray and its children

* Explicitly add  to Array and CArray init

* Add docstring

* Update docstring

* Change requests

* Docstring tweak

* Docstring tweak

* make casting kwarg-onlt for ArrayAbstract
  • Loading branch information
k2bd authored Aug 3, 2020
1 parent 6c46701 commit 91cd08c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
20 changes: 18 additions & 2 deletions traits/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import unittest

from traits.api import Array, Bool, HasTraits
from traits.api import Array, Bool, HasTraits, TraitError
from traits.testing.optional_dependencies import numpy, requires_numpy


Expand All @@ -25,10 +25,10 @@ def _a_changed(self):
self.event_fired = True


@requires_numpy
class ArrayTestCase(unittest.TestCase):
""" Test cases for delegated traits. """

@requires_numpy
def test_zero_to_one_element(self):
""" Test that an event fires when an Array trait changes from zero to
one element.
Expand All @@ -43,3 +43,19 @@ def test_zero_to_one_element(self):

# Confirm that the static trait handler was invoked.
self.assertEqual(f.event_fired, True)

def test_safe_casting(self):
class Bar(HasTraits):
unsafe_f32 = Array(dtype="float32")
safe_f32 = Array(dtype="float32", casting="safe")

f64 = numpy.array([1], dtype="float64")
f32 = numpy.array([1], dtype="float32")

b = Bar()

b.unsafe_f32 = f32
b.unsafe_f32 = f64
b.safe_f32 = f32
with self.assertRaises(TraitError):
b.safe_f32 = f64
16 changes: 16 additions & 0 deletions traits/tests/test_array_or_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,19 @@ class FooBar(HasTraits):
foo = foo_bar.foo
foo += 1729.0
self.assertFalse((foo_bar.foo == foo_bar.bar).all())

def test_safe_casting(self):
class Bar(HasTraits):
unsafe_f32 = ArrayOrNone(dtype="float32")
safe_f32 = ArrayOrNone(dtype="float32", casting="safe")

f64 = numpy.array([1], dtype="float64")
f32 = numpy.array([1], dtype="float32")

b = Bar()

b.unsafe_f32 = f32
b.unsafe_f32 = f64
b.safe_f32 = f32
with self.assertRaises(TraitError):
b.safe_f32 = f64
69 changes: 60 additions & 9 deletions traits/trait_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
value=None,
coerce=False,
typecode=None,
*,
casting="unsafe",
**metadata
):
global ndarray, asarray
Expand Down Expand Up @@ -130,6 +132,7 @@ def __init__(
self.dtype = dtype
self.shape = shape
self.coerce = coerce
self.casting = casting

super(AbstractArray, self).__init__(value, **metadata)

Expand All @@ -148,11 +151,7 @@ def validate(self, object, name, value):

# Make sure the array is of the right type:
if (self.dtype is not None) and (value.dtype != self.dtype):
if self.coerce:
value = value.astype(self.dtype)
else:
# XXX: this also coerces.
value = asarray(value, self.dtype)
value = value.astype(self.dtype, casting=self.casting)

# If no shape requirements, then return the value:
trait_shape = self.shape
Expand Down Expand Up @@ -310,13 +309,39 @@ class Array(AbstractArray):
second dimension must be at least 2.)
value : numpy array
A default value for the array.
casting : str
Casting rule for the array's dtype. If ``dtype`` is set, a value can
only be assigned if it passes the casting rule. Values can be:
- "no": No casting is allowed
- "equiv": Only byte-order changes are allowed
- "safe": Only allow casting that fully preserves values (e.g.
"float32" to "float64")
- "same-kind": Only safe casts or casts within a kind (e.g. "float64"
to "float32") are allowed
- "unsafe": Any casting is allowed
Default is "unsafe".
"""

def __init__(
self, dtype=None, shape=None, value=None, typecode=None, **metadata
self,
dtype=None,
shape=None,
value=None,
typecode=None,
*,
casting="unsafe",
**metadata
):
super(Array, self).__init__(
dtype, shape, value, False, typecode=typecode, **metadata
dtype,
shape,
value,
False,
typecode=typecode,
casting=casting,
**metadata
)


Expand Down Expand Up @@ -351,13 +376,39 @@ class CArray(AbstractArray):
second dimension must be at least 2.)
value : numpy array
A default value for the array.
casting : str
Casting rule for the array's dtype. If ``dtype`` is set, a value can
only be assigned if it passes the casting rule. Values can be:
- "no": No casting is allowed
- "equiv": Only byte-order changes are allowed
- "safe": Only allow casting that fully preserves values (e.g.
"float32" to "float64")
- "same-kind": Only safe casts or casts within a kind (e.g. "float64"
to "float32") are allowed
- "unsafe": Any casting is allowed
Default is "unsafe".
"""

def __init__(
self, dtype=None, shape=None, value=None, typecode=None, **metadata
self,
dtype=None,
shape=None,
value=None,
typecode=None,
*,
casting="unsafe",
**metadata
):
super(CArray, self).__init__(
dtype, shape, value, True, typecode=typecode, **metadata
dtype,
shape,
value,
True,
typecode=typecode,
casting=casting,
**metadata
)


Expand Down

0 comments on commit 91cd08c

Please sign in to comment.