Skip to content

Commit

Permalink
Addressing review comments.
Browse files Browse the repository at this point in the history
Renaming queue to queue_ref, dref to device_ref and so on.

Better device info printing.

Adding SYCL queue copying interfaces in numba_dpex/dpctl_iface.

Doing NRT_ExternalAllocator_new_for_usm(DPCTLQueue_Copy(queue_ref), usm_type) to fix the segfault.
  • Loading branch information
khaled committed May 6, 2023
1 parent d9c9587 commit c0d5b0e
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 97 deletions.
36 changes: 30 additions & 6 deletions numba_dpex/_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple

import numpy
from llvmlite import ir as llvmir
from llvmlite.ir import Constant
Expand All @@ -27,6 +29,10 @@
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue

# from numba_dpex import utils
# from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder


# Numpy array constructors


Expand Down Expand Up @@ -195,8 +201,9 @@ def make_queue(context, builder, arrtype):
function for details on how to construct this argument.
Returns:
ret (tuple): A tuple containing `llvmlite.ir.instructions.ExtractValue`,
`llvmlite.ir.instructions.CastInstr` and `numba.core.pythonapi.PythonAPI`.
ret (namedtuple): A namedtuple containing `llvmlite.ir.instructions.ExtractValue`
as `queue_ref`, `llvmlite.ir.instructions.CastInstr` as `queue_address_ptr`
and `numba.core.pythonapi.PythonAPI` as `pyapi`.
"""

pyapi = context.get_python_api(builder)
Expand Down Expand Up @@ -229,7 +236,10 @@ def make_queue(context, builder, arrtype):
queue_struct = builder.load(queue_struct_ptr)
queue_ref = builder.extract_value(queue_struct, 1)

ret = (queue_ref, queue_address_ptr, pyapi)
return_values = namedtuple(
"return_values", "queue_ref queue_address_ptr pyapi"
)
ret = return_values(queue_ref, queue_address_ptr, pyapi)

return ret

Expand Down Expand Up @@ -294,7 +304,17 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
)

if isinstance(arrtype, DpnpNdArray):
(queue, queue_ptr, pyapi) = make_queue(context, builder, arrtype)
(queue_ref, queue_ptr, pyapi) = make_queue(context, builder, arrtype)
# This might fix the segfault
# sycl_queue_val = cgutils.alloca_once(
# builder,
# utils.get_llvm_type(context=context, type=types.voidptr),
# )
# fn = DpctlCAPIFnBuilder.get_dpctl_queue_copy(
# builder=builder, context=context
# )
# builder.store(builder.call(fn, []), sycl_queue_val)

usm_ty = arrtype.usm_type
usm_ty_map = {"device": 1, "shared": 2, "host": 3}
usm_type = context.get_constant(
Expand All @@ -305,7 +325,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
context.get_dummy_value(),
allocsize,
usm_type,
queue,
queue_ref,
)
mip = types.MemInfoPointer(types.voidptr)
arytypeclass = types.TypeRef(type(arrtype))
Expand Down Expand Up @@ -355,7 +375,11 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
meminfo=meminfo,
)

ret = (ary, queue) if isinstance(arrtype, DpnpNdArray) else ary
if isinstance(arrtype, DpnpNdArray):
return_values = namedtuple("return_values", "ary queue_ref")
ret = return_values(ary, queue_ref)
else:
ret = ary

return ret

Expand Down
4 changes: 3 additions & 1 deletion numba_dpex/core/runtime/_dbg_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

#pragma once

#include <stdarg.h>
#include <stdio.h>

/* Debugging facilities - enabled at compile-time */
/* #undef NDEBUG */
#if 0
#include <stdio.h>
#define DPEXRT_DEBUG(X) \
{ \
X; \
Expand Down
Loading

0 comments on commit c0d5b0e

Please sign in to comment.