Skip to content

Commit

Permalink
Moved calls to createSyncPointBetweenCopyAndCompute() inside createSy…
Browse files Browse the repository at this point in the history
…ncPointAndGetZeEvents()
  • Loading branch information
konradkusiak97 committed Feb 24, 2025
1 parent 3a1c96a commit 060afd5
Showing 1 changed file with 55 additions and 100 deletions.
155 changes: 55 additions & 100 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ ur_result_t createSyncPointBetweenCopyAndCompute(
ze_command_list_handle_t ZeCommandList,
std::vector<ze_event_handle_t> &WaitEventList) {

if (!CommandBuffer->IsInOrderCmdList || !CommandBuffer->ZeCopyCommandList) {
if (!CommandBuffer->ZeCopyCommandList) {
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -229,32 +229,6 @@ ur_result_t createSyncPointBetweenCopyAndCompute(
return UR_RESULT_SUCCESS;
}

/**
* Calls createSyncPointBetweenCopyAndCompute() to create a signal event
* if required and appends it as a wait event to ZeComputeCommandList.
* @param[in] CommandBuffer The CommandBuffer where the command is appended.
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t
appendWaitBetweenCopyAndCompute(ur_exp_command_buffer_handle_t CommandBuffer) {

// Skip in case Copy engine is not used.
if (!CommandBuffer->ZeCopyCommandList) {
return UR_RESULT_SUCCESS;
}

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(createSyncPointBetweenCopyAndCompute(
CommandBuffer, CommandBuffer->ZeComputeCommandList, ZeEventList));

if (!ZeEventList.empty()) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeComputeCommandList, 1, ZeEventList.data()));
}

return UR_RESULT_SUCCESS;
}

/**
* If needed, creates a sync point for a given command and returns the L0
* events associated with the sync point.
Expand All @@ -275,7 +249,7 @@ appendWaitBetweenCopyAndCompute(ur_exp_command_buffer_handle_t CommandBuffer) {
*/
ur_result_t createSyncPointAndGetZeEvents(
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
uint32_t NumSyncPointsInWaitList,
ze_command_list_handle_t ZeCommandList, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
std::vector<ze_event_handle_t> &ZeEventList,
Expand All @@ -284,6 +258,11 @@ ur_result_t createSyncPointAndGetZeEvents(
ZeLaunchEvent = nullptr;

if (CommandBuffer->IsInOrderCmdList) {
UR_CALL(createSyncPointBetweenCopyAndCompute(CommandBuffer, ZeCommandList,
ZeEventList));
if (!ZeEventList.empty()) {
NumSyncPointsInWaitList = ZeEventList.size();
}
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -320,17 +299,14 @@ ur_result_t enqueueCommandBufferMemCopyHelper(
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *RetSyncPoint) {

std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ze_command_list_handle_t ZeCommandList =
CommandBuffer->chooseCommandList(PreferCopyEngine);

UR_CALL(createSyncPointBetweenCopyAndCompute(CommandBuffer, ZeCommandList,
ZeEventList));
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ZE2UR_CALL(zeCommandListAppendMemoryCopy,
(ZeCommandList, Dst, Src, Size, ZeLaunchEvent, ZeEventList.size(),
Expand Down Expand Up @@ -381,17 +357,14 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
const ze_copy_region_t ZeDstRegion = {DstOriginX, DstOriginY, DstOriginZ,
Width, Height, Depth};

std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ze_command_list_handle_t ZeCommandList =
CommandBuffer->chooseCommandList(PreferCopyEngine);

UR_CALL(createSyncPointBetweenCopyAndCompute(CommandBuffer, ZeCommandList,
ZeEventList));
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
(ZeCommandList, Dst, &ZeDstRegion, DstPitch, DstSlicePitch, Src,
Expand All @@ -412,21 +385,18 @@ ur_result_t enqueueCommandBufferFillHelper(
UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0),
UR_RESULT_ERROR_INVALID_VALUE);

std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
true, RetSyncPoint, ZeEventList, ZeLaunchEvent));

bool PreferCopyEngine;
UR_CALL(
preferCopyEngineForFill(CommandBuffer, PatternSize, PreferCopyEngine));

ze_command_list_handle_t ZeCommandList =
CommandBuffer->chooseCommandList(PreferCopyEngine);

UR_CALL(createSyncPointBetweenCopyAndCompute(CommandBuffer, ZeCommandList,
ZeEventList));
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ZE2UR_CALL(zeCommandListAppendMemoryFill,
(ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeLaunchEvent,
Expand Down Expand Up @@ -1167,12 +1137,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
UR_COMMAND_KERNEL_LAUNCH, CommandBuffer,
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

UR_CALL(createSyncPointBetweenCopyAndCompute(
CommandBuffer, CommandBuffer->ZeComputeCommandList, ZeEventList));

ZE2UR_CALL(zeCommandListAppendLaunchKernel,
(CommandBuffer->ZeComputeCommandList, ZeKernel,
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size(),
Expand Down Expand Up @@ -1403,32 +1371,25 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
std::ignore = Command;
std::ignore = Flags;

if (CommandBuffer->IsInOrderCmdList) {
// Sync if copy engine is used in in-order path.
UR_CALL(appendWaitBetweenCopyAndCompute(CommandBuffer));
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
UR_COMMAND_USM_PREFETCH, CommandBuffer,
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));

// Add the prefetch command to the command buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeComputeCommandList, Mem, Size));
} else {
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}
if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

// Add the prefetch command to the command-buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeComputeCommandList, Mem, Size));
// Add the prefetch command to the command-buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeComputeCommandList, Mem, Size));

if (!CommandBuffer->IsInOrderCmdList) {
// Level Zero does not have a completion "event" with the prefetch API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
Expand Down Expand Up @@ -1476,30 +1437,24 @@ ur_result_t urCommandBufferAppendUSMAdviseExp(

ze_memory_advice_t ZeAdvice = static_cast<ze_memory_advice_t>(Value);

if (CommandBuffer->IsInOrderCmdList) {
// Sync if copy engine is used in in-order path.
UR_CALL(appendWaitBetweenCopyAndCompute(CommandBuffer));
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
UR_COMMAND_USM_ADVISE, CommandBuffer, CommandBuffer->ZeComputeCommandList,
NumSyncPointsInWaitList, SyncPointWaitList, true, RetSyncPoint,
ZeEventList, ZeLaunchEvent));

ZE2UR_CALL(zeCommandListAppendMemAdvise,
(CommandBuffer->ZeComputeCommandList,
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));
} else {
std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}
if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

ZE2UR_CALL(zeCommandListAppendMemAdvise,
(CommandBuffer->ZeComputeCommandList,
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));
ZE2UR_CALL(zeCommandListAppendMemAdvise,
(CommandBuffer->ZeComputeCommandList,
CommandBuffer->Device->ZeDevice, Mem, Size, ZeAdvice));

if (!CommandBuffer->IsInOrderCmdList) {
// Level Zero does not have a completion "event" with the advise API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
Expand Down

0 comments on commit 060afd5

Please sign in to comment.