Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix complex scalar checking and test_square failure #204

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test element-wise functions/operators against reference implementations.
"""
import cmath
import math
import operator
from copy import copy
Expand Down Expand Up @@ -48,7 +49,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
def isclose(
a: float,
b: float,
M: float,
maximum: float,
*,
rel_tol: float = 0.25,
abs_tol: float = 1,
Expand All @@ -61,12 +62,30 @@ def isclose(
if math.isnan(a) or math.isnan(b):
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
if math.isinf(a):
return math.isinf(b) or abs(b) > math.log(M)
return math.isinf(b) or abs(b) > math.log(maximum)
elif math.isinf(b):
return math.isinf(a) or abs(a) > math.log(M)
return math.isinf(a) or abs(a) > math.log(maximum)
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)


def isclose_complex(
a: complex,
b: complex,
maximum: float,
*,
rel_tol: float = 0.25,
abs_tol: float = 1,
) -> bool:
"""Like isclose() but specifically for complex values."""
if cmath.isnan(a) or cmath.isnan(b):
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
if cmath.isinf(a):
return cmath.isinf(b) or abs(b) > cmath.log(maximum)
elif cmath.isinf(b):
return cmath.isinf(a) or abs(a) > cmath.log(maximum)
return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)


def default_filter(s: Scalar) -> bool:
"""Returns False when s is a non-finite or a signed zero.

Expand Down Expand Up @@ -254,8 +273,7 @@ def unary_assert_against_refimpl(
f"{f_i}={scalar_i}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -330,8 +348,7 @@ def binary_assert_against_refimpl(
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -403,8 +420,7 @@ def right_scalar_assert_against_refimpl(
f"{f_l}={scalar_l}"
)
if res.dtype in dh.complex_dtypes:
assert isclose(scalar_o.real, expected.real, M), msg
assert isclose(scalar_o.imag, expected.imag, M), msg
assert isclose_complex(scalar_o, expected, M), msg
else:
assert isclose(scalar_o, expected, M), msg
else:
Expand Down Expand Up @@ -1394,7 +1410,7 @@ def test_square(x):
ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl(
"square", x, out, lambda s: s**2, expr_template="{}²={}"
"square", x, out, lambda s: s*s, expr_template="{}²={}"
)


Expand Down