diff --git a/source/adapters/level_zero/v2/api.cpp b/source/adapters/level_zero/v2/api.cpp index 8a2153e0a5..129db02594 100644 --- a/source/adapters/level_zero/v2/api.cpp +++ b/source/adapters/level_zero/v2/api.cpp @@ -170,53 +170,6 @@ ur_result_t urBindlessImagesReleaseExternalSemaphoreExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -ur_result_t urCommandBufferAppendUSMFillExp( - ur_exp_command_buffer_handle_t hCommandBuffer, void *pMemory, - const void *pPattern, size_t patternSize, size_t size, - uint32_t numSyncPointsInWaitList, - const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, - uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList, - ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, - ur_exp_command_buffer_command_handle_t *phCommand) { - logger::error("{} function not implemented!", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - -ur_result_t urCommandBufferAppendMemBufferFillExp( - ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer, - const void *pPattern, size_t patternSize, size_t offset, size_t size, - uint32_t numSyncPointsInWaitList, - const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, - uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList, - ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, - ur_exp_command_buffer_command_handle_t *phCommand) { - logger::error("{} function not implemented!", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - -ur_result_t urCommandBufferAppendUSMPrefetchExp( - ur_exp_command_buffer_handle_t hCommandBuffer, const void *pMemory, - size_t size, ur_usm_migration_flags_t flags, - uint32_t numSyncPointsInWaitList, - const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, - uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList, - ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, - ur_exp_command_buffer_command_handle_t *phCommand) { - logger::error("{} function not implemented!", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - -ur_result_t urCommandBufferAppendUSMAdviseExp( - ur_exp_command_buffer_handle_t hCommandBuffer, const void *pMemory, - size_t size, ur_usm_advice_flags_t advice, uint32_t numSyncPointsInWaitList, - const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, - uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList, - ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, - ur_exp_command_buffer_command_handle_t *phCommand) { - logger::error("{} function not implemented!", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - ur_result_t urCommandBufferUpdateKernelLaunchExp( ur_exp_command_buffer_command_handle_t hCommand, const ur_exp_command_buffer_update_kernel_launch_desc_t diff --git a/source/adapters/level_zero/v2/command_buffer.cpp b/source/adapters/level_zero/v2/command_buffer.cpp index 8253527efe..24e7fc0be6 100644 --- a/source/adapters/level_zero/v2/command_buffer.cpp +++ b/source/adapters/level_zero/v2/command_buffer.cpp @@ -346,6 +346,114 @@ ur_result_t urCommandBufferAppendMemBufferReadRectExp( return exceptionToResult(std::current_exception()); } +ur_result_t urCommandBufferAppendUSMFillExp( + ur_exp_command_buffer_handle_t hCommandBuffer, void *pMemory, + const void *pPattern, size_t patternSize, size_t size, + uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, + ur_exp_command_buffer_command_handle_t *phCommand) try { + + // the same issue as in urCommandBufferAppendKernelLaunchExp + std::ignore = numEventsInWaitList; + std::ignore = phEventWaitList; + std::ignore = phEvent; + // sync mechanic can be ignored, because all lists are in-order + std::ignore = numSyncPointsInWaitList; + std::ignore = pSyncPointWaitList; + std::ignore = pSyncPoint; + + std::ignore = phCommand; + + UR_CALL(hCommandBuffer->commandListManager.appendUSMFill( + pMemory, patternSize, pPattern, size, 0, nullptr, nullptr)); + return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +ur_result_t urCommandBufferAppendMemBufferFillExp( + ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer, + const void *pPattern, size_t patternSize, size_t offset, size_t size, + uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, + ur_exp_command_buffer_command_handle_t *phCommand) try { + + // the same issue as in urCommandBufferAppendKernelLaunchExp + std::ignore = numEventsInWaitList; + std::ignore = phEventWaitList; + std::ignore = phEvent; + // sync mechanic can be ignored, because all lists are in-order + std::ignore = numSyncPointsInWaitList; + std::ignore = pSyncPointWaitList; + std::ignore = pSyncPoint; + + std::ignore = phCommand; + + UR_CALL(hCommandBuffer->commandListManager.appendMemBufferFill( + hBuffer, pPattern, patternSize, offset, size, 0, nullptr, nullptr)); + return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +ur_result_t urCommandBufferAppendUSMPrefetchExp( + ur_exp_command_buffer_handle_t hCommandBuffer, const void *pMemory, + size_t size, ur_usm_migration_flags_t flags, + uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, + ur_exp_command_buffer_command_handle_t *phCommand) try { + + // the same issue as in urCommandBufferAppendKernelLaunchExp + std::ignore = numEventsInWaitList; + std::ignore = phEventWaitList; + std::ignore = phEvent; + // sync mechanic can be ignored, because all lists are in-order + std::ignore = numSyncPointsInWaitList; + std::ignore = pSyncPointWaitList; + std::ignore = pSyncPoint; + + std::ignore = phCommand; + + UR_CALL(hCommandBuffer->commandListManager.appendUSMPrefetch( + pMemory, size, flags, 0, nullptr, nullptr)); + + return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +ur_result_t urCommandBufferAppendUSMAdviseExp( + ur_exp_command_buffer_handle_t hCommandBuffer, const void *pMemory, + size_t size, ur_usm_advice_flags_t advice, uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_event_handle_t *phEvent, + ur_exp_command_buffer_command_handle_t *phCommand) try { + + // the same issue as in urCommandBufferAppendKernelLaunchExp + std::ignore = numEventsInWaitList; + std::ignore = phEventWaitList; + std::ignore = phEvent; + // sync mechanic can be ignored, because all lists are in-order + std::ignore = numSyncPointsInWaitList; + std::ignore = pSyncPointWaitList; + std::ignore = pSyncPoint; + + std::ignore = phCommand; + + UR_CALL(hCommandBuffer->commandListManager.appendUSMAdvise(pMemory, size, + advice, nullptr)); + + return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); +} ur_result_t urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, ur_exp_command_buffer_info_t propName, diff --git a/source/adapters/level_zero/v2/command_list_manager.cpp b/source/adapters/level_zero/v2/command_list_manager.cpp index b1510d0a97..5c6d2330d7 100644 --- a/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/source/adapters/level_zero/v2/command_list_manager.cpp @@ -31,6 +31,50 @@ ur_command_list_manager::~ur_command_list_manager() { ur::level_zero::urDeviceRelease(device); } +ur_result_t ur_command_list_manager::appendGenericFillUnlocked( + ur_mem_buffer_t *dst, size_t offset, size_t patternSize, + const void *pPattern, size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent, + ur_command_t commandType) { + + auto zeSignalEvent = getSignalEvent(phEvent, commandType); + + auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList); + + auto pDst = ur_cast(dst->getDevicePtr( + device, ur_mem_buffer_t::device_access_mode_t::read_only, offset, size, + [&](void *src, void *dst, size_t size) { + ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy, + (zeCommandList.get(), dst, src, size, nullptr, + waitListView.num, waitListView.handles)); + waitListView.clear(); + })); + + // PatternSize must be a power of two for zeCommandListAppendMemoryFill. + // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy. + if (isPowerOf2(patternSize)) { + ZE2UR_CALL(zeCommandListAppendMemoryFill, + (zeCommandList.get(), pDst, pPattern, patternSize, size, + zeSignalEvent, waitListView.num, waitListView.handles)); + } else { + // Copy pattern into every entry in memory array pointed by Ptr. + uint32_t numOfCopySteps = size / patternSize; + const void *src = pPattern; + + for (uint32_t step = 0; step < numOfCopySteps; ++step) { + void *dst = reinterpret_cast(reinterpret_cast(pDst) + + step * patternSize); + ZE2UR_CALL(zeCommandListAppendMemoryCopy, + (zeCommandList.get(), dst, src, patternSize, + step == numOfCopySteps - 1 ? zeSignalEvent : nullptr, + waitListView.num, waitListView.handles)); + waitListView.clear(); + } + } + + return UR_RESULT_SUCCESS; +} + ur_result_t ur_command_list_manager::appendGenericCopyUnlocked( ur_mem_buffer_t *src, ur_mem_buffer_t *dst, bool blocking, size_t srcOffset, size_t dstOffset, size_t size, uint32_t numEventsInWaitList, @@ -209,6 +253,96 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy( return UR_RESULT_SUCCESS; } +ur_result_t ur_command_list_manager::appendMemBufferFill( + ur_mem_handle_t hMem, const void *pPattern, size_t patternSize, + size_t offset, size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + TRACK_SCOPE_LATENCY("ur_command_list_manager::appendMemBufferFill"); + + auto hBuffer = hMem->getBuffer(); + UR_ASSERT(offset + size <= hBuffer->getSize(), UR_RESULT_ERROR_INVALID_SIZE); + + std::scoped_lock lock(this->Mutex, + hBuffer->getMutex()); + + return appendGenericFillUnlocked(hBuffer, offset, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, + phEvent, UR_COMMAND_MEM_BUFFER_FILL); +} + +ur_result_t ur_command_list_manager::appendUSMFill( + void *pMem, size_t patternSize, const void *pPattern, size_t size, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMFill"); + + std::scoped_lock lock(this->Mutex); + + ur_usm_handle_t dstHandle(context, size, pMem); + return appendGenericFillUnlocked(&dstHandle, 0, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, + phEvent, UR_COMMAND_USM_FILL); +} + +ur_result_t ur_command_list_manager::appendUSMPrefetch( + const void *pMem, size_t size, ur_usm_migration_flags_t flags, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMPrefetch"); + + std::ignore = flags; + + std::scoped_lock lock(this->Mutex); + + auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_PREFETCH); + + auto [pWaitEvents, numWaitEvents] = + getWaitListView(phEventWaitList, numEventsInWaitList); + + if (pWaitEvents) { + ZE2UR_CALL(zeCommandListAppendWaitOnEvents, + (zeCommandList.get(), numWaitEvents, pWaitEvents)); + } + // TODO: figure out how to translate "flags" + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (zeCommandList.get(), pMem, size)); + if (zeSignalEvent) { + ZE2UR_CALL(zeCommandListAppendSignalEvent, + (zeCommandList.get(), zeSignalEvent)); + } + + return UR_RESULT_SUCCESS; +} + +ur_result_t +ur_command_list_manager::appendUSMAdvise(const void *pMem, size_t size, + ur_usm_advice_flags_t advice, + ur_event_handle_t *phEvent) { + TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMAdvise"); + + std::scoped_lock lock(this->Mutex); + + auto zeAdvice = ur_cast(advice); + + auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_ADVISE); + + auto [pWaitEvents, numWaitEvents] = getWaitListView(nullptr, 0); + + if (pWaitEvents) { + ZE2UR_CALL(zeCommandListAppendWaitOnEvents, + (zeCommandList.get(), numWaitEvents, pWaitEvents)); + } + + ZE2UR_CALL(zeCommandListAppendMemAdvise, + (zeCommandList.get(), device->ZeDevice, pMem, size, zeAdvice)); + + if (zeSignalEvent) { + ZE2UR_CALL(zeCommandListAppendSignalEvent, + (zeCommandList.get(), zeSignalEvent)); + } + return UR_RESULT_SUCCESS; +} + ur_result_t ur_command_list_manager::appendMemBufferRead( ur_mem_handle_t hMem, bool blockingRead, size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList, diff --git a/source/adapters/level_zero/v2/command_list_manager.hpp b/source/adapters/level_zero/v2/command_list_manager.hpp index e85d9b9049..23cefbd9b7 100644 --- a/source/adapters/level_zero/v2/command_list_manager.hpp +++ b/source/adapters/level_zero/v2/command_list_manager.hpp @@ -98,6 +98,27 @@ struct ur_command_list_manager : public _ur_object { size_t height, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent); + ur_result_t appendMemBufferFill(ur_mem_handle_t hBuffer, const void *pPattern, + size_t patternSize, size_t offset, + size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent); + + ur_result_t appendUSMFill(void *pMem, size_t patternSize, + const void *pPattern, size_t size, + uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent); + + ur_result_t appendUSMPrefetch(const void *pMem, size_t size, + ur_usm_migration_flags_t flags, + uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent); + + ur_result_t appendUSMAdvise(const void *pMem, size_t size, + ur_usm_advice_flags_t advice, + ur_event_handle_t *phEvent); ze_command_list_handle_t getZeCommandList(); @@ -107,6 +128,12 @@ struct ur_command_list_manager : public _ur_object { ur_command_t commandType); private: + ur_result_t appendGenericFillUnlocked( + ur_mem_buffer_t *hBuffer, size_t offset, size_t patternSize, + const void *pPattern, size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent, + ur_command_t commandType); + ur_result_t appendGenericCopyUnlocked( ur_mem_buffer_t *src, ur_mem_buffer_t *dst, bool blocking, size_t srcOffset, size_t dstOffset, size_t size, diff --git a/source/adapters/level_zero/v2/queue_immediate_in_order.cpp b/source/adapters/level_zero/v2/queue_immediate_in_order.cpp index d33ac12f7e..f5eb9b1eca 100644 --- a/source/adapters/level_zero/v2/queue_immediate_in_order.cpp +++ b/source/adapters/level_zero/v2/queue_immediate_in_order.cpp @@ -372,16 +372,11 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueMemBufferFill( const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::enqueueMemBufferFill"); - auto hBuffer = hMem->getBuffer(); - - UR_ASSERT(offset + size <= hBuffer->getSize(), UR_RESULT_ERROR_INVALID_SIZE); - - std::scoped_lock lock(this->Mutex, - hBuffer->getMutex()); + UR_CALL(commandListManager.appendMemBufferFill( + hMem, pPattern, patternSize, offset, size, numEventsInWaitList, + phEventWaitList, phEvent)); - return enqueueGenericFillUnlocked(hBuffer, offset, patternSize, pPattern, - size, numEventsInWaitList, phEventWaitList, - phEvent, UR_COMMAND_MEM_BUFFER_FILL); + return UR_RESULT_SUCCESS; } ur_result_t ur_queue_immediate_in_order_t::enqueueMemImageRead( @@ -601,7 +596,9 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMFill( ur_event_handle_t *phEvent) { TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::enqueueUSMFill"); - std::scoped_lock lock(this->Mutex); + UR_CALL(commandListManager.appendUSMFill(pMem, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, + phEvent)); ur_usm_handle_t dstHandle(hContext, size, pMem); return enqueueGenericFillUnlocked(&dstHandle, 0, patternSize, pPattern, size, @@ -629,27 +626,8 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMPrefetch( ur_event_handle_t *phEvent) { TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::enqueueUSMPrefetch"); - std::ignore = flags; - - std::scoped_lock lock(this->Mutex); - - auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_PREFETCH); - - auto [pWaitEvents, numWaitEvents] = - getWaitListView(phEventWaitList, numEventsInWaitList); - - if (pWaitEvents) { - ZE2UR_CALL( - zeCommandListAppendWaitOnEvents, - (commandListManager.getZeCommandList(), numWaitEvents, pWaitEvents)); - } - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (commandListManager.getZeCommandList(), pMem, size)); - if (zeSignalEvent) { - ZE2UR_CALL(zeCommandListAppendSignalEvent, - (commandListManager.getZeCommandList(), zeSignalEvent)); - } + UR_CALL(commandListManager.appendUSMPrefetch( + pMem, size, flags, numEventsInWaitList, phEventWaitList, phEvent)); return UR_RESULT_SUCCESS; } @@ -660,31 +638,8 @@ ur_queue_immediate_in_order_t::enqueueUSMAdvise(const void *pMem, size_t size, ur_event_handle_t *phEvent) { TRACK_SCOPE_LATENCY("ur_queue_immediate_in_order_t::enqueueUSMAdvise"); - std::ignore = flags; - - std::scoped_lock lock(this->Mutex); - - auto zeAdvice = ur_cast(advice); - - auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_ADVISE); + UR_CALL(commandListManager.appendUSMAdvise(pMem, size, advice, phEvent)); - auto [pWaitEvents, numWaitEvents] = getWaitListView(nullptr, 0); - - if (pWaitEvents) { - ZE2UR_CALL( - zeCommandListAppendWaitOnEvents, - (commandListManager.getZeCommandList(), numWaitEvents, pWaitEvents)); - } - - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemAdvise, - (commandListManager.getZeCommandList(), this->hDevice->ZeDevice, - pMem, size, zeAdvice)); - - if (zeSignalEvent) { - ZE2UR_CALL(zeCommandListAppendSignalEvent, - (commandListManager.getZeCommandList(), zeSignalEvent)); - } return UR_RESULT_SUCCESS; }