Skip to content

Commit

Permalink
Adding ref-count for DLPack for clients to let local tensors go out o…
Browse files Browse the repository at this point in the history
…f scope
  • Loading branch information
cliffburdick committed Mar 20, 2023
1 parent 9390806 commit 3728f96
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
38 changes: 23 additions & 15 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,6 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
return storage_.use_count();
}


/**
* Create an overlapping tensor view
*
Expand Down Expand Up @@ -1752,6 +1751,10 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
* returns a DLPack structure based on a tensor_t. The caller is responsible for freeing the memory
* by calling ->deleter(self).
*
* **Note**: This function will increment the reference count of the tensor. It is expected that once a tensor
* is converted to DLPack someone will eventually call deleter(). If that does not happen a memory leak
* will occur.
*
* @returns Pointer to new DLManagedTensorVersioned pointer. The caller must call the deleter function when finished.
*/
DLManagedTensor *GetDLPackTensor() const {
Expand All @@ -1762,14 +1765,14 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
CUpointer_attribute attr[] = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL};
CUmemorytype mem_type;
int dev_ord;
void *data[2] = {&mem_type, &dev_ord};
void *data[2] = {&mem_type, &dev_ord};

t->data = static_cast<void*>(this->ldata_);
t->data = static_cast<void*>(this->ldata_);
t->device.device_id = 0;

// Determine where this memory resides
auto kind = GetPointerKind(this->ldata_);
auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast<CUdeviceptr>(this->ldata_));
auto kind = GetPointerKind(this->ldata_);
auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast<CUdeviceptr>(this->ldata_));
MATX_ASSERT_STR_EXP(mem_res, CUDA_SUCCESS, matxCudaError, "Error returned from cuPointerGetAttributes");
if (kind == MATX_INVALID_MEMORY) {
if (mem_type == CU_MEMORYTYPE_DEVICE) {
Expand Down Expand Up @@ -1802,28 +1805,33 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
}
}

t->ndim = RANK;
t->dtype = detail::TypeToDLPackType<T>();
t->shape = new int64_t[RANK];
t->strides = new int64_t[RANK];
t->ndim = RANK;
t->dtype = detail::TypeToDLPackType<T>();
t->shape = new int64_t[RANK];
t->strides = new int64_t[RANK];
for (int r = 0; r < RANK; r++) {
t->shape[r] = this->Size(r);
t->shape[r] = this->Size(r);
t->strides[r] = this->Stride(r);
}
t->byte_offset = 0;
t->byte_offset = 0;

mt->manager_ctx = nullptr;
// Increment reference count by making a copy of the shared_ptr by allocating on the heap and
// setting it as the context
auto t_copy = new self_type{*this};
//*t_copy = *this;
mt->manager_ctx = t_copy;
//mt->flags = 0; // Only for v1.0

//auto deleter = [](struct DLManagedTensorVersioned *mtv) { // v1.0
auto deleter = [](struct DLManagedTensor *mtv) {
delete [] mtv->dl_tensor.shape;
delete [] mtv->dl_tensor.strides;
delete static_cast<self_type *>(mtv->manager_ctx);
delete mtv;

mtv->dl_tensor.shape = nullptr;
mtv->dl_tensor.strides = nullptr;
mtv = nullptr;
mtv->dl_tensor.shape = nullptr;
mtv->dl_tensor.strides = nullptr;
mtv = nullptr;
};

mt->deleter = deleter;
Expand Down
2 changes: 2 additions & 0 deletions test/00_tensor/BasicTensorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,11 @@ TYPED_TEST(BasicTensorTestsAll, DLPack)
ASSERT_EQ(dl->dl_tensor.strides[0], t.Stride(0));
ASSERT_EQ(dl->dl_tensor.strides[1], t.Stride(1));
ASSERT_EQ(dl->dl_tensor.strides[2], t.Stride(2));
ASSERT_EQ(t.GetRefCount(), 2);
dl->deleter(dl);
ASSERT_EQ(dl->dl_tensor.shape, nullptr);
ASSERT_EQ(dl->dl_tensor.strides, nullptr);
ASSERT_EQ(t.GetRefCount(), 1);

MATX_EXIT_HANDLER();
}
Expand Down

0 comments on commit 3728f96

Please sign in to comment.