Skip to content

Commit

Permalink
Update unit tests for dpnp array constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
khaled authored and Diptorup Deb committed May 10, 2023
1 parent 6eb9cf8 commit a5726e3
Show file tree
Hide file tree
Showing 11 changed files with 1,013 additions and 211 deletions.
20 changes: 17 additions & 3 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
ndim,
layout="C",
dtype=None,
is_fill_value_float=False,
usm_type="device",
device=None,
queue=None,
Expand Down Expand Up @@ -66,9 +67,22 @@ def __init__(
self.device = self.queue.sycl_device.filter_string

if not dtype:
dummy_tensor = dpctl.tensor.empty(
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
if is_fill_value_float:
dummy_tensor = dpctl.tensor.empty(
1,
dtype=dpctl.tensor.float64,
order=layout,
usm_type=usm_type,
sycl_queue=self.queue,
)
else:
dummy_tensor = dpctl.tensor.empty(
1,
dtype=dpctl.tensor.int64,
order=layout,
usm_type=usm_type,
sycl_queue=self.queue,
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
self.dtype = from_dtype(_dtype)
Expand Down
54 changes: 42 additions & 12 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import dpnp
from numba import errors, types
from numba.core.types import scalars
from numba.core.types.containers import UniTuple
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.extending import overload
Expand All @@ -20,15 +22,27 @@
impl_dpnp_ones_like,
impl_dpnp_zeros,
impl_dpnp_zeros_like,
intrin_usm_alloc,
)

# =========================================================================
# Helps to parse dpnp constructor arguments
# =========================================================================


def _parse_dtype(dtype, data=None):
def _parse_dim(x1):
if hasattr(x1, "ndim") and x1.ndim:
return x1.ndim
elif isinstance(x1, scalars.Integer):
r = 1
return r
elif isinstance(x1, UniTuple):
r = len(x1)
return r
else:
return 0


def _parse_dtype(dtype):
"""Resolve dtype parameter.
Resolves the dtype parameter based on the given value
Expand All @@ -44,9 +58,8 @@ class for nd-arrays. Defaults to None.
numba.core.types.functions.NumberClass: Resolved numba type
class for number classes.
"""

_dtype = None
if data and isinstance(data, types.Array):
_dtype = data.dtype
if not is_nonelike(dtype):
_dtype = _ty_parse_dtype(dtype)
return _dtype
Expand All @@ -60,6 +73,9 @@ def _parse_layout(layout):
raise errors.NumbaValueError(msg)
return layout_type_str
elif isinstance(layout, str):
if layout not in ["C", "F", "A"]:
msg = f"Invalid layout specified: '{layout}'"
raise errors.NumbaValueError(msg)
return layout
else:
raise TypeError(
Expand Down Expand Up @@ -94,6 +110,9 @@ def _parse_usm_type(usm_type):
raise errors.NumbaValueError(msg)
return usm_type_str
elif isinstance(usm_type, str):
if usm_type not in ["shared", "device", "host"]:
msg = f"Invalid usm_type specified: '{usm_type}'"
raise errors.NumbaValueError(msg)
return usm_type
else:
raise TypeError(
Expand Down Expand Up @@ -150,6 +169,7 @@ def build_dpnp_ndarray(
ndim,
layout="C",
dtype=None,
is_fill_value_float=False,
usm_type="device",
device=None,
sycl_queue=None,
Expand All @@ -163,6 +183,8 @@ def build_dpnp_ndarray(
Data type of the array. Can be typestring, a `numpy.dtype`
object, `numpy` char string, or a numpy scalar type.
Default: None.
is_fill_value_float (bool): Specify if the fill value is floating
point.
usm_type (numba.core.types.misc.StringLiteral, optional):
The type of SYCL USM allocation for the output array.
Allowed values are "device"|"shared"|"host".
Expand Down Expand Up @@ -198,6 +220,7 @@ def build_dpnp_ndarray(
ndim=ndim,
layout=layout,
dtype=dtype,
is_fill_value_float=is_fill_value_float,
usm_type=usm_type,
device=device,
queue=sycl_queue,
Expand Down Expand Up @@ -280,6 +303,7 @@ def ol_dpnp_empty(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -384,6 +408,7 @@ def ol_dpnp_zeros(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -488,6 +513,7 @@ def ol_dpnp_ones(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -586,6 +612,7 @@ def ol_dpnp_full(

_ndim = _ty_parse_shape(shape)
_dtype = _parse_dtype(dtype)
_is_fill_value_float = isinstance(fill_value, scalars.Float)
_layout = _parse_layout(order)
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand All @@ -596,6 +623,7 @@ def ol_dpnp_full(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=_is_fill_value_float,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -699,8 +727,8 @@ def ol_dpnp_empty_like(
+ "inside overloaded dpnp.empty_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -812,8 +840,8 @@ def ol_dpnp_zeros_like(
+ "inside overloaded dpnp.zeros_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -924,8 +952,8 @@ def ol_dpnp_ones_like(
+ "inside overloaded dpnp.ones_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -1041,8 +1069,9 @@ def ol_dpnp_full_like(
+ "inside overloaded dpnp.full_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_is_fill_value_float = isinstance(fill_value, scalars.Float)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand All @@ -1052,6 +1081,7 @@ def ol_dpnp_full_like(
_ndim,
layout=_order,
dtype=_dtype,
is_fill_value_float=_is_fill_value_float,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down
101 changes: 83 additions & 18 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,61 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for dpnp ndarray constructors."""
"""Tests for the dpnp.empty overload."""

import dpctl
import dpnp
import pytest
from numba import errors

from numba_dpex import dpjit

shapes = [11, (2, 5)]
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
usm_types = ["device", "shared", "host"]
devices = ["cpu", None]


@pytest.mark.parametrize("shape", shapes)
def test_dpnp_empty_default(shape):
"""Test dpnp.empty() with default parameters inside dpjit."""

@dpjit
def func(shape):
c = dpnp.empty(shape)
return c

try:
c = func(shape)
except Exception:
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
else:
assert c.shape == shape

dummy = dpnp.empty(shape)

assert c.dtype == dummy.dtype
assert c.usm_type == dummy.usm_type
assert c.sycl_device == dummy.sycl_device
assert c.sycl_queue == dummy.sycl_queue
if c.sycl_queue != dummy.sycl_queue:
pytest.xfail(
"Returned queue does not have the queue in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
dummy.sycl_device
)


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
@pytest.mark.parametrize("device", devices)
def test_dpnp_empty(shape, dtype, usm_type, device):
def test_dpnp_empty_from_device(shape, dtype, usm_type):
""" "Use device only in dpnp.emtpy() inside dpjit."""
device = dpctl.SyclDevice().filter_string

@dpjit
def func(shape):
c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, device=device)
Expand All @@ -29,7 +65,7 @@ def func(shape):
try:
c = func(shape)
except Exception:
pytest.fail("Calling dpnp.empty inside dpjit failed")
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
Expand All @@ -38,32 +74,61 @@ def func(shape):

assert c.dtype == dtype
assert c.usm_type == usm_type
if device is not None:
assert (
c.sycl_device.filter_string
== dpctl.SyclDevice(device).filter_string
assert c.sycl_device.filter_string == device
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
device
):
pytest.xfail(
"Returned queue does not have the queue cached against the device."
)
else:
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string


@pytest.mark.parametrize("shape", shapes)
def test_dpnp_empty_default_dtype(shape):
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_from_queue(shape, dtype, usm_type):
""" "Use queue only in dpnp.emtpy() inside dpjit."""

@dpjit
def func(shape):
c = dpnp.empty(shape)
def func(shape, queue):
c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
return c

queue = dpctl.SyclQueue()

try:
c = func(shape)
c = func(shape, queue)
except Exception:
pytest.fail("Calling dpnp.empty inside dpjit failed")
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
else:
assert c.shape == shape

dummy_tensor = dpctl.tensor.empty(shape)
assert c.dtype == dtype
assert c.usm_type == usm_type
assert c.sycl_device == queue.sycl_device

if c.sycl_queue != queue:
pytest.xfail(
"Returned queue does not have the queue passed to the dpnp function."
)


def test_dpnp_empty_exceptions():
"""Test if exception is raised when both queue and device are specified."""
device = dpctl.SyclDevice().filter_string

assert c.dtype == dummy_tensor.dtype
@dpjit
def func(shape, queue):
c = dpnp.empty(shape, sycl_queue=queue, device=device)
return c

queue = dpctl.SyclQueue()

try:
func(10, queue)
except Exception as e:
assert isinstance(e, errors.TypingError)
assert "`device` and `sycl_queue` are exclusive keywords" in str(e)
Loading

0 comments on commit a5726e3

Please sign in to comment.