Skip to content

Commit

Permalink
Merge pull request #1329 from IntelPython/refactor/make_kernel_api_im…
Browse files Browse the repository at this point in the history
…pl_public

Rename _kernel_api_impl and pylint changes.
  • Loading branch information
Diptorup Deb authored Feb 12, 2024
2 parents c20d53a + 73c8fb3 commit 12f89a4
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 114 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from numba_dpex.core.kernel_interface.launcher import call_kernel

from ._kernel_api_impl.spirv import target as spirv_kernel_target
from .kernel_api_impl.spirv import target as spirv_kernel_target
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc


Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from numba.core.cpu import CPUTargetOptions
from numba.core.descriptors import TargetDescriptor

from numba_dpex._kernel_api_impl.spirv.target import (
from numba_dpex.core import config
from numba_dpex.kernel_api_impl.spirv.target import (
SPIRV_TARGET_NAME,
CompilationMode,
SPIRVTargetContext,
SPIRVTypingContext,
)
from numba_dpex.core import config

from .targets.dpjit_target import (
DPEX_TARGET_NAME,
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from numba.core import ir

from numba_dpex._kernel_api_impl.spirv import spirv_generator
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.core import config
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
from numba_dpex.kernel_api_impl.spirv import spirv_generator
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from .kernel_base import KernelInterface

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from numba.np.arrayobj import make_array
from numba.np.numpy_support import is_nonelike

from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.core.kernel_interface.arrayobj import (
_getitem_array_generic as kernel_getitem_array_generic,
)
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from ._intrinsic import (
impl_dpnp_empty,
Expand Down
3 changes: 1 addition & 2 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from numba.core.imputils import Registry

from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher

# Temporary so that Range and NdRange work in experimental call_kernel
from numba_dpex.core.boxing import *
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher

from ._kernel_dpcpp_spirv_overloads import (
_atomic_ref_overloads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from numba.core import cgutils, types
from numba.extending import intrinsic, overload, overload_method

from numba_dpex._kernel_api_impl.spirv.target import (
CC_SPIR_FUNC,
LLVM_SPIRV_ARGS,
)
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.types import USMNdArray
from numba_dpex.kernel_api import (
Expand All @@ -24,6 +20,10 @@
MemoryScope,
)
from numba_dpex.kernel_api.flag_enum import FlagEnum
from numba_dpex.kernel_api_impl.spirv.target import (
CC_SPIR_FUNC,
LLVM_SPIRV_ARGS,
)

from ..dpcpp_types import AtomicRefType
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method

from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from ..target import DPEX_KERNEL_EXP_TARGET_NAME

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex._kernel_api_impl.spirv.target import CC_SPIR_FUNC
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.kernel_api_impl.spirv.target import CC_SPIR_FUNC


def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
target_registry,
)

from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
from numba_dpex._kernel_api_impl.spirv.target import CompilationMode
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
from numba_dpex.kernel_api_impl.spirv.target import CompilationMode

from .target import DPEX_KERNEL_EXP_TARGET_NAME

Expand Down
10 changes: 5 additions & 5 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from numba.extending import intrinsic

from numba_dpex import dpjit
from numba_dpex._kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
from numba_dpex.core.types import DpctlSyclEvent, NdRangeType, RangeType
from numba_dpex.core.utils import kernel_launcher as kl
Expand All @@ -33,6 +28,11 @@
ItemType,
NdItemType,
)
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext


class LLRange(NamedTuple):
Expand Down
6 changes: 3 additions & 3 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from numba.core.descriptors import TargetDescriptor
from numba.core.target_extension import GPU, target_registry

from numba_dpex._kernel_api_impl.spirv.target import (
from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.experimental.models import exp_dmm
from numba_dpex.kernel_api_impl.spirv.target import (
SPIRVTargetContext,
SPIRVTypingContext,
)
from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.experimental.models import exp_dmm


# pylint: disable=R0903
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Code generator for a LLVM module for SPIR-V kernels.
"""

import warnings

from llvmlite import binding as ll
Expand All @@ -26,6 +29,10 @@


class SPIRVCodeLibrary(CPUCodeLibrary):
"""A Numba code library that stores a spir_kernel function and all the
internally defined spir_func functions called from the spir_kernel function.
"""

def _optimize_functions(self, ll_module):
pass

Expand All @@ -36,8 +43,8 @@ def inline_threshold(self):
"""
if hasattr(self, "_inline_threshold"):
return self._inline_threshold
else:
return 0

return 0

@inline_threshold.setter
def inline_threshold(self, value: int):
Expand Down Expand Up @@ -105,6 +112,7 @@ def get_asm_str(self):

@property
def final_module(self):
"""Return the final SPIR-V module after it has been finalized."""
return self._final_module


Expand All @@ -116,7 +124,7 @@ class JITSPIRVCodegen(CPUCodegen):
_library_class = SPIRVCodeLibrary

def _init(self, llvm_module):
assert list(llvm_module.global_variables) == [], "Module isn't empty"
assert not list(llvm_module.global_variables), "Module isn't empty"
self._data_layout = SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
self._target_data = ll.create_target_data(self._data_layout)
self._tm_features = (
Expand All @@ -130,10 +138,15 @@ def _create_empty_module(self, name):
ir_module.data_layout = self._data_layout
return ir_module

def _module_pass_manager(self):
def create_empty_spirv_module(self, name):
"""Public method to create an empty LLVM Module with SPIR-V layout."""

return self._create_empty_module(name)

def _module_pass_manager(self, **kwargs):
raise NotImplementedError

def _function_pass_manager(self, llvm_module):
def _function_pass_manager(self, llvm_module, **kwargs):
raise NotImplementedError

def _add_module(self, module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
from numba.core.typing.typeof import Purpose, typeof

from numba_dpex import config, numba_sem_version
from numba_dpex._kernel_api_impl.spirv import spirv_generator
from numba_dpex._kernel_api_impl.spirv.codegen import SPIRVCodeLibrary
from numba_dpex._kernel_api_impl.spirv.target import (
CompilationMode,
SPIRVTargetContext,
)
from numba_dpex.core.exceptions import (
ExecutionQueueInferenceError,
InvalidKernelSpecializationError,
Expand All @@ -45,6 +39,12 @@
DPEX_KERNEL_EXP_TARGET_NAME,
dpex_exp_kernel_target,
)
from numba_dpex.kernel_api_impl.spirv import spirv_generator
from numba_dpex.kernel_api_impl.spirv.codegen import SPIRVCodeLibrary
from numba_dpex.kernel_api_impl.spirv.target import (
CompilationMode,
SPIRVTargetContext,
)

_SPIRVKernelCompileResult = namedtuple(
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
Expand Down Expand Up @@ -314,6 +314,7 @@ def __init__(
self._kernel_name = pyfunc.__name__

if numba_sem_version < (0, 59, 0):
# pylint: disable=unexpected-keyword-arg
super().__init__(
py_func=pyfunc,
locals=local_vars_to_numba_types,
Expand Down
Loading

0 comments on commit 12f89a4

Please sign in to comment.