Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add casting keyword to numeric array types. #547

Merged
merged 11 commits into from
Aug 3, 2020
Merged
20 changes: 18 additions & 2 deletions traits/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import unittest

from traits.api import Array, Bool, HasTraits
from traits.api import Array, Bool, HasTraits, TraitError
from traits.testing.optional_dependencies import numpy, requires_numpy


Expand All @@ -25,10 +25,10 @@ def _a_changed(self):
self.event_fired = True


@requires_numpy
class ArrayTestCase(unittest.TestCase):
""" Test cases for delegated traits. """

@requires_numpy
def test_zero_to_one_element(self):
""" Test that an event fires when an Array trait changes from zero to
one element.
Expand All @@ -43,3 +43,19 @@ def test_zero_to_one_element(self):

# Confirm that the static trait handler was invoked.
self.assertEqual(f.event_fired, True)

def test_safe_casting(self):
class Bar(HasTraits):
unsafe_f32 = Array(dtype="float32")
safe_f32 = Array(dtype="float32", casting="safe")

f64 = numpy.array([1], dtype="float64")
f32 = numpy.array([1], dtype="float32")

b = Bar()

b.unsafe_f32 = f32
b.unsafe_f32 = f64
b.safe_f32 = f32
with self.assertRaises(TraitError):
b.safe_f32 = f64
16 changes: 16 additions & 0 deletions traits/tests/test_array_or_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,19 @@ class FooBar(HasTraits):
foo = foo_bar.foo
foo += 1729.0
self.assertFalse((foo_bar.foo == foo_bar.bar).all())

def test_safe_casting(self):
class Bar(HasTraits):
unsafe_f32 = ArrayOrNone(dtype="float32")
safe_f32 = ArrayOrNone(dtype="float32", casting="safe")

f64 = numpy.array([1], dtype="float64")
f32 = numpy.array([1], dtype="float32")

b = Bar()

b.unsafe_f32 = f32
b.unsafe_f32 = f64
b.safe_f32 = f32
with self.assertRaises(TraitError):
b.safe_f32 = f64
69 changes: 60 additions & 9 deletions traits/trait_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
value=None,
coerce=False,
typecode=None,
*,
casting="unsafe",
**metadata
):
global ndarray, asarray
Expand Down Expand Up @@ -130,6 +132,7 @@ def __init__(
self.dtype = dtype
self.shape = shape
self.coerce = coerce
self.casting = casting

super(AbstractArray, self).__init__(value, **metadata)

Expand All @@ -148,11 +151,7 @@ def validate(self, object, name, value):

# Make sure the array is of the right type:
if (self.dtype is not None) and (value.dtype != self.dtype):
if self.coerce:
value = value.astype(self.dtype)
else:
# XXX: this also coerces.
value = asarray(value, self.dtype)
value = value.astype(self.dtype, casting=self.casting)

# If no shape requirements, then return the value:
trait_shape = self.shape
Expand Down Expand Up @@ -310,13 +309,39 @@ class Array(AbstractArray):
second dimension must be at least 2.)
value : numpy array
A default value for the array.
casting : str
Casting rule for the array's dtype. If ``dtype`` is set, a value can
only be assigned if it passes the casting rule. Values can be:

- "no": No casting is allowed
- "equiv": Only byte-order changes are allowed
- "safe": Only allow casting that fully preserves values (e.g.
"float32" to "float64")
- "same-kind": Only safe casts or casts within a kind (e.g. "float64"
to "float32") are allowed
- "unsafe": Any casting is allowed

Default is "unsafe".
"""

def __init__(
self, dtype=None, shape=None, value=None, typecode=None, **metadata
self,
dtype=None,
shape=None,
value=None,
typecode=None,
*,
casting="unsafe",
**metadata
):
super(Array, self).__init__(
dtype, shape, value, False, typecode=typecode, **metadata
dtype,
shape,
value,
False,
typecode=typecode,
casting=casting,
**metadata
)


Expand Down Expand Up @@ -351,13 +376,39 @@ class CArray(AbstractArray):
second dimension must be at least 2.)
value : numpy array
A default value for the array.
casting : str
Casting rule for the array's dtype. If ``dtype`` is set, a value can
only be assigned if it passes the casting rule. Values can be:

- "no": No casting is allowed
- "equiv": Only byte-order changes are allowed
- "safe": Only allow casting that fully preserves values (e.g.
"float32" to "float64")
- "same-kind": Only safe casts or casts within a kind (e.g. "float64"
to "float32") are allowed
- "unsafe": Any casting is allowed

Default is "unsafe".
"""

def __init__(
self, dtype=None, shape=None, value=None, typecode=None, **metadata
self,
dtype=None,
shape=None,
value=None,
typecode=None,
*,
casting="unsafe",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this parameter keyword-only (here and elsewhere)? I don't think there's any good case for passing this by position.

**metadata
):
super(CArray, self).__init__(
dtype, shape, value, True, typecode=typecode, **metadata
dtype,
shape,
value,
True,
typecode=typecode,
casting=casting,
**metadata
)


Expand Down