Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceASAN] Fix ASAN with kernel assert #2415

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ if(UR_ENABLE_SANITIZER)
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_libdevice.hpp
Expand Down
47 changes: 13 additions & 34 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
auto hProgram = GetProgram(hKernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
if (PI == nullptr) {
return false;
}
return PI->isKernelInstrumented(hKernel);
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -470,15 +461,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(

getContext()->logger.debug("==== urEnqueueKernelLaunch");

if (!isInstrumentedKernel(hKernel)) {
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

LaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
workDim);
UR_CALL(LaunchInfo.Data.syncToDevice(hQueue));

UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));

Expand Down Expand Up @@ -1366,9 +1352,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate(
getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
if (isInstrumentedKernel(*phKernel)) {
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));
}
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));

return UR_RESULT_SUCCESS;
}
Expand All @@ -1389,9 +1373,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
KernelInfo->RefCount++;
}
KernelInfo->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1411,10 +1393,8 @@ __urdlllocal ur_result_t urKernelRelease(
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}

return UR_RESULT_SUCCESS;
Expand All @@ -1440,11 +1420,10 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
getContext()->logger.debug("==== urKernelSetArgValue");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getAsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1473,9 +1452,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
getContext()->logger.debug("==== urKernelSetArgMemObj");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue)) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1505,7 +1483,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
"==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex,
argSize);

if (auto KI = getAsanInterceptor()->getKernelInfo(hKernel)) {
{
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
// TODO: get local variable alignment
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
Expand Down Expand Up @@ -1542,8 +1521,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(
pArgValue);

std::shared_ptr<KernelInfo> KI;
if (getAsanInterceptor()->getOptions().DetectKernelArguments &&
(KI = getAsanInterceptor()->getKernelInfo(hKernel))) {
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
}
Expand Down
51 changes: 33 additions & 18 deletions source/loader/layers/sanitizer/asan/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,13 @@ ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return UR_RESULT_SUCCESS;
}
m_KernelMap.emplace(Kernel, std::make_shared<KernelInfo>(Kernel));

auto hProgram = GetProgram(Kernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
bool IsInstrumented = PI->isKernelInstrumented(Kernel);

m_KernelMap.emplace(Kernel,
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -685,9 +691,19 @@ ur_result_t AsanInterceptor::prepareLaunch(
std::shared_ptr<ContextInfo> &ContextInfo,
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {

auto KernelInfo = getKernelInfo(Kernel);
assert(KernelInfo && "Kernel should be instrumented");

auto ArgNums = GetKernelNumArgs(Kernel);
auto LocalMemoryUsage =
GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle);
auto PrivateMemoryUsage =
GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle);

getContext()->logger.info(
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
"LocalMemory={}, PrivateMemory={})",
(void *)Kernel, GetKernelName(Kernel), ArgNums,
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);

// Validate pointer arguments
if (getOptions().DetectKernelArguments) {
Expand Down Expand Up @@ -719,11 +735,17 @@ ur_result_t AsanInterceptor::prepareLaunch(
}
}

auto ArgNums = GetKernelNumArgs(Kernel);
if (!KernelInfo->IsInstrumented) {
return UR_RESULT_SUCCESS;
}

// We must prepare all kernel args before call
// urKernelGetSuggestedLocalWorkSize, otherwise the call will fail on
// CPU device.
if (ArgNums) {
{
assert(ArgNums >= 1 &&
"Sanitized Kernel should have at least one argument");

ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr());
if (URes != UR_RESULT_SUCCESS) {
Expand Down Expand Up @@ -763,15 +785,6 @@ ur_result_t AsanInterceptor::prepareLaunch(
LaunchInfo.Data.Host.DeviceTy = DeviceInfo->Type;
LaunchInfo.Data.Host.Debug = getOptions().Debug ? 1 : 0;

auto LocalMemoryUsage =
GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle);
auto PrivateMemoryUsage =
GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle);

getContext()->logger.info(
"KernelInfo {} (LocalMemory={}, PrivateMemory={})", (void *)Kernel,
LocalMemoryUsage, PrivateMemoryUsage);

// Write shadow memory offset for local memory
if (getOptions().DetectLocals) {
if (DeviceInfo->Shadow->AllocLocalShadow(
Expand Down Expand Up @@ -831,10 +844,12 @@ ur_result_t AsanInterceptor::prepareLaunch(
// sync asan runtime data to device side
UR_CALL(LaunchInfo.Data.syncToDevice(Queue));

getContext()->logger.debug("launch_info {} (numLocalArgs={}, localArgs={})",
(void *)LaunchInfo.Data.getDevicePtr(),
LaunchInfo.Data.Host.NumLocalArgs,
(void *)LaunchInfo.Data.Host.LocalArgs);
getContext()->logger.info(
"LaunchInfo {} (device={}, debug={}, numLocalArgs={}, localArgs={})",
(void *)LaunchInfo.Data.getDevicePtr(),
ToString(LaunchInfo.Data.Host.DeviceTy), LaunchInfo.Data.Host.Debug,
LaunchInfo.Data.Host.NumLocalArgs,
(void *)LaunchInfo.Data.Host.LocalArgs);

return UR_RESULT_SUCCESS;
}
Expand Down
12 changes: 7 additions & 5 deletions source/loader/layers/sanitizer/asan/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ struct KernelInfo {
ur_kernel_handle_t Handle;
std::atomic<int32_t> RefCount = 1;

// sanitized kernel
bool IsInstrumented = false;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
Expand All @@ -94,7 +97,8 @@ struct KernelInfo {
// Need preserve the order of local arguments
std::map<uint32_t, LocalArgsInfo> LocalArgs;

explicit KernelInfo(ur_kernel_handle_t Kernel) : Handle(Kernel) {
explicit KernelInfo(ur_kernel_handle_t Kernel, bool IsInstrumented)
: Handle(Kernel), IsInstrumented(IsInstrumented) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Kernel.pfnRetain(Kernel);
assert(Result == UR_RESULT_SUCCESS);
Expand Down Expand Up @@ -348,10 +352,8 @@ class AsanInterceptor {

std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return m_KernelMap[Kernel];
}
return nullptr;
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
return m_KernelMap[Kernel];
}

const AsanOptions &getOptions() { return m_Options; }
Expand Down
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan/asan_libdevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct AsanRuntimeData {
uint32_t Debug = 0;

int ReportFlag = 0;
AsanErrorReport Report[ASAN_MAX_NUM_REPORTS];
AsanErrorReport Report[ASAN_MAX_NUM_REPORTS] = {};
};

constexpr unsigned ASAN_SHADOW_SCALE = 4;
Expand Down
Loading