Skip to content

Commit

Permalink
Add sycl event wait overload
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 20, 2023
1 parent 9e6c224 commit 1ded0a6
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 6 deletions.
1 change: 1 addition & 0 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:

# Re-export all type names
from numba_dpex.core.types import * # noqa E402
from numba_dpex.dpctl_iface import _intrinsic # noqa E402
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402

if config.HAS_NON_HOST_DEVICE:
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager:

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)

# Register the RangeType type
register_model(RangeType)(RangeModel)

Expand Down
5 changes: 1 addition & 4 deletions numba_dpex/core/types/dpctl_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ def box_sycl_queue(typ, val, c):
class DpctlSyclEvent(types.Type):
"""A Numba type to represent a dpctl.SyclEvent PyObject."""

def __init__(self, sycl_event):
if not isinstance(sycl_event, SyclEvent):
raise TypeError("The argument sycl_event is not of type SyclEvent.")

def __init__(self):
super(DpctlSyclEvent, self).__init__(name="DpctlSyclEvent")

@property
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def typeof_dpctl_sycl_event(val, c):
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance.
"""
return DpctlSyclEvent(val)
return DpctlSyclEvent()


@typeof_impl.register(Range)
Expand Down
43 changes: 43 additions & 0 deletions numba_dpex/dpctl_iface/_intrinsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba import types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload_method

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
from numba_dpex.core import types as dpex_types


@intrinsic
def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):
sig = types.void(dpex_types.DpctlSyclEvent())

# defines the custom code generation
def codegen(context, builder, signature, args):
sycl_event_dm = default_manager.lookup(ty_event)
event_ref = builder.extract_value(
args[0],
sycl_event_dm.get_field_position("event_ref"),
)

sycl.dpctl_event_wait(builder, event_ref)

return sig, codegen


@overload_method(dpex_types.DpctlSyclEvent, "wait")
def ol_dpctl_sycl_event_wait(
event,
):
"""Implementation of an overload to support dpctl.SyclEvent() inside
a dpjit function.
"""
return lambda event: sycl_event_wait(event)


# We don't want user to call sycl_event_wait(event), instead it must be called
# with event.wait(). In that way we guarantee the argument type by the
# @overload_method.
__all__ = []
2 changes: 1 addition & 1 deletion numba_dpex/tests/core/types/DpctlSyclEvent/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_model_for_DpctlSyclEvent():
"""Test the data model for DpctlSyclEvent that is registered with numba's
default data model manager.
"""
sycl_event = DpctlSyclEvent(dpctl.SyclEvent())
sycl_event = DpctlSyclEvent()
default_model = default_manager.lookup(sycl_event)
assert isinstance(default_model, SyclEventModel)

Expand Down
16 changes: 16 additions & 0 deletions numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dpctl

from numba_dpex import dpjit


@dpjit
def wait_call(a):
a.wait()
return None


def test_wait_DpctlSyclEvent():
"""Test the dpctl.SyclEvent.wait() call overload."""

e = dpctl.SyclEvent()
wait_call(e)

0 comments on commit 1ded0a6

Please sign in to comment.