Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceASAN] Fix ASAN with kernel assert #2415

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 12 additions & 29 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
auto hProgram = GetProgram(hKernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
return PI->isKernelInstrumented(hKernel);
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -465,12 +459,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(

getContext()->logger.debug("==== urEnqueueKernelLaunch");

if (!isInstrumentedKernel(hKernel)) {
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

USMLaunchInfo LaunchInfo(GetContext(hKernel), GetDevice(hQueue),
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
workDim);
Expand Down Expand Up @@ -1362,9 +1350,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate(
getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
if (isInstrumentedKernel(*phKernel)) {
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));
}
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));

return UR_RESULT_SUCCESS;
}
Expand All @@ -1385,9 +1371,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
KernelInfo->RefCount++;
}
KernelInfo->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1407,10 +1391,8 @@ __urdlllocal ur_result_t urKernelRelease(
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1439,8 +1421,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
std::shared_ptr<KernelInfo> KernelInfo;
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getAsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1470,8 +1452,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue)) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1501,7 +1483,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
"==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex,
argSize);

if (auto KI = getAsanInterceptor()->getKernelInfo(hKernel)) {
{
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
// TODO: get local variable alignment
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
Expand Down Expand Up @@ -1538,8 +1521,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(
pArgValue);

std::shared_ptr<KernelInfo> KI;
if (getAsanInterceptor()->getOptions().DetectKernelArguments &&
(KI = getAsanInterceptor()->getKernelInfo(hKernel))) {
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
}
Expand Down
Loading
Loading