Skip to content

Commit

Permalink
Merge pull request IntelPython#1247 from IntelPython/feature/typing_i…
Browse files Browse the repository at this point in the history
…mprovement

Feature/typing improvement
  • Loading branch information
Diptorup Deb authored Dec 15, 2023
2 parents 3f8cdf9 + 2ec9e7e commit ea5b3bd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
6 changes: 3 additions & 3 deletions numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numba_dpex import config, spirv_generator
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext

from .kernel_base import KernelInterface

Expand Down Expand Up @@ -133,9 +134,8 @@ def compile(
)

func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = cres.target_context.prepare_spir_kernel(
func, cres.signature.args
)
kernel_targetctx: DpexKernelTargetContext = cres.target_context
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)

# XXX: Setting the inline_threshold in the following way is a temporary
# workaround till the JitKernel dispatcher is replaced by
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _submit_parfor_kernel(
kl_builder.set_arguments(
kernel_fn.kernel_arg_types, kernel_args=kernel_args
)
kl_builder.set_dependant_event_list(dep_events=[])
kl_builder.set_dependant_event_list([])
event_ref = kl_builder.submit()

sycl.dpctl_event_wait(lowerer.builder, event_ref)
Expand Down
3 changes: 1 addition & 2 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,8 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
)

def prepare_spir_kernel(self, func, argtypes):
module = func.module
func.linkage = "linkonce_odr"
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
func.module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
return wrapper

Expand Down
47 changes: 31 additions & 16 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,33 @@
from typing import Tuple

import numba.core.event as ev
from llvmlite.binding.value import ValueRef
from numba.core import errors, sigutils, types
from numba.core.compiler import CompileResult, Flags
from numba.core.compiler_lock import global_compiler_lock
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
from numba.core.funcdesc import PythonFunctionDescriptor
from numba.core.target_extension import dispatcher_registry, target_registry
from numba.core.types import void
from numba.core.typing.typeof import Purpose, typeof

from numba_dpex import config, spirv_generator
from numba_dpex.core.codegen import SPIRVCodeLibrary
from numba_dpex.core.exceptions import (
ExecutionQueueInferenceError,
KernelHasReturnValueError,
UnsupportedKernelArgumentError,
)
from numba_dpex.core.pipelines import kernel_compiler
from numba_dpex.core.targets.kernel_target import CompilationMode
from numba_dpex.core.targets.kernel_target import (
CompilationMode,
DpexKernelTargetContext,
)
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.core.utils import kernel_launcher as kl

from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target

_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])

_KernelCompileResult = namedtuple(
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
)
Expand Down Expand Up @@ -76,9 +81,14 @@ def check_queue_equivalence_of_args(
)

def _compile_to_spirv(
self, kernel_library, kernel_fndesc, kernel_targetctx
self,
kernel_library: SPIRVCodeLibrary,
kernel_fndesc: PythonFunctionDescriptor,
kernel_targetctx: DpexKernelTargetContext,
):
kernel_func = kernel_library.get_function(kernel_fndesc.llvm_func_name)
kernel_func: ValueRef = kernel_library.get_function(
kernel_fndesc.llvm_func_name
)

# Create a spir_kernel wrapper function
kernel_fn = kernel_targetctx.prepare_spir_kernel(
Expand All @@ -103,11 +113,11 @@ def _compile_to_spirv(
kernel_library.final_module,
kernel_library.final_module.as_bitcode(),
)
return _KernelModule(
return kl.SPIRVKernelModule(
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
)

def compile(self, args, return_type):
def compile(self, args, return_type) -> _KernelCompileResult:
status, kcres = self._compile_cached(args, return_type)
if status:
return kcres
Expand Down Expand Up @@ -160,8 +170,10 @@ def _compile_cached(
self.targetoptions["_compilation_mode"]
== CompilationMode.KERNEL
):
kernel_device_ir_module: _KernelModule = self._compile_to_spirv(
cres.library, cres.fndesc, cres.target_context
kernel_device_ir_module: kl.SPIRVKernelModule = (
self._compile_to_spirv(
cres.library, cres.fndesc, cres.target_context
)
)
else:
kernel_device_ir_module = None
Expand Down Expand Up @@ -329,14 +341,17 @@ def cb_llvm(dur):
# Add code to enable on disk caching of a binary spirv kernel.
# Refer: https://github.com/IntelPython/numba-dpex/issues/1197
self._cache_misses[sig] += 1
ev_details = {
"dispatcher": self,
"args": args,
"return_type": return_type,
}
with ev.trigger_event("numba_dpex:compile", data=ev_details):
with ev.trigger_event(
"numba_dpex:compile",
data={
"dispatcher": self,
"args": args,
"return_type": return_type,
},
):
try:
kcres: _KernelCompileResult = self._compiler.compile(
compiler: _KernelCompiler = self._compiler
kcres: _KernelCompileResult = compiler.compile(
args, return_type
)
except errors.ForceLiteralArg as e:
Expand Down

0 comments on commit ea5b3bd

Please sign in to comment.