Skip to content

Commit

Permalink
Merge pull request #2166 from igchor/memory_migration
Browse files Browse the repository at this point in the history
[L0 v2] make device allocation resident and support multi-device buffers
  • Loading branch information
pbalcer authored Oct 10, 2024
2 parents c043566 + 28db1fd commit e3910da
Show file tree
Hide file tree
Showing 22 changed files with 582 additions and 361 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/multi_device.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ jobs:
strategy:
matrix:
adapter: [
{name: L0}
{name: L0},
{name: L0_V2}
]
build_type: [Debug, Release]
compiler: [{c: gcc, cxx: g++}] # TODO: investigate why memory-adapter-level_zero hangs with clang
Expand Down
7 changes: 7 additions & 0 deletions source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@

#include <cuda.h>

namespace umf {
ur_result_t getProviderNativeError(const char *, int32_t) {
// TODO: implement when UMF supports CUDA
return UR_RESULT_ERROR_UNKNOWN;
}
} // namespace umf

/// USM: Implements USM Host allocations using CUDA Pinned Memory
/// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#page-locked-host-memory
UR_APIEXPORT ur_result_t UR_APICALL
Expand Down
7 changes: 7 additions & 0 deletions source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
#include "ur_util.hpp"
#include "usm.hpp"

namespace umf {
ur_result_t getProviderNativeError(const char *, int32_t) {
// TODO: implement when UMF supports HIP
return UR_RESULT_ERROR_UNKNOWN;
}
} // namespace umf

/// USM: Implements USM Host allocations using HIP Pinned Memory
UR_APIEXPORT ur_result_t UR_APICALL
urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
Expand Down
11 changes: 11 additions & 0 deletions source/adapters/level_zero/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@

#include <umf_helpers.hpp>

namespace umf {
ur_result_t getProviderNativeError(const char *providerName,
int32_t nativeError) {
if (strcmp(providerName, "Level Zero") == 0) {
return ze2urResult(static_cast<ze_result_t>(nativeError));
}

return UR_RESULT_ERROR_UNKNOWN;
}
} // namespace umf

usm::DisjointPoolAllConfigs DisjointPoolConfigInstance =
InitializeDisjointPoolConfig();

Expand Down
7 changes: 0 additions & 7 deletions source/adapters/level_zero/v2/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ ur_result_t urMemImageCreateWithNativeHandle(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
size_t propSize, void *pPropValue,
size_t *pPropSizeRet) {
logger::error("{} function not implemented!", __FUNCTION__);
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

ur_result_t urMemImageGetInfo(ur_mem_handle_t hMemory, ur_image_info_t propName,
size_t propSize, void *pPropValue,
size_t *pPropSizeRet) {
Expand Down
42 changes: 40 additions & 2 deletions source/adapters/level_zero/v2/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,41 @@
#include "context.hpp"
#include "event_provider_normal.hpp"

static std::vector<ur_device_handle_t>
filterP2PDevices(ur_device_handle_t hSourceDevice,
const std::vector<ur_device_handle_t> &devices) {
std::vector<ur_device_handle_t> p2pDevices;
for (auto &device : devices) {
if (device == hSourceDevice) {
continue;
}

ze_bool_t p2p;
ZE2UR_CALL_THROWS(zeDeviceCanAccessPeer,
(hSourceDevice->ZeDevice, device->ZeDevice, &p2p));

if (p2p) {
p2pDevices.push_back(device);
}
}
return p2pDevices;
}

static std::vector<std::vector<ur_device_handle_t>>
populateP2PDevices(size_t maxDevices,
const std::vector<ur_device_handle_t> &devices) {
std::vector<std::vector<ur_device_handle_t>> p2pDevices(maxDevices);
for (auto &device : devices) {
p2pDevices[device->Id.value()] = filterP2PDevices(device, devices);
}
return p2pDevices;
}

ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
uint32_t numDevices,
const ur_device_handle_t *phDevices,
bool ownZeContext)
: hContext(hContext, ownZeContext),
hDevices(phDevices, phDevices + numDevices), commandListCache(hContext),
: commandListCache(hContext),
eventPoolCache(phDevices[0]->Platform->getNumDevices(),
[context = this,
platform = phDevices[0]->Platform](DeviceId deviceId) {
Expand All @@ -28,6 +57,10 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
context, device, v2::EVENT_COUNTER,
v2::QUEUE_IMMEDIATE);
}),
hContext(hContext, ownZeContext),
hDevices(phDevices, phDevices + numDevices),
p2pAccessDevices(populateP2PDevices(
phDevices[0]->Platform->getNumDevices(), this->hDevices)),
defaultUSMPool(this, nullptr) {}

ur_result_t ur_context_handle_t_::retain() {
Expand Down Expand Up @@ -65,6 +98,11 @@ ur_usm_pool_handle_t ur_context_handle_t_::getDefaultUSMPool() {
return &defaultUSMPool;
}

const std::vector<ur_device_handle_t> &
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) const {
return p2pAccessDevices[hDevice->Id.value()];
}

namespace ur::level_zero {
ur_result_t urContextCreate(uint32_t deviceCount,
const ur_device_handle_t *phDevices,
Expand Down
12 changes: 10 additions & 2 deletions source/adapters/level_zero/v2/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,22 @@ struct ur_context_handle_t_ : _ur_object {
ur_platform_handle_t getPlatform() const;
const std::vector<ur_device_handle_t> &getDevices() const;
ur_usm_pool_handle_t getDefaultUSMPool();
const std::vector<ur_device_handle_t> &
getP2PDevices(ur_device_handle_t hDevice) const;

// Checks if Device is covered by this context.
// For that the Device or its root devices need to be in the context.
bool isValidDevice(ur_device_handle_t Device) const;

const v2::raii::ze_context_handle_t hContext;
const std::vector<ur_device_handle_t> hDevices;
v2::command_list_cache_t commandListCache;
v2::event_pool_cache eventPoolCache;

private:
const v2::raii::ze_context_handle_t hContext;
const std::vector<ur_device_handle_t> hDevices;

// P2P devices for each device in the context, indexed by device id.
const std::vector<std::vector<ur_device_handle_t>> p2pAccessDevices;

ur_usm_pool_handle_t_ defaultUSMPool;
};
15 changes: 10 additions & 5 deletions source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,16 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
auto zePtr = hArgValue->getPtr(kernelDevices.front());
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
} else {
// TODO: Implement this for multi-device kernels.
// Do this the same way as in legacy (keep a pending Args vector and
// do actual allocation on kernel submission) or allocate the memory
// immediately (only for small allocations?)
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
// TODO: if devices do not have p2p capabilities, we need to have allocation
// on each device. Do this the same way as in legacy (keep a pending Args
// vector and do actual allocation on kernel submission) or allocate the
// memory immediately (only for small allocations?).

// Get memory that is accessible by the first device.
// If kernel is submitted to a different device the memory
// will be accessed trough the link or migrated in enqueueKernelLaunch.
auto zePtr = hArgValue->getPtr(kernelDevices.front());
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
}
}

Expand Down
121 changes: 84 additions & 37 deletions source/adapters/level_zero/v2/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
}

if (!hostPtrImported) {
// TODO: use UMF
ZeStruct<ze_host_mem_alloc_desc_t> hostDesc;
ZE2UR_CALL_THROWS(zeMemAllocHost, (hContext->getZeHandle(), &hostDesc, size,
0, &this->ptr));
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &this->ptr));

if (hostPtr) {
std::memcpy(this->ptr, hostPtr, size);
Expand All @@ -40,9 +38,11 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
}

ur_host_mem_handle_t::~ur_host_mem_handle_t() {
// TODO: use UMF API here
if (ptr) {
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
auto ret = hContext->getDefaultUSMPool()->free(ptr);
if (ret != UR_RESULT_SUCCESS) {
logger::error("Failed to free host memory: {}", ret);
}
}
}

Expand All @@ -51,55 +51,80 @@ void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
return ptr;
}

ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
void *src, size_t size) {
auto Id = hDevice->Id.value();

if (!deviceAllocations[Id]) {
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
UR_USM_TYPE_DEVICE, size,
&deviceAllocations[Id]));
}

auto commandList = hContext->commandListCache.getImmediateCommandList(
hDevice->ZeDevice, true,
hDevice
->QueueGroup[ur_device_handle_t_::queue_group_info_t::type::Compute]
.ZeOrdinal,
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
std::nullopt);

ZE2UR_CALL(zeCommandListAppendMemoryCopy,
(commandList.get(), deviceAllocations[Id], src, size, nullptr, 0,
nullptr));

activeAllocationDevice = hDevice;

return UR_RESULT_SUCCESS;
}

ur_device_mem_handle_t::ur_device_mem_handle_t(ur_context_handle_t hContext,
void *hostPtr, size_t size)
: ur_mem_handle_t_(hContext, size),
deviceAllocations(hContext->getPlatform()->getNumDevices()) {
// Legacy adapter allocated the memory directly on a device (first on the
// contxt) and if the buffer is used on another device, memory is migrated
// (depending on an env var setting).
//
// TODO: port this behavior or figure out if it makes sense to keep the memory
// in a host buffer (e.g. for smaller sizes).
deviceAllocations(hContext->getPlatform()->getNumDevices()),
activeAllocationDevice(nullptr) {
if (hostPtr) {
buffer.assign(reinterpret_cast<char *>(hostPtr),
reinterpret_cast<char *>(hostPtr) + size);
auto initialDevice = hContext->getDevices()[0];
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
}
}

ur_device_mem_handle_t::~ur_device_mem_handle_t() {
// TODO: use UMF API here
for (auto &ptr : deviceAllocations) {
if (ptr) {
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
auto ret = hContext->getDefaultUSMPool()->free(ptr);
if (ret != UR_RESULT_SUCCESS) {
logger::error("Failed to free device memory: {}", ret);
}
}
}
}

void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
std::lock_guard lock(this->Mutex);

auto &ptr = deviceAllocations[hDevice->Id.value()];
if (!ptr) {
ZeStruct<ze_device_mem_alloc_desc_t> deviceDesc;
ZE2UR_CALL_THROWS(zeMemAllocDevice, (hContext->getZeHandle(), &deviceDesc,
size, 0, hDevice->ZeDevice, &ptr));

if (!buffer.empty()) {
auto commandList = hContext->commandListCache.getImmediateCommandList(
hDevice->ZeDevice, true,
hDevice
->QueueGroup
[ur_device_handle_t_::queue_group_info_t::type::Compute]
.ZeOrdinal,
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
std::nullopt);
ZE2UR_CALL_THROWS(
zeCommandListAppendMemoryCopy,
(commandList.get(), ptr, buffer.data(), size, nullptr, 0, nullptr));
}
if (!activeAllocationDevice) {
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
hContext, hDevice, nullptr, UR_USM_TYPE_DEVICE, getSize(),
&deviceAllocations[hDevice->Id.value()]));
activeAllocationDevice = hDevice;
}
return ptr;

if (activeAllocationDevice == hDevice) {
return deviceAllocations[hDevice->Id.value()];
}

auto &p2pDevices = hContext->getP2PDevices(hDevice);
auto p2pAccessible = std::find(p2pDevices.begin(), p2pDevices.end(),
activeAllocationDevice) != p2pDevices.end();

if (!p2pAccessible) {
// TODO: migrate buffer through the host
throw UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

// TODO: see if it's better to migrate the memory to the specified device
return deviceAllocations[activeAllocationDevice->Id.value()];
}

namespace ur::level_zero {
Expand Down Expand Up @@ -166,6 +191,28 @@ ur_result_t urMemBufferCreateWithNativeHandle(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
size_t propSize, void *pPropValue,
size_t *pPropSizeRet) {
std::shared_lock<ur_shared_mutex> Lock(hMemory->Mutex);
UrReturnHelper returnValue(propSize, pPropValue, pPropSizeRet);

switch (propName) {
case UR_MEM_INFO_CONTEXT: {
return returnValue(hMemory->getContext());
}
case UR_MEM_INFO_SIZE: {
// Get size of the allocation
return returnValue(size_t{hMemory->getSize()});
}
default: {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}

return UR_RESULT_SUCCESS;
}

ur_result_t urMemRetain(ur_mem_handle_t hMem) {
hMem->RefCount.increment();
return UR_RESULT_SUCCESS;
Expand Down
11 changes: 9 additions & 2 deletions source/adapters/level_zero/v2/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <ur_api.h>

#include "../device.hpp"
#include "common.hpp"

struct ur_mem_handle_t_ : _ur_object {
Expand All @@ -21,6 +22,7 @@ struct ur_mem_handle_t_ : _ur_object {
virtual void *getPtr(ur_device_handle_t) = 0;

inline size_t getSize() { return size; }
inline ur_context_handle_t getContext() { return hContext; }

protected:
const ur_context_handle_t hContext;
Expand Down Expand Up @@ -48,8 +50,13 @@ struct ur_device_mem_handle_t : public ur_mem_handle_t_ {
void *getPtr(ur_device_handle_t) override;

private:
std::vector<char> buffer;

// Vector of per-device allocations indexed by device->Id
std::vector<void *> deviceAllocations;

// Specifies device on which the latest allocation resides.
// If null, there is no allocation.
ur_device_handle_t activeAllocationDevice;

ur_result_t migrateBufferTo(ur_device_handle_t hDevice, void *src,
size_t size);
};
4 changes: 4 additions & 0 deletions source/adapters/level_zero/v2/queue_create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ ur_result_t urQueueCreate(ur_context_handle_t hContext,
ur_device_handle_t hDevice,
const ur_queue_properties_t *pProperties,
ur_queue_handle_t *phQueue) {
if (!hContext->isValidDevice(hDevice)) {
return UR_RESULT_ERROR_INVALID_DEVICE;
}

// TODO: For now, always use immediate, in-order
*phQueue =
new v2::ur_queue_immediate_in_order_t(hContext, hDevice, pProperties);
Expand Down
Loading

0 comments on commit e3910da

Please sign in to comment.