Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lifetime management for sycl event #1188

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class SyclEventModel(StructModel):
def __init__(self, dmm, fe_type):
members = [
(
"parent",
types.CPointer(types.int8),
"meminfo",
types.MemInfoPointer(types.pyobject),
),
(
"event_ref",
Expand Down
27 changes: 22 additions & 5 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "_queuestruct.h"
#include "_usmarraystruct.h"

#include "numba/core/runtime/nrt_external.h"

// forward declarations
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
Expand Down Expand Up @@ -64,9 +66,12 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
PyArray_Descr *descr);
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
queuestruct_t *queue_struct);
static int DPEXRT_sycl_event_from_python(PyObject *obj,
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
PyObject *obj,
eventstruct_t *event_struct);
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct);
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
eventstruct_t *eventstruct);

/** An NRT_external_malloc_func implementation using DPCTLmalloc_device.
*
Expand Down Expand Up @@ -1306,7 +1311,8 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
* represent a dpctl.SyclEvent inside Numba.
* @return {return} Return code indicating success (0) or failure (-1).
*/
static int DPEXRT_sycl_event_from_python(PyObject *obj,
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
PyObject *obj,
eventstruct_t *event_struct)
{
struct PySyclEventObject *event_obj = NULL;
Expand All @@ -1328,7 +1334,13 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
goto error;
}

event_struct->parent = obj;
// We are doing incref here to ensure python does not release the object
// while NRT references it. Coresponding decref is called by NRT in
// NRT_MemInfo_pyobject_dtor once there is no reference to this object by
// the code managed by NRT.
Py_INCREF(event_obj);
ZzEeKkAa marked this conversation as resolved.
Show resolved Hide resolved
event_struct->meminfo =
nrt->manage_memory(event_obj, NRT_MemInfo_pyobject_dtor);
event_struct->event_ref = event_ref;

return 0;
Expand All @@ -1355,12 +1367,13 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
* @return {return} A PyObject created from the eventstruct->parent, if
* the PyObject could not be created return NULL.
*/
static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
eventstruct_t *eventstruct)
{
PyObject *orig_event = NULL;
PyGILState_STATE gstate;

orig_event = eventstruct->parent;
orig_event = nrt->get_data(eventstruct->meminfo);
// FIXME: Better error checking is needed to enforce the boxing of the event
// object. For now, only the minimal is done as the returning of SyclEvent
// from a dpjit function should not be a used often and the dpctl C API for
Expand All @@ -1375,9 +1388,13 @@ static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
DPEXRT_DEBUG(
drt_debug_print("DPEXRT-DEBUG: In DPEXRT_sycl_event_to_python.\n"););

// TODO: is there any way to release meminfo without calling dtor so we dont
// call incref, decref one after another.
// We need to increase reference count because we are returning new
// reference to the same event.
Py_INCREF(orig_event);
// We need to release meminfo since we are taking ownership back.
nrt->release(eventstruct->meminfo);

return orig_event;
}
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/runtime/_eventstruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

#pragma once

#include <Python.h>
#include "numba/core/runtime/nrt_external.h"

typedef struct
{
PyObject *parent;
NRT_MemInfo *meminfo;
void *event_ref;
} eventstruct_t;
12 changes: 12 additions & 0 deletions numba_dpex/core/runtime/_nrt_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,15 @@ void NRT_MemInfo_destroy(NRT_MemInfo *mi)
TheMSys.stats.mi_free++;
}
}

void NRT_MemInfo_pyobject_dtor(void *data)
{
PyGILState_STATE gstate;
PyObject *ownerobj = data;

gstate = PyGILState_Ensure(); /* ensure the GIL */
Py_DECREF(data); /* release the python object */
PyGILState_Release(gstate); /* release the GIL */

DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: pyobject destructor\n"););
}
1 change: 1 addition & 0 deletions numba_dpex/core/runtime/_nrt_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ size_t NRT_MemInfo_refcount(NRT_MemInfo *mi);
void NRT_Free(void *ptr);
void NRT_dealloc(NRT_MemInfo *mi);
void NRT_MemInfo_destroy(NRT_MemInfo *mi);
void NRT_MemInfo_pyobject_dtor(void *data);

#endif /* _NRT_HELPER_H_ */
15 changes: 11 additions & 4 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools

import numba.core.unsafe.nrt
from llvmlite import ir as llvmir
from numba.core import cgutils, types

Expand Down Expand Up @@ -206,26 +207,32 @@ def queuestruct_to_python(self, pyapi, val):
def eventstruct_from_python(self, pyapi, obj, ptr):
"""Calls the c function DPEXRT_sycl_event_from_python"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
return self.error

def eventstruct_to_python(self, pyapi, val):
"""Calls the c function DPEXRT_sycl_event_to_python"""

fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr])
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr])
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_to_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")

qptr = cgutils.alloca_once_value(pyapi.builder, val)
ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr)
self.error = pyapi.builder.call(fn, [ptr])

self.error = pyapi.builder.call(fn, [nrt_api, ptr])

return self.error

Expand Down
6 changes: 3 additions & 3 deletions numba_dpex/dpctl_iface/libsyclinterface_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def dpctl_event_wait(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
ZzEeKkAa marked this conversation as resolved.
Show resolved Hide resolved
arg_list=[cgutils.voidptr_t],
func_name="DPCTLEvent_Wait",
)
Expand All @@ -85,7 +85,7 @@ def dpctl_event_delete(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
arg_list=[cgutils.voidptr_t],
func_name="DPCTLEvent_Delete",
)
Expand All @@ -99,7 +99,7 @@ def dpctl_queue_delete(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
arg_list=[cgutils.voidptr_t],
func_name="DPCTLQueue_Delete",
)
Expand Down