Skip to content

Commit

Permalink
Overload dpnp array's sycl_queue attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 30, 2023
1 parent 425e64a commit 00e8bfc
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
59 changes: 58 additions & 1 deletion numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from dpctl import get_device_cached_queue
from llvmlite import ir as llvmir
from llvmlite.ir import Constant
from llvmlite.ir import Constant, IRBuilder
from llvmlite.ir.types import DoubleType, FloatType
from numba import types
from numba.core import cgutils
from numba.core import config as numba_config
from numba.core import imputils
from numba.core.typing import signature
from numba.extending import intrinsic, overload_classmethod
from numba.np.arrayobj import (
Expand Down Expand Up @@ -1077,3 +1078,59 @@ def codegen(context, builder, sig, args):
return ary._getvalue()

return signature, codegen


@intrinsic
def ol_dpnp_nd_array_sycl_queue(
ty_context,
ty_dpnp_nd_array: DpnpNdArray,
):
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
return

ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue

sig = ty_queue(ty_dpnp_nd_array)

def codegen(context, builder: IRBuilder, sig, args: list):
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
context,
builder,
value=args[0],
)

queue_ref = array_proxy.sycl_queue

queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
context, builder
)

queue_struct_proxy.queue_ref = queue_ref
queue_struct_proxy.meminfo = array_proxy.meminfo

# Warning: current implementation prevents whole object from being
# destroyed as long as sycl_queue attribute is being used. It should be
# okay since anywere we use it as an argument callee creates a copy
# so it does not steel reference.
#
# We can avoid it by:
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
# queue_struct->meminfo =
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
# but it will allocate new meminfo object which can negatively affect
# performance.
# Speaking philosophically attribute is a part of the object and as long
# as nobody can still the reference it is a part of the owner object
# and lifetime is tied to it.

queue_value = queue_struct_proxy._getvalue()

# We need to incref meminfo so that queue model is preventing parent
# ndarray from being destroyed, that can destroy queue that we are
# using.
return imputils.impl_ret_borrowed(
context, builder, ty_queue, queue_value
)

return sig, codegen
11 changes: 10 additions & 1 deletion numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numba.core.types.containers import UniTuple
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.extending import overload
from numba.extending import overload, overload_attribute
from numba.np.arrayobj import getitem_arraynd_intp as np_getitem_arraynd_intp
from numba.np.numpy_support import is_nonelike

Expand All @@ -27,6 +27,7 @@
impl_dpnp_ones_like,
impl_dpnp_zeros,
impl_dpnp_zeros_like,
ol_dpnp_nd_array_sycl_queue,
)

# =========================================================================
Expand Down Expand Up @@ -1085,3 +1086,11 @@ def getitem_arraynd_intp(context, builder, sig, args):
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)

return ret


@overload_attribute(DpnpNdArray, "sycl_queue")
def dpnp_nd_array_sycl_queue(arr):
def get(arr):
return ol_dpnp_nd_array_sycl_queue(arr)

return get

0 comments on commit 00e8bfc

Please sign in to comment.