From ed8cb707f30934cdcbfda973182c6146574c3517 Mon Sep 17 00:00:00 2001 From: khaled Date: Tue, 31 Oct 2023 02:54:59 -0500 Subject: [PATCH] Functional dpnp.arange for int and double types --- numba_dpex/core/runtime/kernels/api.h | 18 +- numba_dpex/core/runtime/kernels/sequences.cpp | 39 +- numba_dpex/core/runtime/kernels/sequences.hpp | 26 +- numba_dpex/core/runtime/kernels/types.hpp | 64 +++ numba_dpex/dpnp_iface/array_sequence_ops.py | 365 +++++++++++++++--- 5 files changed, 429 insertions(+), 83 deletions(-) diff --git a/numba_dpex/core/runtime/kernels/api.h b/numba_dpex/core/runtime/kernels/api.h index e3b0fff4de..8ee819bba8 100644 --- a/numba_dpex/core/runtime/kernels/api.h +++ b/numba_dpex/core/runtime/kernels/api.h @@ -5,21 +5,35 @@ #include "dpctl_capi.h" #include "dpctl_sycl_interface.h" +#pragma once + #ifdef __cplusplus extern "C" { #endif - + // Dispatch vector initializer functions. void NUMBA_DPEX_SYCL_KERNEL_init_sequence_step_dispatch_vectors(); void NUMBA_DPEX_SYCL_KERNEL_init_affine_sequence_dispatch_vectors(); + + // Call linear sequences dispatch functions. uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( void *start, void *dt, arystruct_t *dst, int ndim, u_int8_t is_c_contiguous, + int dst_typeid, + const DPCTLSyclQueueRef exec_q); + + // Call linear affine sequences dispatch functions. + uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_affine_sequence( + void *start, + void *end, + arystruct_t *dst, + u_int8_t include_endpoint, + int ndim, + u_int8_t is_c_contiguous, const DPCTLSyclQueueRef exec_q); - // const DPCTLEventVectorRef depends = std::vector()); #ifdef __cplusplus } diff --git a/numba_dpex/core/runtime/kernels/sequences.cpp b/numba_dpex/core/runtime/kernels/sequences.cpp index 422c6f77f2..9bbc44dce0 100644 --- a/numba_dpex/core/runtime/kernels/sequences.cpp +++ b/numba_dpex/core/runtime/kernels/sequences.cpp @@ -42,14 +42,20 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( arystruct_t *dst, int ndim, u_int8_t is_c_contiguous, + int dst_typeid, const DPCTLSyclQueueRef exec_q) -// const DPCTLEventVectorRef depends = std::vector()) { std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:" - << " start = " << *(reinterpret_cast(start)) << std::endl; + << " start = " + << ndpx::runtime::kernel::types::caste_using_typeid(start, + dst_typeid) + << std::endl; std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:" - << " dt = " << *(reinterpret_cast(dt)) << std::endl; + << " dt = " + << ndpx::runtime::kernel::types::caste_using_typeid(dt, + dst_typeid) + << std::endl; if (ndim != 1) { throw std::logic_error( @@ -60,26 +66,15 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( "populate_arystruct_linseq(): array must be c-contiguous."); } - /** - auto array_types = td_ns::usm_ndarray_types(); - int dst_typenum = dst.get_typenum(); - int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - */ - - size_t len = static_cast(dst->nitems); // dst.get_shape(0); - if (len == 0) { - // nothing to do - // return std::make_pair(sycl::event{}, sycl::event{}); + size_t len = static_cast(dst->nitems); + if (len == 0) return 0; - } std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:" << " len = " << len << std::endl; char *dst_data = reinterpret_cast(dst->data); - const int dst_typeid = 7; // 7 = int64_t, 10 = float, 11 = double - // int64_t *_start = reinterpret_cast(start); - // int64_t *_dt = reinterpret_cast(dt); + // int dst_typeid = 7; // 7 = int64_t, 10 = float, 11 = double auto fn = sequence_step_dispatch_vector[dst_typeid]; sycl::queue *queue = reinterpret_cast(exec_q); @@ -87,10 +82,14 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( sycl::event linspace_step_event = fn(*queue, len, start, dt, dst_data, depends); - /*return std::make_pair(keep_args_alive(exec_q, {dst}, - {linspace_step_event}), linspace_step_event);*/ + linspace_step_event.wait_and_throw(); - return 1; + if (linspace_step_event + .get_info() == + sycl::info::event_command_status::complete) + return 0; + else + return 1; } // uint ndpx::runtime::kernel::tensor::populate_arystruct_affine_sequence( diff --git a/numba_dpex/core/runtime/kernels/sequences.hpp b/numba_dpex/core/runtime/kernels/sequences.hpp index 2085179bac..482de83543 100644 --- a/numba_dpex/core/runtime/kernels/sequences.hpp +++ b/numba_dpex/core/runtime/kernels/sequences.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -111,7 +112,17 @@ sycl::event sequence_step_kernel(sycl::queue exec_q, char *array_data, const std::vector &depends) { + std::cout << "sequqnce_step_kernel<" + << ndpx::runtime::kernel::types::demangle() + << ">(): nelems = " << nelems << ", start_v = " << start_v + << ", step_v = " << step_v << std::endl; + ndpx::runtime::kernel::types::validate_type_for_device(exec_q); + + std::cout << "sequqnce_step_kernel<" + << ndpx::runtime::kernel::types::demangle() + << ">(): validate_type_for_device(exec_q) = done" << std::endl; + sycl::event seq_step_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.parallel_for>( @@ -170,6 +181,11 @@ sycl::event sequence_step(sycl::queue &exec_q, std::cerr << e.what() << std::endl; } + std::cout << "sequqnce_step()<" + << ndpx::runtime::kernel::types::demangle() + << ">: nelems = " << nelems << ", *start_v = " << (*start_v) + << ", *step_v = " << (*step_v) << std::endl; + auto sequence_step_event = sequence_step_kernel( exec_q, nelems, *start_v, *step_v, array_data, depends); @@ -232,16 +248,6 @@ typedef sycl::event (*affine_sequence_ptr_t)(sycl::queue &, bool, // include_endpoint char *, // dst_data_ptr const std::vector &); - -// uint populate_arystruct_affine_sequence(void *start, -// void *end, -// arystruct_t *dst, -// int include_endpoint, -// int ndim, -// int is_c_contiguous, -// const DPCTLSyclQueueRef exec_q, -// const DPCTLEventVectorRef depends); - } // namespace tensor } // namespace kernel } // namespace runtime diff --git a/numba_dpex/core/runtime/kernels/types.hpp b/numba_dpex/core/runtime/kernels/types.hpp index 9cff097bd7..4fc2d7b70e 100644 --- a/numba_dpex/core/runtime/kernels/types.hpp +++ b/numba_dpex/core/runtime/kernels/types.hpp @@ -1,9 +1,12 @@ #ifndef __TYPES_HPP__ #define __TYPES_HPP__ +#include #include #include #include +#include +#include #include namespace ndpx @@ -43,6 +46,58 @@ enum class typenum_t : int constexpr int num_types = 14; // number of elements in typenum_t +template std::string demangle() +{ + char const *mangled = typeid(T).name(); + char *c_demangled; + int status = 0; + c_demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status); + + std::string res; + if (c_demangled) { + res = c_demangled; + free(c_demangled); + } + else { + res = mangled; + free(c_demangled); + } + return res; +} + +std::string caste_using_typeid(void *value, int _typeid) +{ + switch (_typeid) { + case 0: + return std::to_string(*(reinterpret_cast(value))); + case 1: + return std::to_string(*(reinterpret_cast(value))); + case 2: + return std::to_string(*(reinterpret_cast(value))); + case 3: + return std::to_string(*(reinterpret_cast(value))); + case 4: + return std::to_string(*(reinterpret_cast(value))); + case 5: + return std::to_string(*(reinterpret_cast(value))); + case 6: + return std::to_string(*(reinterpret_cast(value))); + case 7: + return std::to_string(*(reinterpret_cast(value))); + case 8: + return std::to_string(*(reinterpret_cast(value))); + case 9: + return std::to_string(*(reinterpret_cast(value))); + case 10: + return std::to_string(*(reinterpret_cast(value))); + case 11: + return std::to_string(*(reinterpret_cast(value))); + default: + throw std::runtime_error(std::to_string(_typeid) + + " could't be mapped to valid data type."); + } +} + template dstTy convert_impl(const srcTy &v) { if constexpr (std::is_same::value) { @@ -75,6 +130,9 @@ template dstTy convert_impl(const srcTy &v) template void validate_type_for_device(const sycl::device &d) { if constexpr (std::is_same_v) { + std::cout + << "ndpx::runtime::kernel::types::validate_type_for_device(): here0" + << std::endl; if (!d.has(sycl::aspect::fp64)) { throw std::runtime_error("Device " + d.get_info() + @@ -82,6 +140,9 @@ template void validate_type_for_device(const sycl::device &d) } } else if constexpr (std::is_same_v>) { + std::cout + << "ndpx::runtime::kernel::types::validate_type_for_device(): here1" + << std::endl; if (!d.has(sycl::aspect::fp64)) { throw std::runtime_error("Device " + d.get_info() + @@ -89,6 +150,9 @@ template void validate_type_for_device(const sycl::device &d) } } else if constexpr (std::is_same_v) { + std::cout + << "ndpx::runtime::kernel::types::validate_type_for_device(): here2" + << std::endl; if (!d.has(sycl::aspect::fp16)) { throw std::runtime_error("Device " + d.get_info() + diff --git a/numba_dpex/dpnp_iface/array_sequence_ops.py b/numba_dpex/dpnp_iface/array_sequence_ops.py index a389556040..b63ebd4e9b 100644 --- a/numba_dpex/dpnp_iface/array_sequence_ops.py +++ b/numba_dpex/dpnp_iface/array_sequence_ops.py @@ -9,8 +9,14 @@ from llvmlite import ir as llvmir from numba import errors, types from numba.core import cgutils -from numba.core.types.misc import UnicodeType -from numba.core.types.scalars import Complex, Float, Integer, IntegerLiteral +from numba.core.types.misc import NoneType, UnicodeType +from numba.core.types.scalars import ( + Boolean, + Complex, + Float, + Integer, + IntegerLiteral, +) from numba.core.typing.templates import Signature from numba.extending import intrinsic, overload @@ -35,21 +41,69 @@ ) +def _is_any_float_type(value): + return ( + type(value) == float + or isinstance(value, np.floating) + or isinstance(value, Float) + ) + + +def _is_any_int_type(value): + return ( + type(value) == int + or isinstance(value, np.integer) + or isinstance(value, Integer) + ) + + +def _is_any_complex_type(value): + return np.iscomplex(value) or isinstance(value, Complex) + + +def _compute_bitwidth(value): + print("_compute_bitwidth(): type(value) =", type(value)) + if ( + isinstance(value, Float) + or isinstance(value, Integer) + or isinstance(value, Complex) + ): + return value.bitwidth + elif ( + isinstance(value, np.floating) + or isinstance(value, np.integer) + or np.iscomplex(value) + ): + return value.itemsize * 8 + elif type(value) == float or type(value) == int: + return 64 + elif type(value) == complex: + return 128 + else: + msg = "dpnp_iface.array_sequence_ops._compute_bitwidth(): Unknwon type." + raise errors.NumbaValueError(msg) + + def _parse_dtype_from_range(start, stop, step): - max_bw = max(start.bitwidth, stop.bitwidth, step.bitwidth) + max_bw = max( + _compute_bitwidth(start), + _compute_bitwidth(stop), + _compute_bitwidth(step), + ) if ( - isinstance(start, Complex) - or isinstance(stop, Complex) - or isinstance(step, Complex) + _is_any_complex_type(start) + or _is_any_complex_type(stop) + or _is_any_complex_type(step) ): - if max_bw == 128: - return numba.from_dtype(dpnp.complex128) - else: - return numba.from_dtype(dpnp.complex64) + return ( + numba.from_dtype(dpnp.complex128) + if max_bw == 128 + else numba.from_dtype(dpnp.complex64) + ) elif ( - isinstance(start, Float) - or isinstance(stop, Float) - or isinstance(step, Float) + _is_any_float_type(start) + or _is_any_float_type(stop) + or _is_any_float_type(step) ): if max_bw == 64: return numba.from_dtype(dpnp.float64) @@ -60,9 +114,9 @@ def _parse_dtype_from_range(start, stop, step): else: return numba.from_dtype(dpnp.float) elif ( - isinstance(start, Integer) - or isinstance(stop, Integer) - or isinstance(step, Integer) + _is_any_int_type(start) + or _is_any_int_type(stop) + or _is_any_int_type(step) ): if max_bw == 64: return numba.from_dtype(dpnp.int64) @@ -71,10 +125,117 @@ def _parse_dtype_from_range(start, stop, step): else: return numba.from_dtype(dpnp.int) else: - msg = "Type couldn't be inferred from (start, stop, step)." + msg = ( + "dpnp_iface.array_sequence_ops._parse_dtype_from_range(): " + + "Types couldn't be inferred for (start, stop, step)." + ) raise errors.NumbaValueError(msg) +def _get_llvm_type(numba_type): + if isinstance(numba_type, Integer): + return llvmir.IntType(numba_type.bitwidth) + elif isinstance(numba_type, Float): + if numba_type.bitwidth == 64: + return llvmir.DoubleType() + elif numba_type.bitwidth == 32: + return llvmir.FloatType() + elif numba_type.bitwidth == 16: + return llvmir.HalfType() + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_llvm_type(): " + + f"Incompatible bitwidth in {numba_type}." + ) + raise errors.NumbaTypeError(msg) + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_llvm_type(): " + + "Incompatible numba type." + ) + raise errors.NumbaTypeError(msg) + + +def _get_constant(context, dtype, bitwidth, value): + if isinstance(dtype, Integer): + if bitwidth == 64: + return context.get_constant(types.int64, value) + elif bitwidth == 32: + return context.get_constant(types.int32, value) + elif bitwidth == 16: + return context.get_constant(types.int16, value) + elif bitwidth == 8: + return context.get_constant(types.int8, value) + elif isinstance(dtype, Float): + if bitwidth == 64: + return context.get_constant(types.float64, value) + elif bitwidth == 32: + return context.get_constant(types.float32, value) + elif bitwidth == 16: + return context.get_constant(types.float16, value) + elif isinstance(dtype, Complex): + if bitwidth == 128: + return context.get_constant(types.complex128, value) + elif bitwidth == 64: + return context.get_constant(types.complex64, value) + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_constant():" + + " Couldn't infer type for the requested constant." + ) + raise errors.NumbaTypeError(msg) + + +def _get_dst_typeid(dtype): + if isinstance(dtype, Boolean): + return 0 + elif isinstance(dtype, Integer): + if dtype.bitwidth == 8: + return 1 if dtype.signed else 2 + elif dtype.bitwidth == 16: + return 3 if dtype.signed else 4 + elif dtype.bitwidth == 32: + return 5 if dtype.signed else 6 + elif dtype.bitwidth == 64: + return 7 if dtype.signed else 8 + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_dst_typeid(): " + + f"Couldn't map {dtype} to dst_index." + ) + raise errors.NumbaValueError(msg) + elif isinstance(dtype, Float): + if dtype.bitwidth == 16: + return 9 + elif dtype.bitwidth == 32: + return 10 + elif dtype.bitwidth == 64: + return 11 + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_dst_typeid(): " + + f"Couldn't map {dtype} to dst_index." + ) + raise errors.NumbaValueError(msg) + elif isinstance(dtype, Complex): + if dtype.bitwidth == 64: + return 12 + elif dtype.bitwidth == 128: + return 13 + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_dst_typeid(): " + + f"Couldn't map {dtype} to dst_index." + ) + raise errors.NumbaValueError(msg) + else: + msg = ( + "dpnp_iface.array_sequence_ops._get_dst_typeid(): " + + f"Unknown numba type {dtype}" + ) + raise errors.NumbaTypeError(msg) + + @intrinsic def impl_dpnp_arange( ty_context, @@ -102,19 +263,117 @@ def impl_dpnp_arange( sycl_queue_arg_pos = -2 def codegen(context, builder, sig, args): - start_ir, stop_ir, step_ir, queue_ir = ( + mod = builder.module + + start_ir, stop_ir, step_ir, dtype_ir, queue_ir = ( args[0], args[1], args[2], + args[3], args[sycl_queue_arg_pos], ) - queue_arg_type = sig.args[sycl_queue_arg_pos] + ( + start_arg_type, + stop_arg_type, + step_arg_type, + dtype_arg_type, + queue_arg_type, + ) = ( + sig.args[0], + sig.args[1], + sig.args[2], + sig.args[3], + sig.args[sycl_queue_arg_pos], + ) + # b = llvmir.IntType(1) # noqa: E800 + # u32 = llvmir.IntType(32) # noqa: E800 u64 = llvmir.IntType(64) - b = llvmir.IntType(1) - # f64 = llvmir.DoubleType() # noqa: E800 - mod = builder.module + # f32 = llvmir.FloatType() # noqa: E800 + f64 = llvmir.DoubleType() # noqa: E800 + # zero_u32 = context.get_constant(types.int32, 0) # noqa: E800 + # zero_u64 = context.get_constant(types.int64, 0) # noqa: E800 + # zero_f32 = context.get_constant(types.float32, 0) # noqa: E800 + zero_f64 = context.get_constant(types.float64, 0) + # one_u32 = context.get_constant(types.int32, 1) # noqa: E800 + # one_u64 = context.get_constant(types.int64, 1) # noqa: E800 + # one_f32 = context.get_constant(types.float32, 1) # noqa: E800 + one_f64 = context.get_constant(types.float64, 1) + + # ftype = _get_llvm_type(dtype_arg_type.dtype) # noqa: E800 + # utype = _get_llvm_type(dtype_arg_type.dtype) # noqa: E800 + # one = _get_constant( # noqa: E800 + # context, dtype_arg_type.dtype, dtype_arg_type.dtype.bitwidth, 1 # noqa: E800 + # ) # noqa: E800 + # zero = _get_constant( # noqa: E800 + # context, dtype_arg_type.dtype, dtype_arg_type.dtype.bitwidth, 0 # noqa: E800 + # ) # noqa: E800 + + print( + f"start_ir = {start_ir}, " + + f"start_ir.type = {start_ir.type}, " + + f"type(start_ir.type) = {type(start_ir.type)}" + ) + print( + f"step_ir = {step_ir}, " + + f"step_ir.type = {step_ir.type}, " + + f"type(step_ir.type) = {type(step_ir.type)}" + ) + print( + f"stop_ir = {stop_ir}, " + + f"stop_ir.type = {stop_ir.type}, " + + f"type(stop_ir.type) = {type(stop_ir.type)}" + ) + + # Sanity check: + # if stop is pointing to a null + # start <- 0 + # stop <- 1 + # if step is pointing to a null + # step <- 1 + # TODO: do this either in LLVMIR or outside of intrinsic + print("type(stop_arg_type) =", type(stop_arg_type)) + print("type(step_arg_type) =", type(step_arg_type)) + if isinstance(stop_arg_type, NoneType): + start_ir = zero_f64 + stop_ir = one_f64 + if isinstance(step_arg_type, NoneType): + step_ir = one_f64 + + if isinstance(start_arg_type, Integer) and isinstance( + dtype_arg_type.dtype, Float + ): + if start_arg_type.signed: + start_ir = builder.sitofp(start_ir, f64) + step_ir = builder.sitofp(step_ir, f64) + else: + start_ir = builder.uitofp(start_ir, f64) + step_ir = builder.uitofp(step_ir, f64) + + print( + f"start_ir = {start_ir}, " + + f"start_ir.type = {start_ir.type}, " + + f"type(start_ir.type) = {type(start_ir.type)}" + ) + print( + f"step_ir = {step_ir}, " + + f"step_ir.type = {step_ir.type}, " + + f"type(step_ir.type) = {type(step_ir.type)}" + ) + print( + f"stop_ir = {stop_ir}, " + + f"stop_ir.type = {stop_ir.type}, " + + f"type(stop_ir.type) = {type(stop_ir.type)}" + ) + print( + f"dtype_ir = {dtype_ir}, " + + f"dtype_ir.type = {dtype_ir.type}, " + + f"dtype_arg_type = {dtype_arg_type}, " + + f"dtype_arg_type.dtype = {dtype_arg_type.dtype}, " + + f"dtype_arg_type.dtype.bitwidth = {dtype_arg_type.dtype.bitwidth}" + ) + # Get SYCL Queue ref sycl_queue_arg = _ArgTyAndValue(queue_arg_type, queue_ir) qref_payload: _QueueRefPayload = _get_queue_ref( context=context, @@ -123,47 +382,58 @@ def codegen(context, builder, sig, args): sycl_queue_arg=sycl_queue_arg, ) - from numba.core.cpu import CPUContext - from numba.np.arrayobj import make_array - - from numba_dpex.core.datamodel.models import DpnpNdArrayModel - - # dt = builder.bitcast(builder.sdiv(t, builder.bitcast(step_ir, u64)), f64) # noqa: E800 - # dt = builder.sdiv(t, builder.bitcast(step_ir, u64)) # noqa: E800 - with builder.goto_entry_block(): start_ptr = cgutils.alloca_once(builder, start_ir.type) step_ptr = cgutils.alloca_once(builder, step_ir.type) - # dt_ptr = cgutils.alloca_once(builder, dt.type) # noqa: E800 builder.store(start_ir, start_ptr) builder.store(step_ir, step_ptr) - # builder.store(dt, dt_ptr) # noqa: E800 start_vptr = builder.bitcast(start_ptr, cgutils.voidptr_t) step_vptr = builder.bitcast(step_ptr, cgutils.voidptr_t) - # dt_vptr = builder.bitcast(dt_ptr, cgutils.voidptr_t) # noqa: E800 - t = builder.sub(stop_ir, start_ir) + ll = builder.sitofp(start_ir, f64) + ul = builder.sitofp(stop_ir, f64) + d = builder.sitofp(step_ir, f64) + + # Doing ceil(a,b) = (a-1)/b + 1 to avoid overflow + t = builder.fptosi( + builder.fadd( + builder.fdiv(builder.fsub(builder.fsub(ul, ll), one_f64), d), + one_f64, + ), + u64, + ) + + # Allocate an empty array ary = _empty_nd_impl( context, builder, sig.return_type, [t], qref_payload.queue_ref ) + + # Convert into void* arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t) + # Function parameters ndim = context.get_constant(types.intp, 1) - is_c_contguous = context.get_constant(types.boolean, 1) + is_c_contguous = context.get_constant(types.int8, 1) + typeid_index = _get_dst_typeid(dtype_arg_type.dtype) + dst_typeid = context.get_constant(types.intp, typeid_index) + # Function signature fnty = llvmir.FunctionType( utils.LLVMTypes.int64_ptr_t, [ cgutils.voidptr_t, cgutils.voidptr_t, cgutils.voidptr_t, - u64, - b, + _get_llvm_type(types.intp), + _get_llvm_type(types.int8), + _get_llvm_type(types.intp), cgutils.voidptr_t, ], ) + + # Kernel call fn = cgutils.get_or_insert_function( mod, fnty, "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence" ) @@ -175,6 +445,7 @@ def codegen(context, builder, sig, args): arrystruct_vptr, ndim, is_c_contguous, + dst_typeid, qref_payload.queue_ref, ], ) @@ -198,25 +469,26 @@ def ol_dpnp_arange( print("stop =", stop, ", type(stop) =", type(stop)) print("step =", step, ", type(step) =", type(step)) print("dtype =", dtype, ", type(dtype) =", type(dtype)) - print("device =", device, ", type(device) =", type(device)) - print("usm_type =", usm_type, ", type(usm_type) =", type(usm_type)) - print("sycl_queue =", sycl_queue, ", type(sycl_queue) =", type(sycl_queue)) print("---") if stop is None: - stop = start start = 0 + stop = 1 if step is None: step = 1 + print("start =", start, ", type(start) =", type(start)) print("stop =", stop, ", type(stop) =", type(stop)) print("step =", step, ", type(step) =", type(step)) - print("-*-") + print("***") + _dtype = ( _parse_dtype(dtype) if dtype is not None else _parse_dtype_from_range(start, stop, step) ) + print("_dtype =", _dtype, ", type(_dtype) =", type(_dtype)) + _device = _parse_device_filter_string(device) if device else None _usm_type = _parse_usm_type(usm_type) if usm_type else "device" @@ -229,15 +501,6 @@ def ol_dpnp_arange( queue=sycl_queue, ) - print("start =", start, ", type(start) =", type(start)) - print("stop =", stop, ", type(stop) =", type(stop)) - print("step =", step, ", type(step) =", type(step)) - print("_dtype =", _dtype, ", type(_dtype) =", type(_dtype)) - print("_device =", _device, ", type(_device) =", type(_device)) - print("_usm_type =", _usm_type, ", type(_usm_type) =", type(_usm_type)) - print("sycl_queue =", sycl_queue, ", type(sycl_queue) =", type(sycl_queue)) - print("***") - if ret_ty: def impl(