From 751da72a715b9f7b6e32357f29bd04a7562a7783 Mon Sep 17 00:00:00 2001 From: khaled Date: Thu, 6 Apr 2023 13:31:20 -0500 Subject: [PATCH] Overload implementation for dpnp.full() - Adds an overload implementation for dpnp.full - Removes the `like` kwarg from dpnp.empty() overload - Unit test cases --- numba_dpex/core/runtime/_dpexrt_python.c | 122 ++++++-- numba_dpex/core/runtime/context.py | 45 ++- numba_dpex/core/types/usm_ndarray_type.py | 2 +- numba_dpex/dpnp_iface/_intrinsic.py | 198 +++++++++--- numba_dpex/dpnp_iface/arrayobj.py | 289 ++++++++++++++---- .../tests/dpjit_tests/dpnp/test_dpnp_empty.py | 20 +- .../tests/dpjit_tests/dpnp/test_dpnp_full.py | 69 +++++ .../tests/dpjit_tests/dpnp/test_dpnp_ones.py | 6 +- .../dpjit_tests/dpnp/test_dpnp_ones_like.py | 8 +- .../tests/dpjit_tests/dpnp/test_dpnp_zeros.py | 10 +- .../dpjit_tests/dpnp/test_dpnp_zeros_like.py | 10 +- 11 files changed, 629 insertions(+), 150 deletions(-) create mode 100644 numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_full.py diff --git a/numba_dpex/core/runtime/_dpexrt_python.c b/numba_dpex/core/runtime/_dpexrt_python.c index 4ca1993e9f..51511448cb 100644 --- a/numba_dpex/core/runtime/_dpexrt_python.c +++ b/numba_dpex/core/runtime/_dpexrt_python.c @@ -34,6 +34,12 @@ static NRT_ExternalAllocator * NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type); static void *DPEXRTQueue_CreateFromFilterString(const char *device); static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner); +static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi, + size_t itemsize, + bool dest_is_float, + bool value_is_float, + int64_t value, + const char *device); static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj, void *data, npy_intp nitems, @@ -510,25 +516,47 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device) * This function takes an allocated memory as NRT_MemInfo and fills it with * the value specified by `value`. * - * @param mi An NRT_MemInfo object, should be found from memory - * allocation. - * @param itemsize The itemsize, the size of each item in the array. - * @param is_float Flag to specify if the data being float or not. - * @param value The value to be used to fill an array. - * @param device The device on which the memory was allocated. - * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo - * object could be created. + * @param mi An NRT_MemInfo object, should be found from memory + * allocation. + * @param itemsize The itemsize, the size of each item in the array. + * @param dest_is_float True if the destination array's dtype is float. + * @param value_is_float True if the value to be filled is float. + * @param value The value to be used to fill an array. + * @param device The device on which the memory was allocated. + * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo + * object could be created. */ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi, size_t itemsize, - bool is_float, - uint8_t value, + bool dest_is_float, + bool value_is_float, + int64_t value, const char *device) { DPCTLSyclQueueRef qref = NULL; DPCTLSyclEventRef eref = NULL; size_t count = 0, size = 0, exp = 0; + /** + * @brief A union for bit conversion from the input int64_t value + * to a uintX_t bit-pattern with appropriate type conversion when the + * input value represents a float. + */ + typedef union + { + float f_; /**< The float to be represented. */ + double d_; + int8_t i8_; + int16_t i16_; + int32_t i32_; + int64_t i64_; + uint8_t ui8_; + uint16_t ui16_; + uint32_t ui32_; /**< The bit representation. */ + uint64_t ui64_; /**< The bit representation. */ + } bitcaster_t; + + bitcaster_t bc; size = mi->size; while (itemsize >>= 1) exp++; @@ -552,40 +580,86 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi, switch (exp) { case 3: { - uint64_t value_assign = (uint64_t)value; - if (is_float) { - double const_val = (double)value; + if (dest_is_float && value_is_float) { + double *p = (double *)(&value); + bc.d_ = *p; + } + else if (dest_is_float && !value_is_float) { // To stop warning: dereferencing type-punned pointer // will break strict-aliasing rules [-Wstrict-aliasing] - double *p = &const_val; - value_assign = *((uint64_t *)(p)); + double cd = (double)value; + bc.d_ = *((double *)(&cd)); + } + else if (!dest_is_float && value_is_float) { + double *p = (double *)&value; + bc.i64_ = *p; + } + else { + bc.i64_ = value; } - if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count))) + + if (!(eref = DPCTLQueue_Fill64(qref, mi->data, bc.ui64_, count))) goto error; break; } case 2: { - uint32_t value_assign = (uint32_t)value; - if (is_float) { - float const_val = (float)value; + if (dest_is_float && value_is_float) { + double *p = (double *)(&value); + bc.f_ = *p; + } + else if (dest_is_float && !value_is_float) { // To stop warning: dereferencing type-punned pointer // will break strict-aliasing rules [-Wstrict-aliasing] - float *p = &const_val; - value_assign = *((uint32_t *)(p)); + float cf = (float)value; + bc.f_ = *((float *)(&cf)); + } + else if (!dest_is_float && value_is_float) { + double *p = (double *)&value; + bc.i32_ = *p; } - if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count))) + else { + bc.i32_ = (int32_t)value; + } + + if (!(eref = DPCTLQueue_Fill32(qref, mi->data, bc.ui32_, count))) goto error; break; } case 1: - if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count))) + { + if (dest_is_float) + goto error; + + if (value_is_float) { + double *p = (double *)&value; + bc.i16_ = *p; + } + else { + bc.i16_ = (int16_t)value; + } + + if (!(eref = DPCTLQueue_Fill16(qref, mi->data, bc.ui16_, count))) goto error; break; + } case 0: - if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count))) + { + if (dest_is_float) + goto error; + + if (value_is_float) { + double *p = (double *)&value; + bc.i8_ = *p; + } + else { + bc.i8_ = (int8_t)value; + } + + if (!(eref = DPCTLQueue_Fill8(qref, mi->data, bc.ui8_, count))) goto error; break; + } default: goto error; } diff --git a/numba_dpex/core/runtime/context.py b/numba_dpex/core/runtime/context.py index 49d41f906d..e55d64afad 100644 --- a/numba_dpex/core/runtime/context.py +++ b/numba_dpex/core/runtime/context.py @@ -29,14 +29,35 @@ def wrap(self, builder, *args, **kwargs): @_check_null_result def meminfo_alloc(self, builder, size, usm_type, device): - """A wrapped caller for meminfo_alloc_unchecked() with null check.""" + """ + Wrapper to call :func:`~context.DpexRTContext.meminfo_alloc_unchecked` + with null checking of the returned value. + """ return self.meminfo_alloc_unchecked(builder, size, usm_type, device) @_check_null_result - def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device): - """A wrapped caller for meminfo_fill_unchecked() with null check.""" + def meminfo_fill( + self, + builder, + meminfo, + itemsize, + dest_is_float, + value_is_float, + value, + device, + ): + """ + Wrapper to call :func:`~context.DpexRTContext.meminfo_fill_unchecked` + with null checking of the returned value. + """ return self.meminfo_fill_unchecked( - builder, meminfo, itemsize, is_float, value, device + builder, + meminfo, + itemsize, + dest_is_float, + value_is_float, + value, + device, ) def meminfo_alloc_unchecked(self, builder, size, usm_type, device): @@ -71,7 +92,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device): return ret def meminfo_fill_unchecked( - self, builder, meminfo, itemsize, is_float, value, device + self, + builder, + meminfo, + itemsize, + dest_is_float, + value_is_float, + value, + device, ): """Fills an allocated `MemInfo` with the value specified. @@ -96,12 +124,15 @@ def meminfo_fill_unchecked( b = llvmir.IntType(1) fnty = llvmir.FunctionType( cgutils.voidptr_t, - [cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t], + [cgutils.voidptr_t, u64, b, b, u64, cgutils.voidptr_t], ) fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill") fn.return_value.add_attribute("noalias") - ret = builder.call(fn, [meminfo, itemsize, is_float, value, device]) + ret = builder.call( + fn, + [meminfo, itemsize, dest_is_float, value_is_float, value, device], + ) return ret diff --git a/numba_dpex/core/types/usm_ndarray_type.py b/numba_dpex/core/types/usm_ndarray_type.py index bd4d12eaf3..e3d1101633 100644 --- a/numba_dpex/core/types/usm_ndarray_type.py +++ b/numba_dpex/core/types/usm_ndarray_type.py @@ -74,7 +74,7 @@ def __init__( if not dtype: dummy_tensor = dpctl.tensor.empty( - shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue + 1, order=layout, usm_type=usm_type, sycl_queue=self.queue ) # convert dpnp type to numba/numpy type _dtype = dummy_tensor.dtype diff --git a/numba_dpex/dpnp_iface/_intrinsic.py b/numba_dpex/dpnp_iface/_intrinsic.py index 3c2ad0ac98..a013b329f7 100644 --- a/numba_dpex/dpnp_iface/_intrinsic.py +++ b/numba_dpex/dpnp_iface/_intrinsic.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from llvmlite import ir as llvmir from numba import types from numba.core.typing import signature from numba.extending import intrinsic @@ -76,14 +77,29 @@ def fill_arrayobj(context, builder, sig, llargs, value, is_like=False): types.intp, get_itemsize(context, arrtype[0]) ) device = context.insert_const_string(builder.module, arrtype[0].device) - value = context.get_constant(types.int8, value) + + # Do a bitcast of the input to a 64-bit int. + value = builder.bitcast(value, llvmir.IntType(64)) + + if isinstance(sig.args[1], types.scalars.Float): + value_is_float = context.get_constant(types.boolean, 1) + else: + value_is_float = context.get_constant(types.boolean, 0) + if isinstance(arrtype[0].dtype, types.scalars.Float): - is_float = context.get_constant(types.boolean, 1) + dest_is_float = context.get_constant(types.boolean, 1) else: - is_float = context.get_constant(types.boolean, 0) + dest_is_float = context.get_constant(types.boolean, 0) + dpexrtCtx = dpexrt.DpexRTContext(context) dpexrtCtx.meminfo_fill( - builder, ary.meminfo, itemsize, is_float, value, device + builder, + ary.meminfo, + itemsize, + dest_is_float, + value_is_float, + value, + device, ) return ary, arrtype @@ -109,6 +125,7 @@ def impl_dpnp_empty( ty_shape, ty_dtype, ty_order, + # ty_like, # see issue https://github.com/IntelPython/numba-dpex/issues/998 ty_device, ty_usm_type, ty_sycl_queue, @@ -119,12 +136,14 @@ def impl_dpnp_empty( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_shape (numba.core.types.abstract): One of the numba defined - abstract types. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_like (numba.core.types.npytypes.Array): Numba type for array. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -144,6 +163,7 @@ def impl_dpnp_empty( ty_shape, ty_dtype, ty_order, + # ty_like, # see issue https://github.com/IntelPython/numba-dpex/issues/998 ty_device, ty_usm_type, ty_sycl_queue, @@ -163,6 +183,7 @@ def impl_dpnp_zeros( ty_shape, ty_dtype, ty_order, + ty_like, ty_device, ty_usm_type, ty_sycl_queue, @@ -173,12 +194,14 @@ def impl_dpnp_zeros( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_shape (numba.core.types.abstract): One of the numba defined - abstract types. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_like (numba.core.types.npytypes.Array): Numba type for array. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -198,6 +221,7 @@ def impl_dpnp_zeros( ty_shape, ty_dtype, ty_order, + ty_like, ty_device, ty_usm_type, ty_sycl_queue, @@ -205,7 +229,8 @@ def impl_dpnp_zeros( ) def codegen(context, builder, sig, llargs): - ary, _ = fill_arrayobj(context, builder, sig, llargs, 0) + fill_value = context.get_constant(types.intp, 0) + ary, _ = fill_arrayobj(context, builder, sig, llargs, fill_value) return ary._getvalue() return sig, codegen @@ -217,6 +242,7 @@ def impl_dpnp_ones( ty_shape, ty_dtype, ty_order, + ty_like, ty_device, ty_usm_type, ty_sycl_queue, @@ -227,12 +253,14 @@ def impl_dpnp_ones( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_shape (numba.core.types.abstract): One of the numba defined - abstract types. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_like (numba.core.types.npytypes.Array): Numba type for array. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -252,6 +280,7 @@ def impl_dpnp_ones( ty_shape, ty_dtype, ty_order, + ty_like, ty_device, ty_usm_type, ty_sycl_queue, @@ -259,7 +288,8 @@ def impl_dpnp_ones( ) def codegen(context, builder, sig, llargs): - ary, _ = fill_arrayobj(context, builder, sig, llargs, 1) + fill_value = context.get_constant(types.intp, 1) + ary, _ = fill_arrayobj(context, builder, sig, llargs, fill_value) return ary._getvalue() return sig, codegen @@ -268,9 +298,11 @@ def codegen(context, builder, sig, llargs): @intrinsic def impl_dpnp_empty_like( ty_context, - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -281,11 +313,16 @@ def impl_dpnp_empty_like( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_x1 (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_subok (numba.core.types.scalars.Boolean): Numba type class for + subok. + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. Not supported. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -302,9 +339,11 @@ def impl_dpnp_empty_like( ty_retty = ty_retty_ref.instance_type sig = ty_retty( - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -323,9 +362,11 @@ def codegen(context, builder, sig, llargs): @intrinsic def impl_dpnp_zeros_like( ty_context, - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -336,11 +377,16 @@ def impl_dpnp_zeros_like( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_x1 (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_subok (numba.core.types.scalars.Boolean): Numba type class for + subok. + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. Not supported. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -357,9 +403,11 @@ def impl_dpnp_zeros_like( ty_retty = ty_retty_ref.instance_type sig = ty_retty( - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -367,7 +415,10 @@ def impl_dpnp_zeros_like( ) def codegen(context, builder, sig, llargs): - ary, _ = fill_arrayobj(context, builder, sig, llargs, 0, is_like=True) + fill_value = context.get_constant(types.intp, 0) + ary, _ = fill_arrayobj( + context, builder, sig, llargs, fill_value, is_like=True + ) return ary._getvalue() return sig, codegen @@ -376,9 +427,11 @@ def codegen(context, builder, sig, llargs): @intrinsic def impl_dpnp_ones_like( ty_context, - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -389,11 +442,16 @@ def impl_dpnp_ones_like( Args: ty_context (numba.core.typing.context.Context): The typing context for the codegen. - ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. - ty_dtype (numba.core.types.functions.NumberClass): Type class for - number classes (e.g. "np.float64"). + ty_x1 (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. ty_order (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. + ty_subok (numba.core.types.scalars.Boolean): Numba type class for + subok. + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. Not supported. ty_device (numba.core.types.misc.UnicodeType): UnicodeType from numba for strings. ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType @@ -410,9 +468,11 @@ def impl_dpnp_ones_like( ty_retty = ty_retty_ref.instance_type sig = ty_retty( - ty_x, + ty_x1, ty_dtype, ty_order, + ty_subok, + ty_shape, ty_device, ty_usm_type, ty_sycl_queue, @@ -420,7 +480,75 @@ def impl_dpnp_ones_like( ) def codegen(context, builder, sig, llargs): - ary, _ = fill_arrayobj(context, builder, sig, llargs, 1, is_like=True) + fill_value = context.get_constant(types.intp, 1) + ary, _ = fill_arrayobj( + context, builder, sig, llargs, fill_value, is_like=True + ) return ary._getvalue() return sig, codegen + + +@intrinsic +def impl_dpnp_full( + ty_context, + ty_shape, + ty_fill_value, + ty_dtype, + ty_order, + ty_like, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.full(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_shape (numba.core.types.scalars.Integer or + numba.core.types.containers.UniTuple): Numba type for the shape + of the array. + ty_fill_value (numba.core.types.scalars): One of the Numba scalar + types. + ty_dtype (numba.core.types.functions.NumberClass): Numba type for + dtype. + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_like (numba.core.types.npytypes.Array): Numba type for array. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + signature = ty_retty( + ty_shape, + ty_fill_value, + ty_dtype, + ty_order, + ty_like, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, args): + fill_value = context.get_argument_value(builder, sig.args[1], args[1]) + ary, _ = fill_arrayobj( + context, builder, sig, args, fill_value, is_like=False + ) + return ary._getvalue() + + return signature, codegen diff --git a/numba_dpex/dpnp_iface/arrayobj.py b/numba_dpex/dpnp_iface/arrayobj.py index 694ac37db5..fe7fab6b4a 100644 --- a/numba_dpex/dpnp_iface/arrayobj.py +++ b/numba_dpex/dpnp_iface/arrayobj.py @@ -14,6 +14,7 @@ from ._intrinsic import ( impl_dpnp_empty, impl_dpnp_empty_like, + impl_dpnp_full, impl_dpnp_ones, impl_dpnp_ones_like, impl_dpnp_zeros, @@ -79,7 +80,7 @@ def _parse_usm_type(usm_type): Raises: errors.NumbaValueError: If an invalid usm_type is specified. TypeError: If the parameter is neither a 'str' - nor a 'types.StringLiteral' + nor a 'types.StringLiteral'. Returns: str: The stringized usm_type. @@ -112,7 +113,7 @@ def _parse_device_filter_string(device): Raises: TypeError: If the parameter is neither a 'str' - nor a 'types.StringLiteral' + nor a 'types.StringLiteral'. Returns: str: The stringized device. @@ -142,11 +143,11 @@ def build_dpnp_ndarray( Args: ndim (int): The dimension of the array. - layout ("C", or F"): memory layout for the array. Default: "C" + layout ("C", or F"): memory layout for the array. Default: "C". dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. usm_type (numba.core.types.misc.StringLiteral, optional): The type of SYCL USM allocation for the output array. Allowed values are "device"|"shared"|"host". @@ -208,6 +209,8 @@ def ol_dpnp_empty( shape, dtype=None, order="C", + # like=None, # this gets lost when dpnp.empty() is called outside dpjit, + # see issue https://github.com/IntelPython/numba-dpex/issues/998 device=None, usm_type="device", sycl_queue=None, @@ -216,13 +219,21 @@ def ol_dpnp_empty( a jit function. Args: - shape (tuple): Dimensions of the array to be created. + shape (numba.core.types.containers.UniTuple or + numba.core.types.scalars.IntegerLiteral): Dimensions + of the array to be created. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + like (numba.core.types.npytypes.Array, optional): A type for + reference object to allow the creation of arrays which are not + `NumPy` arrays. If an array-like passed in as `like` supports the + `__array_function__` protocol, the result will be defined by it. + In this case, it ensures the creation of an array object + compatible with that passed in via this argument. device (numba.core.types.misc.StringLiteral, optional): array API concept of device where the output array is created. `device` can be `None`, a oneAPI filter selector string, an instance of @@ -241,7 +252,7 @@ def ol_dpnp_empty( errors.TypingError: If couldn't parse input types to dpnp.empty(). Returns: - function: Local function `impl_dpnp_empty()` + function: Local function `impl_dpnp_empty()`. """ _ndim = _ty_parse_shape(shape) @@ -266,12 +277,20 @@ def impl( shape, dtype=None, order="C", + # like=None, see issue https://github.com/IntelPython/numba-dpex/issues/998 device=None, usm_type="device", sycl_queue=None, ): return impl_dpnp_empty( - shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + shape, + _dtype, + order, + # like, see issue https://github.com/IntelPython/numba-dpex/issues/998 + _device, + _usm_type, + sycl_queue, + ret_ty, ) return impl @@ -289,6 +308,7 @@ def ol_dpnp_zeros( shape, dtype=None, order="C", + like=None, device=None, usm_type="device", sycl_queue=None, @@ -297,13 +317,21 @@ def ol_dpnp_zeros( a jit function. Args: - shape (tuple): Dimensions of the array to be created. + shape (numba.core.types.containers.UniTuple or + numba.core.types.scalars.IntegerLiteral): Dimensions + of the array to be created. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + like (numba.core.types.npytypes.Array, optional): A type for + reference object to allow the creation of arrays which are not + `NumPy` arrays. If an array-like passed in as `like` supports the + `__array_function__` protocol, the result will be defined by it. + In this case, it ensures the creation of an array object + compatible with that passed in via this argument. device (numba.core.types.misc.StringLiteral, optional): array API concept of device where the output array is created. `device` can be `None`, a oneAPI filter selector string, an instance of @@ -322,7 +350,7 @@ def ol_dpnp_zeros( errors.TypingError: If couldn't parse input types to dpnp.zeros(). Returns: - function: Local function `impl_dpnp_zeros()` + function: Local function `impl_dpnp_zeros()`. """ _ndim = _ty_parse_shape(shape) @@ -347,12 +375,20 @@ def impl( shape, dtype=None, order="C", + like=None, device=None, usm_type="device", sycl_queue=None, ): return impl_dpnp_zeros( - shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + shape, + _dtype, + order, + like, + _device, + _usm_type, + sycl_queue, + ret_ty, ) return impl @@ -370,6 +406,7 @@ def ol_dpnp_ones( shape, dtype=None, order="C", + like=None, device=None, usm_type="device", sycl_queue=None, @@ -378,13 +415,21 @@ def ol_dpnp_ones( a jit function. Args: - shape (tuple): Dimensions of the array to be created. + shape (numba.core.types.containers.UniTuple or + numba.core.types.scalars.IntegerLiteral): Dimensions + of the array to be created. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + like (numba.core.types.npytypes.Array, optional): A type for + reference object to allow the creation of arrays which are not + `NumPy` arrays. If an array-like passed in as `like` supports the + `__array_function__` protocol, the result will be defined by it. + In this case, it ensures the creation of an array object + compatible with that passed in via this argument. device (numba.core.types.misc.StringLiteral, optional): array API concept of device where the output array is created. `device` can be `None`, a oneAPI filter selector string, an instance of @@ -403,7 +448,7 @@ def ol_dpnp_ones( errors.TypingError: If couldn't parse input types to dpnp.ones(). Returns: - function: Local function `impl_dpnp_ones()` + function: Local function `impl_dpnp_ones()`. """ _ndim = _ty_parse_shape(shape) @@ -428,12 +473,20 @@ def impl( shape, dtype=None, order="C", + like=None, device=None, usm_type="device", sycl_queue=None, ): return impl_dpnp_ones( - shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + shape, + _dtype, + order, + like, + _device, + _usm_type, + sycl_queue, + ret_ty, ) return impl @@ -448,9 +501,10 @@ def impl( @overload(dpnp.empty_like, prefer_literal=True) def ol_dpnp_empty_like( - x, + x1, dtype=None, order="C", + subok=False, shape=None, device=None, usm_type=None, @@ -461,14 +515,19 @@ def ol_dpnp_empty_like( This is an overloaded function implementation for dpnp.empty_like(). Args: - x (numba.core.types.npytypes.Array): Input array from which to + x1 (numba.core.types.npytypes.Array): Input array from which to derive the output array shape. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + subok ('numba.core.types.scalars.BooleanLiteral', optional): A + boolean literal type for the `subok` parameter defined in + NumPy. If True, then the newly created array will use the + sub-class type of prototype, otherwise it will be a + base-class array. Defaults to False. shape (numba.core.types.containers.UniTuple, optional): The shape to override the shape of the given array. Not supported. Default: `None` @@ -490,7 +549,7 @@ def ol_dpnp_empty_like( errors.TypingError: If shape is provided. Returns: - function: Local function `impl_dpnp_empty_like()` + function: Local function `impl_dpnp_empty_like()`. """ if shape: @@ -498,9 +557,9 @@ def ol_dpnp_empty_like( "The parameter shape is not supported " + "inside overloaded dpnp.empty_like() function." ) - _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 - _dtype = _parse_dtype(dtype, data=x) - _order = x.layout if order is None else order + _ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x1) + _order = x1.layout if order is None else order _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" _device = ( _parse_device_filter_string(device) if device is not None else "unknown" @@ -516,18 +575,21 @@ def ol_dpnp_empty_like( if ret_ty: def impl( - x, + x1, dtype=None, order="C", + subok=False, shape=None, device=None, usm_type=None, sycl_queue=None, ): return impl_dpnp_empty_like( - x, + x1, _dtype, _order, + subok, + shape, _device, _usm_type, sycl_queue, @@ -538,15 +600,16 @@ def impl( else: raise errors.TypingError( "Cannot parse input types to " - + f"function dpnp.empty_like({x}, {dtype}, ...)." + + f"function dpnp.empty_like({x1}, {dtype}, ...)." ) @overload(dpnp.zeros_like, prefer_literal=True) def ol_dpnp_zeros_like( - x, + x1, dtype=None, order="C", + subok=None, shape=None, device=None, usm_type=None, @@ -557,14 +620,19 @@ def ol_dpnp_zeros_like( This is an overloaded function implementation for dpnp.zeros_like(). Args: - x (numba.core.types.npytypes.Array): Input array from which to + x1 (numba.core.types.npytypes.Array): Input array from which to derive the output array shape. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + subok ('numba.core.types.scalars.BooleanLiteral', optional): A + boolean literal type for the `subok` parameter defined in + NumPy. If True, then the newly created array will use the + sub-class type of prototype, otherwise it will be a + base-class array. Defaults to False. shape (numba.core.types.containers.UniTuple, optional): The shape to override the shape of the given array. Not supported. Default: `None` @@ -586,7 +654,7 @@ def ol_dpnp_zeros_like( errors.TypingError: If shape is provided. Returns: - function: Local function `impl_dpnp_zeros_like()` + function: Local function `impl_dpnp_zeros_like()`. """ if shape: @@ -594,9 +662,9 @@ def ol_dpnp_zeros_like( "The parameter shape is not supported " + "inside overloaded dpnp.zeros_like() function." ) - _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 - _dtype = _parse_dtype(dtype, data=x) - _order = x.layout if order is None else order + _ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x1) + _order = x1.layout if order is None else order _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" _device = ( _parse_device_filter_string(device) if device is not None else "unknown" @@ -612,18 +680,21 @@ def ol_dpnp_zeros_like( if ret_ty: def impl( - x, + x1, dtype=None, order="C", + subok=None, shape=None, device=None, usm_type=None, sycl_queue=None, ): return impl_dpnp_zeros_like( - x, + x1, _dtype, _order, + subok, + shape, _device, _usm_type, sycl_queue, @@ -634,15 +705,16 @@ def impl( else: raise errors.TypingError( "Cannot parse input types to " - + f"function dpnp.empty_like({x}, {dtype}, ...)." + + f"function dpnp.empty_like({x1}, {dtype}, ...)." ) @overload(dpnp.ones_like, prefer_literal=True) def ol_dpnp_ones_like( - x, + x1, dtype=None, order="C", + subok=None, shape=None, device=None, usm_type=None, @@ -653,14 +725,19 @@ def ol_dpnp_ones_like( This is an overloaded function implementation for dpnp.ones_like(). Args: - x (numba.core.types.npytypes.Array): Input array from which to + x1 (numba.core.types.npytypes.Array): Input array from which to derive the output array shape. dtype (numba.core.types.functions.NumberClass, optional): Data type of the array. Can be typestring, a `numpy.dtype` object, `numpy` char string, or a numpy scalar type. - Default: None + Default: None. order (str, optional): memory layout for the array "C" or "F". - Default: "C" + Default: "C". + subok ('numba.core.types.scalars.BooleanLiteral', optional): A + boolean literal type for the `subok` parameter defined in + NumPy. If True, then the newly created array will use the + sub-class type of prototype, otherwise it will be a + base-class array. Defaults to False. shape (numba.core.types.containers.UniTuple, optional): The shape to override the shape of the given array. Not supported. Default: `None` @@ -682,7 +759,7 @@ def ol_dpnp_ones_like( errors.TypingError: If shape is provided. Returns: - function: Local function `impl_dpnp_ones_like()` + function: Local function `impl_dpnp_ones_like()`. """ if shape: @@ -690,9 +767,9 @@ def ol_dpnp_ones_like( "The parameter shape is not supported " + "inside overloaded dpnp.ones_like() function." ) - _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 - _dtype = _parse_dtype(dtype, data=x) - _order = x.layout if order is None else order + _ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x1) + _order = x1.layout if order is None else order _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" _device = ( _parse_device_filter_string(device) if device is not None else "unknown" @@ -708,17 +785,21 @@ def ol_dpnp_ones_like( if ret_ty: def impl( - x, + x1, dtype=None, order="C", + subok=None, + shape=None, device=None, usm_type=None, sycl_queue=None, ): return impl_dpnp_ones_like( - x, + x1, _dtype, _order, + subok, + shape, _device, _usm_type, sycl_queue, @@ -729,5 +810,109 @@ def impl( else: raise errors.TypingError( "Cannot parse input types to " - + f"function dpnp.empty_like({x}, {dtype}, ...)." + + f"function dpnp.empty_like({x1}, {dtype}, ...)." ) + + +@overload(dpnp.full, prefer_literal=True) +def ol_dpnp_full( + shape, + fill_value, + dtype=None, + order="C", + like=None, + device=None, + usm_type=None, + sycl_queue=None, +): + """Implementation of an overload to support dpnp.full() inside + a jit function. + + Args: + shape (numba.core.types.containers.UniTuple or + numba.core.types.scalars.IntegerLiteral): Dimensions + of the array to be created. + fill_value (numba.core.types.scalars): One of the + numba.core.types.scalar types for the value to + be filled. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None. + order (str, optional): memory layout for the array "C" or "F". + Default: "C". + like (numba.core.types.npytypes.Array, optional): A type for + reference object to allow the creation of arrays which are not + `NumPy` arrays. If an array-like passed in as `like` supports the + `__array_function__` protocol, the result will be defined by it. + In this case, it ensures the creation of an array object + compatible with that passed in via this argument. + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. + + Raises: + errors.TypingError: If rank of the ndarray couldn't be inferred. + errors.TypingError: If couldn't parse input types to dpnp.full(). + + Returns: + function: Local function `impl_dpnp_full()`. + """ + + _ndim = _ty_parse_shape(shape) + _dtype = _parse_dtype(dtype) + _layout = _parse_layout(order) + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + if _ndim: + ret_ty = build_dpnp_ndarray( + _ndim, + layout=_layout, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: + + def impl( + shape, + fill_value, + dtype=None, + order="C", + like=None, + device=None, + usm_type=None, + sycl_queue=None, + ): + return impl_dpnp_full( + shape, + fill_value, + _dtype, + order, + like, + _device, + _usm_type, + sycl_queue, + ret_ty, + ) + + return impl + else: + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.full({shape}, {fill_value}, {dtype}, ...)." + ) + else: + raise errors.TypingError("Could not infer the rank of the ndarray.") diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py index 798f03b7d8..775b337109 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py @@ -10,7 +10,7 @@ from numba_dpex import dpjit -shapes = [10, (2, 5)] +shapes = [11, (2, 5)] dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64] usm_types = ["device", "shared", "host"] devices = ["cpu", "unknown"] @@ -22,14 +22,12 @@ @pytest.mark.parametrize("device", devices) def test_dpnp_empty(shape, dtype, usm_type, device): @dpjit - def func1(shape): - c = dpnp.empty( - shape=shape, dtype=dtype, usm_type=usm_type, device=device - ) + def func(shape): + c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, device=device) return c try: - c = func1(shape) + c = func(shape) except Exception: pytest.fail("Calling dpnp.empty inside dpjit failed") @@ -52,12 +50,12 @@ def func1(shape): @pytest.mark.parametrize("shape", shapes) def test_dpnp_empty_default_dtype(shape): @dpjit - def func1(shape): - c = dpnp.empty(shape=shape) + def func(shape): + c = dpnp.empty(shape) return c try: - c = func1(shape) + c = func(shape) except Exception: pytest.fail("Calling dpnp.empty inside dpjit failed") @@ -66,10 +64,6 @@ def func1(shape): else: assert c.shape == shape - dummy_tensor = dpctl.tensor.empty(shape=1) - - assert c.dtype == dummy_tensor.dtype - dummy_tensor = dpctl.tensor.empty(shape) assert c.dtype == dummy_tensor.dtype diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_full.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_full.py new file mode 100644 index 0000000000..b52e307187 --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_full.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dpnp ndarray constructors.""" + +import math + +import dpctl +import dpctl.tensor as dpt +import dpnp +import numpy +import pytest + +from numba_dpex import dpjit + +shapes = [11, (3, 7)] +dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64] +usm_types = ["device", "shared", "host"] +devices = ["cpu", "unknown"] +fill_values = [ + 7, + -7, + 7.1, + -7.1, + math.pi, + math.e, + 4294967295, + 4294967295.0, + 3.4028237e38, +] + + +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("fill_value", fill_values) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("usm_type", usm_types) +@pytest.mark.parametrize("device", devices) +def test_dpnp_full(shape, fill_value, dtype, usm_type, device): + @dpjit + def func(shape, fill_value): + c = dpnp.full( + shape, fill_value, dtype=dtype, usm_type=usm_type, device=device + ) + return c + + a = numpy.full(shape, fill_value, dtype=dtype) + + try: + c = func(shape, fill_value) + except Exception: + pytest.fail("Calling dpnp.full inside dpjit failed") + + if len(c.shape) == 1: + assert c.shape[0] == shape + else: + assert c.shape == shape + + assert c.dtype == dtype + assert c.usm_type == usm_type + if device != "unknown": + assert ( + c.sycl_device.filter_string + == dpctl.SyclDevice(device).filter_string + ) + else: + c.sycl_device.filter_string == dpctl.SyclDevice().filter_string + + assert numpy.array_equal(dpt.asnumpy(c._array_obj), a) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py index fd69038728..34dbcaf457 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("device", devices) def test_dpnp_ones(shape, dtype, usm_type, device): @dpjit - def func1(shape): + def func(shape): c = dpnp.ones( shape=shape, dtype=dtype, usm_type=usm_type, device=device ) @@ -33,9 +33,9 @@ def func1(shape): a = numpy.ones(shape, dtype=dtype) try: - c = func1(shape) + c = func(shape) except Exception: - pytest.fail("Calling dpnp.empty inside dpjit failed") + pytest.fail("Calling dpnp.ones inside dpjit failed") if len(c.shape) == 1: assert c.shape[0] == shape diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones_like.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones_like.py index 64d1accba6..d360a65ffe 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones_like.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones_like.py @@ -23,10 +23,10 @@ @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("usm_type", usm_types) @pytest.mark.parametrize("device", devices) -def test_dpnp_ones(shape, dtype, usm_type, device): +def test_dpnp_ones_like(shape, dtype, usm_type, device): @dpjit def func1(a): - c = dpnp.ones(a, dtype=dtype, usm_type=usm_type, device=device) + c = dpnp.ones_like(a, dtype=dtype, usm_type=usm_type, device=device) return c if isinstance(shape, int): @@ -35,9 +35,9 @@ def func1(a): NZ = numpy.random.rand(*shape) try: - c = func1(shape) + c = func1(NZ) except Exception: - pytest.fail("Calling dpnp.empty inside dpjit failed") + pytest.fail("Calling dpnp.ones_like inside dpjit failed") if len(c.shape) == 1: assert c.shape[0] == NZ.shape[0] diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py index 854cceb9bc..e63fee390d 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py @@ -24,18 +24,16 @@ @pytest.mark.parametrize("device", devices) def test_dpnp_zeros(shape, dtype, usm_type, device): @dpjit - def func1(shape): - c = dpnp.zeros( - shape=shape, dtype=dtype, usm_type=usm_type, device=device - ) + def func(shape): + c = dpnp.zeros(shape, dtype=dtype, usm_type=usm_type, device=device) return c a = numpy.zeros(shape, dtype=dtype) try: - c = func1(shape) + c = func(shape) except Exception: - pytest.fail("Calling dpnp.empty inside dpjit failed") + pytest.fail("Calling dpnp.zeros inside dpjit failed") if len(c.shape) == 1: assert c.shape[0] == shape diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros_like.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros_like.py index 2f8cdd89d6..a1fe81e611 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros_like.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros_like.py @@ -23,10 +23,10 @@ @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("usm_type", usm_types) @pytest.mark.parametrize("device", devices) -def test_dpnp_zeros(shape, dtype, usm_type, device): +def test_dpnp_zeros_like(shape, dtype, usm_type, device): @dpjit - def func1(a): - c = dpnp.zeros(a, dtype=dtype, usm_type=usm_type, device=device) + def func(a): + c = dpnp.zeros_like(a, dtype=dtype, usm_type=usm_type, device=device) return c if isinstance(shape, int): @@ -35,9 +35,9 @@ def func1(a): NZ = numpy.random.rand(*shape) try: - c = func1(shape) + c = func(NZ) except Exception: - pytest.fail("Calling dpnp.empty inside dpjit failed") + pytest.fail("Calling dpnp.zeros_like inside dpjit failed") if len(c.shape) == 1: assert c.shape[0] == NZ.shape[0]