Skip to content

Commit

Permalink
forcing contiguous allocation in device API, for now
Browse files Browse the repository at this point in the history
  • Loading branch information
adstraw committed Dec 10, 2021
1 parent 2cea7b5 commit 643e9ad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/runtime/hexagon/hexagon/hexagon_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
if (args.type_codes[i] == kTVMDLTensorHandle) {
DLTensor* tensor = static_cast<DLTensor*>(arg_values[i].v_handle);
buffer_args.emplace_back(i, static_cast<HexagonBuffer*>(tensor->data));
// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
tensor->data = buffer_args.back().second->GetPointer()[0];
}
}
Expand Down
16 changes: 9 additions & 7 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ void HexagonDeviceAPIv2::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* r
void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* shape,
DLDataType dtype, Optional<String> mem_scope) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon);
CHECK(ndim == 1 || ndim == 2);

size_t typesize = (dtype.bits / 8) * dtype.lanes;

// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
size_t nallocs = 1;
size_t nbytes = shape[0] * typesize;
if (ndim == 2) {
nallocs = shape[0];
nbytes = shape[1] * typesize;
size_t nbytes = 1;
for (int i = 0; i < ndim; ++i) {
nbytes *= shape[i];
}
size_t typesize = (dtype.bits / 8) * dtype.lanes;
nbytes *= typesize;

size_t alignment = typesize;
if (alignment < kHexagonAllocAlignment) {
Expand Down Expand Up @@ -98,6 +98,8 @@ void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType typ
auto* hexbuf = static_cast<HexagonBuffer*>(
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size));

// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
void* ptr = hexbuf->GetPointer()[0];
workspace_allocations_.insert({ptr, hexbuf});
return ptr;
Expand Down

0 comments on commit 643e9ad

Please sign in to comment.