Skip to content

Commit

Permalink
Merge pull request oneapi-src#1581 from 0x12CC/l0_cooperative_kernels
Browse files Browse the repository at this point in the history
Implement L0 cooperative kernel functions
  • Loading branch information
kbenzie authored May 27, 2024
2 parents 0a11fb4 + d3d3f6e commit 905804c
Showing 1 changed file with 263 additions and 9 deletions.
272 changes: 263 additions & 9 deletions source/adapters/level_zero/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,264 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
ur_queue_handle_t Queue, ///< [in] handle of the queue object
ur_kernel_handle_t Kernel, ///< [in] handle of the kernel object
uint32_t WorkDim, ///< [in] number of dimensions, from 1 to 3, to specify
///< the global and work-group work-items
const size_t
*GlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned
///< values that specify the offset used to
///< calculate the global ID of a work-item
const size_t *GlobalWorkSize, ///< [in] pointer to an array of workDim
///< unsigned values that specify the number
///< of global work-items in workDim that
///< will execute the kernel function
const size_t
*LocalWorkSize, ///< [in][optional] pointer to an array of workDim
///< unsigned values that specify the number of local
///< work-items forming a work-group that will execute
///< the kernel function. If nullptr, the runtime
///< implementation will choose the work-group size.
uint32_t NumEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t
*EventWaitList, ///< [in][optional][range(0, numEventsInWaitList)]
///< pointer to a list of events that must be complete
///< before the kernel execution. If nullptr, the
///< numEventsInWaitList must be 0, indicating that no
///< wait event.
ur_event_handle_t
*OutEvent ///< [in,out][optional] return an event object that identifies
///< this particular kernel execution instance.
) {
auto ZeDevice = Queue->Device->ZeDevice;

ze_kernel_handle_t ZeKernel{};
if (Kernel->ZeKernelMap.empty()) {
ZeKernel = Kernel->ZeKernel;
} else {
auto It = Kernel->ZeKernelMap.find(ZeDevice);
if (It == Kernel->ZeKernelMap.end()) {
/* kernel and queue don't match */
return UR_RESULT_ERROR_INVALID_QUEUE;
}
ZeKernel = It->second;
}
// Lock automatically releases when this goes out of scope.
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
if (GlobalWorkOffset != NULL) {
if (!Queue->Device->Platform->ZeDriverGlobalOffsetExtensionFound) {
logger::error("No global offset extension found on this driver");
return UR_RESULT_ERROR_INVALID_VALUE;
}

ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
(ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
GlobalWorkOffset[2]));
}

// If there are any pending arguments set them now.
for (auto &Arg : Kernel->PendingArguments) {
// The ArgValue may be a NULL pointer in which case a NULL value is used for
// the kernel argument declared as a pointer to global or constant memory.
char **ZeHandlePtr = nullptr;
if (Arg.Value) {
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
Queue->Device));
}
ZE2UR_CALL(zeKernelSetArgumentValue,
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
}
Kernel->PendingArguments.clear();

ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
uint32_t WG[3]{};

// New variable needed because GlobalWorkSize parameter might not be of size 3
size_t GlobalWorkSize3D[3]{1, 1, 1};
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);

if (LocalWorkSize) {
// L0
UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits<uint32_t>::max)(),
UR_RESULT_ERROR_INVALID_VALUE);
UR_ASSERT(LocalWorkSize[1] < (std::numeric_limits<uint32_t>::max)(),
UR_RESULT_ERROR_INVALID_VALUE);
UR_ASSERT(LocalWorkSize[2] < (std::numeric_limits<uint32_t>::max)(),
UR_RESULT_ERROR_INVALID_VALUE);
WG[0] = static_cast<uint32_t>(LocalWorkSize[0]);
WG[1] = static_cast<uint32_t>(LocalWorkSize[1]);
WG[2] = static_cast<uint32_t>(LocalWorkSize[2]);
} else {
// We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
// values do not fit to 32-bit that the API only supports currently.
bool SuggestGroupSize = true;
for (int I : {0, 1, 2}) {
if (GlobalWorkSize3D[I] > UINT32_MAX) {
SuggestGroupSize = false;
}
}
if (SuggestGroupSize) {
ZE2UR_CALL(zeKernelSuggestGroupSize,
(ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1],
GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2]));
} else {
for (int I : {0, 1, 2}) {
// Try to find a I-dimension WG size that the GlobalWorkSize[I] is
// fully divisable with. Start with the max possible size in
// each dimension.
uint32_t GroupSize[] = {
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX,
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY,
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ};
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]);
while (GlobalWorkSize3D[I] % GroupSize[I]) {
--GroupSize[I];
}

if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
logger::error(
"urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
"suitable for global work size > UINT32_MAX");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
WG[I] = GroupSize[I];
}
logger::debug("urEnqueueCooperativeKernelLaunchExp: using computed WG "
"size = {{{}, {}, {}}}",
WG[0], WG[1], WG[2]);
}
}

// TODO: assert if sizes do not fit into 32-bit?

switch (WorkDim) {
case 3:
ZeThreadGroupDimensions.groupCountX =
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
ZeThreadGroupDimensions.groupCountZ =
static_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
break;
case 2:
ZeThreadGroupDimensions.groupCountX =
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
WG[2] = 1;
break;
case 1:
ZeThreadGroupDimensions.groupCountX =
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
WG[1] = WG[2] = 1;
break;

default:
logger::error("urEnqueueCooperativeKernelLaunchExp: unsupported work_dim");
return UR_RESULT_ERROR_INVALID_VALUE;
}

// Error handling for non-uniform group size case
if (GlobalWorkSize3D[0] !=
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
"range is not a "
"multiple of the group size in the 1st dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[1] !=
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
"range is not a "
"multiple of the group size in the 2nd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[2] !=
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
logger::debug("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
"range is not a "
"multiple of the group size in the 3rd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}

ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));

bool UseCopyEngine = false;
_ur_ze_event_list_t TmpWaitList;
UR_CALL(TmpWaitList.createAndRetainUrZeEventList(
NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));

// Get a new command list to be used on this call
ur_command_list_ptr_t CommandList{};
UR_CALL(Queue->Context->getAvailableCommandList(
Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
true /* AllowBatching */));

ze_event_handle_t ZeEvent = nullptr;
ur_event_handle_t InternalEvent{};
bool IsInternal = OutEvent == nullptr;
ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;

UR_CALL(createEventAndAssociateQueue(Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
CommandList, IsInternal, false));
UR_CALL(setSignalEvent(Queue, UseCopyEngine, &ZeEvent, Event,
NumEventsInWaitList, EventWaitList,
CommandList->second.ZeQueue));
(*Event)->WaitList = TmpWaitList;

// Save the kernel in the event, so that when the event is signalled
// the code can do a urKernelRelease on this kernel.
(*Event)->CommandData = (void *)Kernel;

// Increment the reference count of the Kernel and indicate that the Kernel
// is in use. Once the event has been signalled, the code in
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
// reference count on the kernel, using the kernel saved in CommandData.
UR_CALL(urKernelRetain(Kernel));

// Add to list of kernels to be submitted
if (IndirectAccessTrackingEnabled)
Queue->KernelsToBeSubmitted.push_back(Kernel);

if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
// If using immediate commandlists then gathering of indirect
// references and appending to the queue (which means submission)
// must be done together.
std::unique_lock<ur_shared_mutex> ContextsLock(
Queue->Device->Platform->ContextsMutex, std::defer_lock);
// We are going to submit kernels for execution. If indirect access flag is
// set for a kernel then we need to make a snapshot of existing memory
// allocations in all contexts in the platform. We need to lock the mutex
// guarding the list of contexts in the platform to prevent creation of new
// memory alocations in any context before we submit the kernel for
// execution.
ContextsLock.lock();
Queue->CaptureIndirectAccesses();
// Add the command to the command list, which implies submission.
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
} else {
// Add the command to the command list for later submission.
// No lock is needed here, unlike the immediate commandlist case above,
// because the kernels are not actually submitted yet. Kernels will be
// submitted only when the comamndlist is closed. Then, a lock is held.
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
}

logger::debug("calling zeCommandListAppendLaunchCooperativeKernel() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(ZeEvent));
printZeEventList((*Event)->WaitList);

// Execute command list asynchronously, as the event will be used
// to track down its completion.
UR_CALL(Queue->executeCommandList(CommandList, false, true));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
Expand Down Expand Up @@ -829,10 +1080,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
uint32_t TotalGroupCount = 0;
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
(hKernel->ZeKernel, &TotalGroupCount));
*pGroupCountRet = TotalGroupCount;
return UR_RESULT_SUCCESS;
}

Expand Down

0 comments on commit 905804c

Please sign in to comment.