From d3d3f6e519e804d109117e99ce6d84d10dba9131 Mon Sep 17 00:00:00 2001 From: Michael Aziz Date: Fri, 3 May 2024 19:06:23 -0700 Subject: [PATCH] Implement L0 cooperative kernel functions Defines `urKernelSuggestMaxCooperativeGroupCountExp` and `urEnqueueCooperativeKernelLaunchExp` to enable cooperative kernels with more than one work group. Signed-off-by: Michael Aziz --- source/adapters/level_zero/kernel.cpp | 272 +++++++++++++++++++++++++- 1 file changed, 263 insertions(+), 9 deletions(-) diff --git a/source/adapters/level_zero/kernel.cpp b/source/adapters/level_zero/kernel.cpp index 65feaae511..b8b7a56bac 100644 --- a/source/adapters/level_zero/kernel.cpp +++ b/source/adapters/level_zero/kernel.cpp @@ -271,13 +271,264 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( - ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim, - const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, - const size_t *pLocalWorkSize, uint32_t numEventsInWaitList, - const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, - numEventsInWaitList, phEventWaitList, phEvent); + ur_queue_handle_t Queue, ///< [in] handle of the queue object + ur_kernel_handle_t Kernel, ///< [in] handle of the kernel object + uint32_t WorkDim, ///< [in] number of dimensions, from 1 to 3, to specify + ///< the global and work-group work-items + const size_t + *GlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned + ///< values that specify the offset used to + ///< calculate the global ID of a work-item + const size_t *GlobalWorkSize, ///< [in] pointer to an array of workDim + ///< unsigned values that specify the number + ///< of global work-items in workDim that + ///< will execute the kernel function + const size_t + *LocalWorkSize, ///< [in][optional] pointer to an array of workDim + ///< unsigned values that specify the number of local + ///< work-items forming a work-group that will execute + ///< the kernel function. If nullptr, the runtime + ///< implementation will choose the work-group size. + uint32_t NumEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t + *EventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] + ///< pointer to a list of events that must be complete + ///< before the kernel execution. If nullptr, the + ///< numEventsInWaitList must be 0, indicating that no + ///< wait event. + ur_event_handle_t + *OutEvent ///< [in,out][optional] return an event object that identifies + ///< this particular kernel execution instance. +) { + auto ZeDevice = Queue->Device->ZeDevice; + + ze_kernel_handle_t ZeKernel{}; + if (Kernel->ZeKernelMap.empty()) { + ZeKernel = Kernel->ZeKernel; + } else { + auto It = Kernel->ZeKernelMap.find(ZeDevice); + if (It == Kernel->ZeKernelMap.end()) { + /* kernel and queue don't match */ + return UR_RESULT_ERROR_INVALID_QUEUE; + } + ZeKernel = It->second; + } + // Lock automatically releases when this goes out of scope. + std::scoped_lock Lock( + Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex); + if (GlobalWorkOffset != NULL) { + if (!Queue->Device->Platform->ZeDriverGlobalOffsetExtensionFound) { + logger::error("No global offset extension found on this driver"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + ZE2UR_CALL(zeKernelSetGlobalOffsetExp, + (ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1], + GlobalWorkOffset[2])); + } + + // If there are any pending arguments set them now. + for (auto &Arg : Kernel->PendingArguments) { + // The ArgValue may be a NULL pointer in which case a NULL value is used for + // the kernel argument declared as a pointer to global or constant memory. + char **ZeHandlePtr = nullptr; + if (Arg.Value) { + UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, + Queue->Device)); + } + ZE2UR_CALL(zeKernelSetArgumentValue, + (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + } + Kernel->PendingArguments.clear(); + + ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; + uint32_t WG[3]{}; + + // New variable needed because GlobalWorkSize parameter might not be of size 3 + size_t GlobalWorkSize3D[3]{1, 1, 1}; + std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); + + if (LocalWorkSize) { + // L0 + UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits::max)(), + UR_RESULT_ERROR_INVALID_VALUE); + UR_ASSERT(LocalWorkSize[1] < (std::numeric_limits::max)(), + UR_RESULT_ERROR_INVALID_VALUE); + UR_ASSERT(LocalWorkSize[2] < (std::numeric_limits::max)(), + UR_RESULT_ERROR_INVALID_VALUE); + WG[0] = static_cast(LocalWorkSize[0]); + WG[1] = static_cast(LocalWorkSize[1]); + WG[2] = static_cast(LocalWorkSize[2]); + } else { + // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize + // values do not fit to 32-bit that the API only supports currently. + bool SuggestGroupSize = true; + for (int I : {0, 1, 2}) { + if (GlobalWorkSize3D[I] > UINT32_MAX) { + SuggestGroupSize = false; + } + } + if (SuggestGroupSize) { + ZE2UR_CALL(zeKernelSuggestGroupSize, + (ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], + GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); + } else { + for (int I : {0, 1, 2}) { + // Try to find a I-dimension WG size that the GlobalWorkSize[I] is + // fully divisable with. Start with the max possible size in + // each dimension. + uint32_t GroupSize[] = { + Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX, + Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY, + Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ}; + GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); + while (GlobalWorkSize3D[I] % GroupSize[I]) { + --GroupSize[I]; + } + + if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) { + logger::error( + "urEnqueueCooperativeKernelLaunchExp: can't find a WG size " + "suitable for global work size > UINT32_MAX"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + WG[I] = GroupSize[I]; + } + logger::debug("urEnqueueCooperativeKernelLaunchExp: using computed WG " + "size = {{{}, {}, {}}}", + WG[0], WG[1], WG[2]); + } + } + + // TODO: assert if sizes do not fit into 32-bit? + + switch (WorkDim) { + case 3: + ZeThreadGroupDimensions.groupCountX = + static_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + static_cast(GlobalWorkSize3D[1] / WG[1]); + ZeThreadGroupDimensions.groupCountZ = + static_cast(GlobalWorkSize3D[2] / WG[2]); + break; + case 2: + ZeThreadGroupDimensions.groupCountX = + static_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + static_cast(GlobalWorkSize3D[1] / WG[1]); + WG[2] = 1; + break; + case 1: + ZeThreadGroupDimensions.groupCountX = + static_cast(GlobalWorkSize3D[0] / WG[0]); + WG[1] = WG[2] = 1; + break; + + default: + logger::error("urEnqueueCooperativeKernelLaunchExp: unsupported work_dim"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + // Error handling for non-uniform group size case + if (GlobalWorkSize3D[0] != + size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { + logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The " + "range is not a " + "multiple of the group size in the 1st dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[1] != + size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { + logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The " + "range is not a " + "multiple of the group size in the 2nd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[2] != + size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { + logger::debug("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The " + "range is not a " + "multiple of the group size in the 3rd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + + ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2])); + + bool UseCopyEngine = false; + _ur_ze_event_list_t TmpWaitList; + UR_CALL(TmpWaitList.createAndRetainUrZeEventList( + NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine)); + + // Get a new command list to be used on this call + ur_command_list_ptr_t CommandList{}; + UR_CALL(Queue->Context->getAvailableCommandList( + Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList, + true /* AllowBatching */)); + + ze_event_handle_t ZeEvent = nullptr; + ur_event_handle_t InternalEvent{}; + bool IsInternal = OutEvent == nullptr; + ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent; + + UR_CALL(createEventAndAssociateQueue(Queue, Event, UR_COMMAND_KERNEL_LAUNCH, + CommandList, IsInternal, false)); + UR_CALL(setSignalEvent(Queue, UseCopyEngine, &ZeEvent, Event, + NumEventsInWaitList, EventWaitList, + CommandList->second.ZeQueue)); + (*Event)->WaitList = TmpWaitList; + + // Save the kernel in the event, so that when the event is signalled + // the code can do a urKernelRelease on this kernel. + (*Event)->CommandData = (void *)Kernel; + + // Increment the reference count of the Kernel and indicate that the Kernel + // is in use. Once the event has been signalled, the code in + // CleanupCompletedEvent(Event) will do a urKernelRelease to update the + // reference count on the kernel, using the kernel saved in CommandData. + UR_CALL(urKernelRetain(Kernel)); + + // Add to list of kernels to be submitted + if (IndirectAccessTrackingEnabled) + Queue->KernelsToBeSubmitted.push_back(Kernel); + + if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) { + // If using immediate commandlists then gathering of indirect + // references and appending to the queue (which means submission) + // must be done together. + std::unique_lock ContextsLock( + Queue->Device->Platform->ContextsMutex, std::defer_lock); + // We are going to submit kernels for execution. If indirect access flag is + // set for a kernel then we need to make a snapshot of existing memory + // allocations in all contexts in the platform. We need to lock the mutex + // guarding the list of contexts in the platform to prevent creation of new + // memory alocations in any context before we submit the kernel for + // execution. + ContextsLock.lock(); + Queue->CaptureIndirectAccesses(); + // Add the command to the command list, which implies submission. + ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel, + (CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent, + (*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList)); + } else { + // Add the command to the command list for later submission. + // No lock is needed here, unlike the immediate commandlist case above, + // because the kernels are not actually submitted yet. Kernels will be + // submitted only when the comamndlist is closed. Then, a lock is held. + ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel, + (CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent, + (*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList)); + } + + logger::debug("calling zeCommandListAppendLaunchCooperativeKernel() with" + " ZeEvent {}", + ur_cast(ZeEvent)); + printZeEventList((*Event)->WaitList); + + // Execute command list asynchronously, as the event will be used + // to track down its completion. + UR_CALL(Queue->executeCommandList(CommandList, false, true)); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( @@ -818,10 +1069,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, size_t localWorkSize, size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) { - (void)hKernel; (void)localWorkSize; (void)dynamicSharedMemorySize; - *pGroupCountRet = 1; + std::shared_lock Guard(hKernel->Mutex); + uint32_t TotalGroupCount = 0; + ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount, + (hKernel->ZeKernel, &TotalGroupCount)); + *pGroupCountRet = TotalGroupCount; return UR_RESULT_SUCCESS; }