Skip to content

Commit

Permalink
Fix complex scalar checking and test_square failure
Browse files Browse the repository at this point in the history
Closes gh-190
  • Loading branch information
rgommers committed Nov 14, 2023
1 parent b6370dc commit 3b4a954
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test element-wise functions/operators against reference implementations.
"""
import cmath
import math
import operator
from copy import copy
Expand Down Expand Up @@ -67,6 +68,29 @@ def isclose(
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)


def isclose_complex(
a: complex,
b: complex,
M: float,
*,
rel_tol: float = 0.25,
abs_tol: float = 1,
) -> bool:
"""Wraps math.isclose with very generous defaults.
This is useful for many floating-point operations where the spec does not
make accuracy requirements.
"""
if cmath.isnan(a) or cmath.isnan(b):
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
if cmath.isinf(a):
return cmath.isinf(b) or abs(b) > cmath.log(M)
elif cmath.isinf(b):
return cmath.isinf(a) or abs(a) > cmath.log(M)
return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)



def default_filter(s: Scalar) -> bool:
"""Returns False when s is a non-finite or a signed zero.
Expand Down Expand Up @@ -254,8 +278,7 @@ def unary_assert_against_refimpl(
f"{f_i}={scalar_i}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -330,8 +353,7 @@ def binary_assert_against_refimpl(
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -403,8 +425,7 @@ def right_scalar_assert_against_refimpl(
f"{f_l}={scalar_l}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -1394,7 +1415,7 @@ def test_square(x):
ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl(
"square", x, out, lambda s: s**2, expr_template="{}²={}"
"square", x, out, lambda s: s*s, expr_template="{}²={}"
)


Expand Down

0 comments on commit 3b4a954

Please sign in to comment.