From 7695abd3132516e3427ea34dbc5d9e3baa19472f Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Wed, 18 Oct 2023 14:33:44 -0400 Subject: [PATCH] Add sycl event wait overload --- numba_dpex/__init__.py | 1 + numba_dpex/core/datamodel/models.py | 1 + numba_dpex/dpctl_iface/_intrinsic.py | 49 ++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 numba_dpex/dpctl_iface/_intrinsic.py diff --git a/numba_dpex/__init__.py b/numba_dpex/__init__.py index d7722a1abf..802f502ed4 100644 --- a/numba_dpex/__init__.py +++ b/numba_dpex/__init__.py @@ -107,6 +107,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]: # Re-export all type names from numba_dpex.core.types import * # noqa E402 +from numba_dpex.dpctl_iface import _intrinsic # noqa E402 from numba_dpex.dpnp_iface import dpnpimpl # noqa E402 if config.HAS_NON_HOST_DEVICE: diff --git a/numba_dpex/core/datamodel/models.py b/numba_dpex/core/datamodel/models.py index 0c85b2bb7d..8f4edf3f6f 100644 --- a/numba_dpex/core/datamodel/models.py +++ b/numba_dpex/core/datamodel/models.py @@ -280,6 +280,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager: # Register the DpctlSyclEvent type register_model(DpctlSyclEvent)(SyclEventModel) + # Register the RangeType type register_model(RangeType)(RangeModel) diff --git a/numba_dpex/dpctl_iface/_intrinsic.py b/numba_dpex/dpctl_iface/_intrinsic.py new file mode 100644 index 0000000000..35ed77191b --- /dev/null +++ b/numba_dpex/dpctl_iface/_intrinsic.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba import types +from numba.core import cgutils +from numba.core.datamodel import default_manager +from numba.extending import intrinsic, overload_method + +import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl +from numba_dpex.core import types as dpex_types + + +@intrinsic +def sycl_event_wait(typingctx, ty_event): + # check for accepted types + if not isinstance(ty_event, dpex_types.DpctlSyclEvent): + return + + result_type = types.void + sig = result_type(ty_event) + + # defines the custom code generation + def codegen(context, builder, signature, args): + event_struct_proxy = cgutils.create_struct_proxy(ty_event)( + context, builder + ) + event_struct_ptr = event_struct_proxy._getpointer() + + event_struct = builder.load(event_struct_ptr) + sycl_event_dm = default_manager.lookup(ty_event) + event_ref = builder.extract_value( + event_struct, + sycl_event_dm.get_field_position("event_ref"), + ) + + sycl.dpctl_event_wait(builder, event_ref) + + return sig, codegen + + +@overload_method(dpex_types.DpctlSyclEvent, "wait") +def ol_dpctl_sycl_event_wait( + event, +): + """Implementation of an overload to support dpctl.SyclEvent() inside + a dpjit function. + """ + return lambda event: sycl_event_wait(event)