Skip to content

Commit

Permalink
Merge pull request oneapi-src#1246 from 0x12CC/cooperative_kernel_fun…
Browse files Browse the repository at this point in the history
…ctions

[UR] Add default implementation for cooperative kernel functions
  • Loading branch information
kbenzie authored Feb 19, 2024
2 parents 24078c2 + 8a8d704 commit 3fd11f1
Show file tree
Hide file tree
Showing 21 changed files with 176 additions and 24 deletions.
10 changes: 8 additions & 2 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -8692,8 +8692,12 @@ urEnqueueCooperativeKernelLaunchExp(
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
UR_APIEXPORT ur_result_t UR_APICALL
urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
///< kernel is launched
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
///< that will be used when the kernel is launched
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
);

#if !defined(__GNUC__)
Expand Down Expand Up @@ -9641,6 +9645,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
ur_kernel_handle_t *phKernel;
size_t *plocalWorkSize;
size_t *pdynamicSharedMemorySize;
uint32_t **ppGroupCountRet;
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;

Expand Down
2 changes: 2 additions & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
ur_kernel_handle_t,
size_t,
size_t,
uint32_t *);

///////////////////////////////////////////////////////////////////////////////
Expand Down
10 changes: 10 additions & 0 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11399,6 +11399,16 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
ur::details::printPtr(os,
*(params->phKernel));

os << ", ";
os << ".localWorkSize = ";

os << *(params->plocalWorkSize);

os << ", ";
os << ".dynamicSharedMemorySize = ";

os << *(params->pdynamicSharedMemorySize);

os << ", ";
os << ".pGroupCountRet = ";

Expand Down
6 changes: 6 additions & 0 deletions scripts/core/exp-cooperative-kernels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ params:
- type: $x_kernel_handle_t
name: hKernel
desc: "[in] handle of the kernel object"
- type: size_t
name: localWorkSize
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
- type: size_t
name: dynamicSharedMemorySize
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"
- type: "uint32_t*"
name: "pGroupCountRet"
desc: "[out] pointer to maximum number of groups"
Expand Down
10 changes: 10 additions & 0 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

/// Set parameters for general 3D memory copy.
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
/// must be a pointer to a CUdeviceptr
Expand Down
10 changes: 10 additions & 0 deletions source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
const ur_kernel_arg_value_properties_t *pProperties,
Expand Down
6 changes: 4 additions & 2 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
return result;
}

pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
pDdiTable->pfnCooperativeKernelLaunchExp =
urEnqueueCooperativeKernelLaunchExp;

return UR_RESULT_SUCCESS;
}
Expand All @@ -416,7 +417,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
return result;
}

pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
urKernelSuggestMaxCooperativeGroupCountExp;

return UR_RESULT_SUCCESS;
}
Expand Down
10 changes: 10 additions & 0 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

/// Enqueues a wait on the given queue for all events.
/// See \ref enqueueEventWait
///
Expand Down
10 changes: 10 additions & 0 deletions source/adapters/hip/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {
Expand Down
6 changes: 4 additions & 2 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
return result;
}

pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
pDdiTable->pfnCooperativeKernelLaunchExp =
urEnqueueCooperativeKernelLaunchExp;

return UR_RESULT_SUCCESS;
}
Expand All @@ -386,7 +387,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
return result;
}

pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
urKernelSuggestMaxCooperativeGroupCountExp;

return UR_RESULT_SUCCESS;
}
Expand Down
20 changes: 20 additions & 0 deletions source/adapters/level_zero/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
ur_queue_handle_t Queue, ///< [in] handle of the queue to submit to.
ur_program_handle_t Program, ///< [in] handle of the program containing the
Expand Down Expand Up @@ -787,6 +797,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
ur_native_handle_t NativeKernel, ///< [in] the native handle of the kernel.
ur_context_handle_t Context, ///< [in] handle of the context object
Expand Down
6 changes: 4 additions & 2 deletions source/adapters/level_zero/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
return result;
}

pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
pDdiTable->pfnCooperativeKernelLaunchExp =
urEnqueueCooperativeKernelLaunchExp;

return UR_RESULT_SUCCESS;
}
Expand All @@ -463,7 +464,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
return result;
}

pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
urKernelSuggestMaxCooperativeGroupCountExp;

return UR_RESULT_SUCCESS;
}
Expand Down
11 changes: 9 additions & 2 deletions source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5443,15 +5443,22 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
size_t
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
///< kernel is launched
size_t
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
///< that will be used when the kernel is launched
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
) try {
ur_result_t result = UR_RESULT_SUCCESS;

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnSuggestMaxCooperativeGroupCountExp =
d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) {
result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
result = pfnSuggestMaxCooperativeGroupCountExp(
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
} else {
// generic implementation
}
Expand Down
10 changes: 10 additions & 0 deletions source/adapters/opencl/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
Expand Down
11 changes: 11 additions & 0 deletions source/adapters/opencl/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common.hpp"

#include <algorithm>
#include <cstddef>
#include <memory>

UR_APIEXPORT ur_result_t UR_APICALL
Expand Down Expand Up @@ -376,6 +377,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
ur_native_handle_t hNativeKernel, ur_context_handle_t, ur_program_handle_t,
const ur_kernel_native_properties_t *pProperties,
Expand Down
6 changes: 4 additions & 2 deletions source/adapters/opencl/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
return result;
}

pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
pDdiTable->pfnCooperativeKernelLaunchExp =
urEnqueueCooperativeKernelLaunchExp;

return UR_RESULT_SUCCESS;
}
Expand All @@ -407,7 +408,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
return result;
}

pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
urKernelSuggestMaxCooperativeGroupCountExp;

return UR_RESULT_SUCCESS;
}
Expand Down
14 changes: 10 additions & 4 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6037,7 +6037,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
size_t
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
///< kernel is launched
size_t
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
///< that will be used when the kernel is launched
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
) {
auto pfnSuggestMaxCooperativeGroupCountExp =
context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
Expand All @@ -6047,13 +6053,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
}

ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
&hKernel, &pGroupCountRet};
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
uint64_t instance = context.notify_begin(
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,
"urKernelSuggestMaxCooperativeGroupCountExp", &params);

ur_result_t result =
pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);

context.notify_end(
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,
Expand Down
12 changes: 9 additions & 3 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8827,7 +8827,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
size_t
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
///< kernel is launched
size_t
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
///< that will be used when the kernel is launched
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
) {
auto pfnSuggestMaxCooperativeGroupCountExp =
context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
Expand All @@ -8851,8 +8857,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
refCountContext.logInvalidReference(hKernel);
}

ur_result_t result =
pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);

return result;
}
Expand Down
11 changes: 9 additions & 2 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7571,7 +7571,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
size_t
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
///< kernel is launched
size_t
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
///< that will be used when the kernel is launched
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
) {
ur_result_t result = UR_RESULT_SUCCESS;

Expand All @@ -7587,7 +7593,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
hKernel = reinterpret_cast<ur_kernel_object_t *>(hKernel)->handle;

// forward to device-platform
result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
result = pfnSuggestMaxCooperativeGroupCountExp(
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);

return result;
}
Expand Down
Loading

0 comments on commit 3fd11f1

Please sign in to comment.