Skip to content

Commit

Permalink
Merge pull request oneapi-src#1669 from JackAKirk/fix-usmptr-get-dev
Browse files Browse the repository at this point in the history
[CUDA][HIP] Fix urUSMGetMemAllocInfo impl to use single platform.
  • Loading branch information
kbenzie authored Jun 5, 2024
2 parents 42c0b02 + 9868e3b commit f06bc02
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
11 changes: 4 additions & 7 deletions source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,13 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
(CUdeviceptr)pMem));

// currently each device is in its own platform, so find the platform at
// the same index
std::vector<ur_platform_handle_t> Platforms;
Platforms.resize(DeviceIndex + 1);
// cuda backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
Result = urPlatformGet(&AdapterHandle, 1, DeviceIndex + 1,
Platforms.data(), nullptr);
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);

// get the device from the platform
ur_device_handle_t Device = Platforms[DeviceIndex]->Devices[0].get();
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();
return ReturnValue(Device);
}
case UR_USM_ALLOC_INFO_POOL: {
Expand Down
11 changes: 4 additions & 7 deletions source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,13 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,

int DeviceIdx = hipPointerAttributeType.device;

// currently each device is in its own platform, so find the platform at
// the same index
std::vector<ur_platform_handle_t> Platforms;
Platforms.resize(DeviceIdx + 1);
// hip backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
Result = urPlatformGet(&AdapterHandle, 1, DeviceIdx + 1, Platforms.data(),
nullptr);
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);

// get the device from the platform
ur_device_handle_t Device = Platforms[DeviceIdx]->Devices[0].get();
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();
return ReturnValue(Device);
}
case UR_USM_ALLOC_INFO_POOL: {
Expand Down

0 comments on commit f06bc02

Please sign in to comment.