diff --git a/numba_dpex/__init__.py b/numba_dpex/__init__.py index d7722a1abf..802f502ed4 100644 --- a/numba_dpex/__init__.py +++ b/numba_dpex/__init__.py @@ -107,6 +107,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]: # Re-export all type names from numba_dpex.core.types import * # noqa E402 +from numba_dpex.dpctl_iface import _intrinsic # noqa E402 from numba_dpex.dpnp_iface import dpnpimpl # noqa E402 if config.HAS_NON_HOST_DEVICE: diff --git a/numba_dpex/core/datamodel/models.py b/numba_dpex/core/datamodel/models.py index 0c85b2bb7d..8f4edf3f6f 100644 --- a/numba_dpex/core/datamodel/models.py +++ b/numba_dpex/core/datamodel/models.py @@ -280,6 +280,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager: # Register the DpctlSyclEvent type register_model(DpctlSyclEvent)(SyclEventModel) + # Register the RangeType type register_model(RangeType)(RangeModel) diff --git a/numba_dpex/core/types/dpctl_types.py b/numba_dpex/core/types/dpctl_types.py index be66f28df3..ed03439453 100644 --- a/numba_dpex/core/types/dpctl_types.py +++ b/numba_dpex/core/types/dpctl_types.py @@ -123,10 +123,7 @@ def box_sycl_queue(typ, val, c): class DpctlSyclEvent(types.Type): """A Numba type to represent a dpctl.SyclEvent PyObject.""" - def __init__(self, sycl_event): - if not isinstance(sycl_event, SyclEvent): - raise TypeError("The argument sycl_event is not of type SyclEvent.") - + def __init__(self): super(DpctlSyclEvent, self).__init__(name="DpctlSyclEvent") @property diff --git a/numba_dpex/core/typing/typeof.py b/numba_dpex/core/typing/typeof.py index 5ef88a8d0c..8c8d826418 100644 --- a/numba_dpex/core/typing/typeof.py +++ b/numba_dpex/core/typing/typeof.py @@ -121,7 +121,7 @@ def typeof_dpctl_sycl_event(val, c): Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance. """ - return DpctlSyclEvent(val) + return DpctlSyclEvent() @typeof_impl.register(Range) diff --git a/numba_dpex/dpctl_iface/_intrinsic.py b/numba_dpex/dpctl_iface/_intrinsic.py new file mode 100644 index 0000000000..e4dec22902 --- /dev/null +++ b/numba_dpex/dpctl_iface/_intrinsic.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba import types +from numba.core import cgutils +from numba.core.datamodel import default_manager +from numba.extending import intrinsic, overload_method + +import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl +from numba_dpex.core import types as dpex_types + + +@intrinsic +def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent): + # x = types.uint8 + sig = types.void(dpex_types.DpctlSyclEvent()) + + # defines the custom code generation + def codegen(context, builder, signature, args): + sycl_event_dm = default_manager.lookup(ty_event) + event_ref = builder.extract_value( + args[0], + sycl_event_dm.get_field_position("event_ref"), + ) + + sycl.dpctl_event_wait(builder, event_ref) + + return sig, codegen + + +@overload_method(dpex_types.DpctlSyclEvent, "wait") +def ol_dpctl_sycl_event_wait( + event, +): + """Implementation of an overload to support dpctl.SyclEvent() inside + a dpjit function. + """ + return lambda event: sycl_event_wait(event) + + +# We don't want user to call sycl_event_wait(event), instead it must be called +# with event.wait(). In that way we guarantee the argument type by the +# @overload_method. +__all__ = [] diff --git a/numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py b/numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py new file mode 100644 index 0000000000..2dfbe2d00c --- /dev/null +++ b/numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py @@ -0,0 +1,16 @@ +import dpctl + +from numba_dpex import dpjit + + +@dpjit +def wait_call(a): + a.wait() + return None + + +def test_wait_DpctlSyclEvent(): + """Test the dpctl.SyclEvent.wait() call overload.""" + + e = dpctl.SyclEvent() + wait_call(e)