diff --git a/numba_dpex/core/descriptor.py b/numba_dpex/core/descriptor.py index 39ac5fe816..369838ed25 100644 --- a/numba_dpex/core/descriptor.py +++ b/numba_dpex/core/descriptor.py @@ -96,6 +96,10 @@ def typing_context(self): """ return self._toplevel_typing_context + @property + def target_name(self): + return self._target_name + class DpexTarget(TargetDescriptor): """ diff --git a/numba_dpex/core/parfors/kernel_builder.py b/numba_dpex/core/parfors/kernel_builder.py index 0266ca43a4..a6dbcb5d66 100644 --- a/numba_dpex/core/parfors/kernel_builder.py +++ b/numba_dpex/core/parfors/kernel_builder.py @@ -21,10 +21,12 @@ rename_labels, replace_var_names, ) +from numba.core.target_extension import target_override from numba.core.typing import signature from numba.parfors import parfor from numba_dpex.core import config +from numba_dpex.core.types.kernel_api.index_space_ids import ItemType from numba_dpex.kernel_api_impl.spirv import spirv_generator from ..descriptor import dpex_kernel_target @@ -66,18 +68,18 @@ def _print_body(body_dict): def _compile_kernel_parfor( sycl_queue, kernel_name, func_ir, argtypes, debug=False ): - - cres = compile_numba_ir_with_dpex( - pyfunc=func_ir, - pyfunc_name=kernel_name, - args=argtypes, - return_type=None, - debug=debug, - is_kernel=True, - typing_context=dpex_kernel_target.typing_context, - target_context=dpex_kernel_target.target_context, - extra_compile_flags=None, - ) + with target_override(dpex_kernel_target.target_context.target_name): + cres = compile_numba_ir_with_dpex( + pyfunc=func_ir, + pyfunc_name=kernel_name, + args=argtypes, + return_type=None, + debug=debug, + is_kernel=True, + typing_context=dpex_kernel_target.typing_context, + target_context=dpex_kernel_target.target_context, + extra_compile_flags=None, + ) cres.library.inline_threshold = config.INLINE_THRESHOLD cres.library._optimize_final_module() func = cres.library.get_function(cres.fndesc.llvm_func_name) @@ -420,6 +422,13 @@ def create_kernel_for_parfor( print("kernel_ir after remove dead") kernel_ir.dump() + # The first argument to a range kernel is a kernel_api.Item object. The + # ``Item`` object is used by the kernel_api.spirv backend to generate the + # correct SPIR-V indexing instructions. Since, the argument is not something + # available originally in the kernel_param_types, we add it at this point to + # make sure the kernel signature matches the actual generated code. + ty_item = ItemType(parfor_dim) + kernel_param_types = (ty_item, *kernel_param_types) kernel_sig = signature(types.none, *kernel_param_types) if config.DEBUG_ARRAY_OPT: diff --git a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py index 026a707ba4..ecbfa5df4c 100644 --- a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py +++ b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py @@ -63,7 +63,7 @@ def _generate_kernel_stub_as_string(self): # Create the dpex kernel function. kernel_txt += "def " + self._kernel_name - kernel_txt += "(" + (", ".join(self._kernel_params)) + "):\n" + kernel_txt += "(item, " + (", ".join(self._kernel_params)) + "):\n" global_id_dim = 0 for_loop_dim = self._kernel_rank global_id_dim = self._kernel_rank @@ -71,7 +71,7 @@ def _generate_kernel_stub_as_string(self): for dim in range(global_id_dim): dimstr = str(dim) kernel_txt += ( - f" {self._ivar_names[dim]} = dpex.get_global_id({dimstr})\n" + f" {self._ivar_names[dim]} = item.get_id({dimstr})\n" ) for dim in range(global_id_dim, for_loop_dim):