Skip to content

Commit

Permalink
Added more tests for dpnp.arange()
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Nov 7, 2023
1 parent a7782fd commit 6ba0e17
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 76 deletions.
142 changes: 87 additions & 55 deletions numba_dpex/dpnp_iface/array_sequence_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import dpnp
import numba
import numpy as np
from llvmlite import ir as llvmir
from llvmlite.ir import types as llvmirtypes
from numba import errors, types
Expand Down Expand Up @@ -34,44 +33,42 @@
)


def _is_any_float_type(value):
def _is_float_type(start, stop, step):
return (
type(value) == float
or isinstance(value, np.floating)
or isinstance(value, Float)
type(start) == float
or type(stop) == float
or type(step) == float
or isinstance(start, Float)
or isinstance(stop, Float)
or isinstance(step, Float)
)


def _is_any_int_type(value):
def _is_int_type(start, stop, step):
return (
type(value) == int
or isinstance(value, np.integer)
or isinstance(value, Integer)
type(start) == int
or type(stop) == int
or type(step) == int
or isinstance(start, Integer)
or isinstance(stop, Integer)
or isinstance(step, Integer)
)


def _is_any_complex_type(value):
return np.iscomplex(value) or isinstance(value, Complex)
def _is_complex_type(start, stop, step):
return (
isinstance(start, Complex)
or isinstance(stop, Complex)
or isinstance(step, Complex)
)


def _parse_dtype_from_range(start, stop, step):
if (
_is_any_complex_type(start)
or _is_any_complex_type(stop)
or _is_any_complex_type(step)
):
numba.from_dtype(dpnp.complex128)
elif (
_is_any_float_type(start)
or _is_any_float_type(stop)
or _is_any_float_type(step)
):
if _is_complex_type(start, stop, step):
return numba.from_dtype(dpnp.complex128)
elif _is_float_type(start, stop, step):
return numba.from_dtype(dpnp.float64)
elif (
_is_any_int_type(start)
or _is_any_int_type(stop)
or _is_any_int_type(step)
):
elif _is_int_type(start, stop, step):
return numba.from_dtype(dpnp.int64)
else:
msg = (
Expand Down Expand Up @@ -155,7 +152,42 @@ def _get_dst_typeid(dtype):
raise errors.NumbaTypeError(msg)


def _normalize(builder, src, src_type, dest_type):
def _round(builder, src, src_type):
return_type = (
llvmirtypes.DoubleType()
if src_type.bitwidth == 64
else (
llvmirtypes.FloatType()
if src_type.bitwidth == 32
else llvmirtypes.HalfType()
)
)
round = builder.module.declare_intrinsic("llvm.round", [return_type])
src = builder.call(round, [src])
return src


def _is_fraction(builder, src, src_type):
if isinstance(src_type, Float):
return_type = (
llvmirtypes.DoubleType()
if src_type.bitwidth == 64
else (
llvmirtypes.FloatType()
if src_type.bitwidth == 32
else llvmirtypes.HalfType()
)
)
llvm_fabs = builder.module.declare_intrinsic("llvm.fabs", [return_type])
src_abs = builder.call(llvm_fabs, [src])
ret = True
is_lto = builder.fcmp_ordered(">=", src_abs, src.type(1.0))
with builder.if_then(is_lto):
ret = False
return ret


def _normalize(builder, src, src_type, dest_type, rounding=False):
dest_llvm_type = _get_llvm_type(dest_type)
if isinstance(src_type, Integer) and isinstance(dest_type, Integer):
if src_type.bitwidth < dest_type.bitwidth:
Expand All @@ -170,19 +202,8 @@ def _normalize(builder, src, src_type, dest_type):
else:
return builder.uitofp(src, dest_llvm_type)
elif isinstance(src_type, Float) and isinstance(dest_type, Integer):
# src_gtz = builder.fcmp_ordered(">", src, src.type(0.0)) # noqa: E800
# with builder.if_then(src_gtz):
return_type = (
llvmirtypes.DoubleType()
if src_type.bitwidth == 64
else (
llvmirtypes.FloatType()
if src_type.bitwidth == 32
else llvmirtypes.HalfType()
)
)
rint = builder.module.declare_intrinsic("llvm.round", [return_type])
src = builder.call(rint, [src])
if rounding:
src = _round(builder, src, src_type)
if dest_type.signed:
return builder.fptosi(src, dest_llvm_type)
else:
Expand All @@ -196,7 +217,8 @@ def _normalize(builder, src, src_type, dest_type):
return src
else:
msg = (
f"{src}[{src_type}] is neither a "
"dpnp_iface.array_sequence_ops._normalize(): "
+ f"{src}[{src_type}] is neither a "
+ "'numba.core.types.scalars.Float' "
+ "nor an 'numba.core.types.scalars.Integer'."
)
Expand All @@ -216,16 +238,14 @@ def _compute_array_length_ir(
ub = _normalize(builder, stop_ir, stop_arg_type, types.float64)
n = _normalize(builder, step_ir, step_arg_type, types.float64)

ceil = builder.module.declare_intrinsic(
llvm_ceil = builder.module.declare_intrinsic(
"llvm.ceil", [llvmirtypes.DoubleType()]
)
fabs = builder.module.declare_intrinsic(
"llvm.fabs", [llvmirtypes.DoubleType()]
)

array_length_ir = builder.fptosi(
builder.call(
ceil, [builder.fdiv(builder.call(fabs, [builder.fsub(ub, lb)]), n)]
llvm_ceil,
[builder.fdiv(builder.fsub(ub, lb), n)],
),
llvmir.IntType(64),
)
Expand Down Expand Up @@ -304,8 +324,14 @@ def codegen(context, builder, sig, args):
step_ir = context.get_constant(start_arg_type, 1)
step_arg_type = start_arg_type

# Keep note if either start or stop is in (-1.0, 0.0] or [0.0, 1.0)
round_step = not (
_is_fraction(builder, start_ir, start_arg_type)
and _is_fraction(builder, stop_ir, stop_arg_type)
)

# Allocate an empty array
t = _compute_array_length_ir(
len = _compute_array_length_ir(
builder,
start_ir,
stop_ir,
Expand All @@ -315,7 +341,7 @@ def codegen(context, builder, sig, args):
step_arg_type,
)
ary = _empty_nd_impl(
context, builder, sig.return_type, [t], qref_payload.queue_ref
context, builder, sig.return_type, [len], qref_payload.queue_ref
)
# Convert into void*
arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t)
Expand All @@ -324,14 +350,20 @@ def codegen(context, builder, sig, args):
start_ir = _normalize(
builder, start_ir, start_arg_type, dtype_arg_type.dtype
)
start_arg_type = dtype_arg_type.dtype
stop_ir = _normalize(
builder, stop_ir, stop_arg_type, dtype_arg_type.dtype
)
stop_arg_type = dtype_arg_type.dtype
step_ir = _normalize(
builder, step_ir, step_arg_type, dtype_arg_type.dtype
builder,
step_ir,
step_arg_type,
dtype_arg_type.dtype,
rounding=round_step,
)

# After normalization, their arg_types will change
start_arg_type = dtype_arg_type.dtype
stop_arg_type = dtype_arg_type.dtype
step_arg_type = dtype_arg_type.dtype

# Construct function parameters
Expand Down Expand Up @@ -409,10 +441,10 @@ def ol_dpnp_arange(
)

if is_nonelike(stop):
start = 0
stop = 1
start = 0 if type(start) == int or isinstance(start, Integer) else 0.0
stop = 1 if type(start) == int or isinstance(start, Integer) else 1.0
if is_nonelike(step):
step = 1
step = 1 if type(start) == int or isinstance(start, Integer) else 1.0

_dtype = (
_parse_dtype(dtype)
Expand Down
101 changes: 80 additions & 21 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,37 @@ def get_xfail_test(param, reason):
dtypes = get_all_dtypes(
no_bool=True, no_float16=True, no_none=False, no_complex=True
)
dtypes_except_none = get_all_dtypes(
no_bool=True, no_float16=True, no_none=True, no_complex=True
)
usm_types = ["device", "shared", "host"]
ranges = [
[1, None, None],
[1, None, None], # 0
[1, 10, None],
[1, 10, 1],
[-10, -1, 1],
[11, 41, 7],
[1, 10, 1.0],
[1, 10, 1.0], # 5
[1, 10.0, 1],
[0.7, 0.91, 0.03],
[-1003.345, -987.44, 0.73],
get_xfail_test([-1.0, None, None], "can't allocate an empty array"),
get_xfail_test([-1.0, 10, -2], "impossible range"),
get_xfail_test([-10, -1, -1], "impossible range"),
[1.15, 2.75, 0.05],
[0.75, 10.23, 0.95], # 10
[10.23, 0.75, -0.95],
get_xfail_test([-1.0, None, None], "Can't allocate an empty array"),
get_xfail_test([-1.0, 10, -2], "Impossible range"),
get_xfail_test([-10, -1, -1], "Impossible range"),
]


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_arange_basic(range, dtype, usm_type):
device = dpctl.SyclDevice().filter_string
def test_dpnp_arange_default(range, dtype):
start, stop, step = range

@dpjit
def func():
x = dpnp.arange(
start,
stop=stop,
step=step,
dtype=dtype,
usm_type=usm_type,
device=device,
)
x = dpnp.arange(start, stop=stop, step=step, dtype=dtype)
return x

try:
Expand All @@ -70,18 +67,48 @@ def func():
stop=stop,
step=step,
dtype=dtype,
usm_type=usm_type,
device=device,
)

print(a)
print(c)

assert a.dtype == c.dtype
assert a.shape == c.shape
if a.dtype in [dpnp.float, dpnp.float16, dpnp.float32, dpnp.float64]:
assert np.allclose(a.asnumpy(), c.asnumpy())
else:
assert np.array_equal(a.asnumpy(), c.asnumpy())
if c.sycl_queue != a.sycl_queue:
pytest.xfail(
"Returned queue does not have the same queue as in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
a.sycl_device
)


@pytest.mark.parametrize("range", ranges[0:3])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_arange_from_device(range, dtype, usm_type):
device = dpctl.SyclDevice().filter_string

start, stop, step = range

@dpjit
def func():
x = dpnp.arange(
start,
stop=stop,
step=step,
dtype=dtype,
usm_type=usm_type,
device=device,
)
return x

try:
c = func()
except Exception:
pytest.fail("Calling dpnp.arange() inside dpjit failed.")

assert c.usm_type == usm_type
assert c.sycl_device.filter_string == device
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
Expand All @@ -90,3 +117,35 @@ def func():
pytest.xfail(
"Returned queue does not have the same queue as cached against the device."
)


@pytest.mark.parametrize("range", ranges[0:3])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_arange_from_queue(range, dtype, usm_type):
start, stop, step = range

@dpjit
def func(queue):
x = dpnp.arange(
start,
stop=stop,
step=step,
dtype=dtype,
usm_type=usm_type,
sycl_queue=queue,
)
return x

try:
queue = dpctl.SyclQueue()
c = func(queue)
except Exception:
pytest.fail("Calling dpnp.arange() inside dpjit failed.")

assert c.usm_type == usm_type
assert c.sycl_device == queue.sycl_device
if c.sycl_queue != queue:
pytest.xfail(
"Returned queue does not have the same queue as the one passed to the dpnp function."
)

0 comments on commit 6ba0e17

Please sign in to comment.