Skip to content

Commit

Permalink
Merge pull request oneapi-src#1770 from igchor/command_list_cache_v2
Browse files Browse the repository at this point in the history
[L0] implement command list cache for queue
  • Loading branch information
pbalcer authored Jun 26, 2024
2 parents f040b1f + 81620ab commit ba4e49d
Show file tree
Hide file tree
Showing 8 changed files with 438 additions and 3 deletions.
3 changes: 0 additions & 3 deletions source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ ur_result_t ze2urResult(ze_result_t ZeResult) {
}
}

usm::DisjointPoolAllConfigs DisjointPoolConfigInstance =
InitializeDisjointPoolConfig();

// This function will ensure compatibility with both Linux and Windows for
// setting environment variables.
bool setEnvVar(const char *name, const char *value) {
Expand Down
8 changes: 8 additions & 0 deletions source/adapters/level_zero/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ ur_result_t ze2urResult(ze_result_t ZeResult);
return ze2urResult(Result); \
}

// Trace a call to Level-Zero RT, throw on error
#define ZE2UR_CALL_THROWS(ZeName, ZeArgs) \
{ \
ze_result_t ZeResult = ZeName ZeArgs; \
if (auto Result = ZeCall().doCall(ZeResult, #ZeName, #ZeArgs, true)) \
throw ze2urResult(Result); \
}

// Perform traced call to L0 without checking for errors
#define ZE_CALL_NOCHECK(ZeName, ZeArgs) \
ZeCall().doCall(ZeName ZeArgs, #ZeName, #ZeArgs, false)
Expand Down
3 changes: 3 additions & 0 deletions source/adapters/level_zero/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

#include <umf_helpers.hpp>

usm::DisjointPoolAllConfigs DisjointPoolConfigInstance =
InitializeDisjointPoolConfig();

ur_result_t umf2urResult(umf_result_t umfResult) {
if (umfResult == UMF_RESULT_SUCCESS)
return UR_RESULT_SUCCESS;
Expand Down
154 changes: 154 additions & 0 deletions source/adapters/level_zero/v2/command_list_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
//===--------- command_list_cache.cpp - Level Zero Adapter ---------------===//
//
// Copyright (C) 2024 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "command_list_cache.hpp"

#include "context.hpp"
#include "device.hpp"

bool v2::immediate_command_list_descriptor_t::operator==(
const immediate_command_list_descriptor_t &rhs) const {
return ZeDevice == rhs.ZeDevice && IsInOrder == rhs.IsInOrder &&
Mode == rhs.Mode && Priority == rhs.Priority && Index == rhs.Index;
}

bool v2::regular_command_list_descriptor_t::operator==(
const regular_command_list_descriptor_t &rhs) const {
return ZeDevice == rhs.ZeDevice && Ordinal == rhs.Ordinal &&
IsInOrder == rhs.IsInOrder;
}

namespace v2 {
inline size_t command_list_descriptor_hash_t::operator()(
const command_list_descriptor_t &desc) const {
if (auto ImmCmdDesc =
std::get_if<immediate_command_list_descriptor_t>(&desc)) {
return combine_hashes(0, ImmCmdDesc->ZeDevice, ImmCmdDesc->Ordinal,
ImmCmdDesc->IsInOrder, ImmCmdDesc->Mode,
ImmCmdDesc->Priority, ImmCmdDesc->Index);
} else {
auto RegCmdDesc = std::get<regular_command_list_descriptor_t>(desc);
return combine_hashes(0, RegCmdDesc.ZeDevice, RegCmdDesc.IsInOrder,
RegCmdDesc.Ordinal);
}
}

command_list_cache_t::command_list_cache_t(ze_context_handle_t ZeContext)
: ZeContext{ZeContext} {}

raii::ze_command_list_t
command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
if (auto ImmCmdDesc =
std::get_if<immediate_command_list_descriptor_t>(&desc)) {
ze_command_list_handle_t ZeCommandList;
ZeStruct<ze_command_queue_desc_t> QueueDesc;
QueueDesc.ordinal = ImmCmdDesc->Ordinal;
QueueDesc.mode = ImmCmdDesc->Mode;
QueueDesc.priority = ImmCmdDesc->Priority;
QueueDesc.flags =
ImmCmdDesc->IsInOrder ? ZE_COMMAND_QUEUE_FLAG_IN_ORDER : 0;
if (ImmCmdDesc->Index.has_value()) {
QueueDesc.flags |= ZE_COMMAND_QUEUE_FLAG_EXPLICIT_ONLY;
QueueDesc.index = ImmCmdDesc->Index.value();
}
ZE2UR_CALL_THROWS(
zeCommandListCreateImmediate,
(ZeContext, ImmCmdDesc->ZeDevice, &QueueDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
} else {
auto RegCmdDesc = std::get<regular_command_list_descriptor_t>(desc);
ZeStruct<ze_command_list_desc_t> CmdListDesc;
CmdListDesc.flags =
RegCmdDesc.IsInOrder ? ZE_COMMAND_LIST_FLAG_IN_ORDER : 0;
CmdListDesc.commandQueueGroupOrdinal = RegCmdDesc.Ordinal;

ze_command_list_handle_t ZeCommandList;
ZE2UR_CALL_THROWS(zeCommandListCreate, (ZeContext, RegCmdDesc.ZeDevice,
&CmdListDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
}
}

raii::ze_command_list_t command_list_cache_t::getImmediateCommandList(
ze_device_handle_t ZeDevice, bool IsInOrder, uint32_t Ordinal,
ze_command_queue_mode_t Mode, ze_command_queue_priority_t Priority,
std::optional<uint32_t> Index) {
immediate_command_list_descriptor_t Desc;
Desc.ZeDevice = ZeDevice;
Desc.Ordinal = Ordinal;
Desc.IsInOrder = IsInOrder;
Desc.Mode = Mode;
Desc.Priority = Priority;
Desc.Index = Index;
return getCommandList(Desc);
}

raii::ze_command_list_t
command_list_cache_t::getRegularCommandList(ze_device_handle_t ZeDevice,
bool IsInOrder, uint32_t Ordinal) {
regular_command_list_descriptor_t Desc;
Desc.ZeDevice = ZeDevice;
Desc.IsInOrder = IsInOrder;
Desc.Ordinal = Ordinal;
return getCommandList(Desc);
}

void command_list_cache_t::addImmediateCommandList(
raii::ze_command_list_t cmdList, ze_device_handle_t ZeDevice,
bool IsInOrder, uint32_t Ordinal, ze_command_queue_mode_t Mode,
ze_command_queue_priority_t Priority, std::optional<uint32_t> Index) {
immediate_command_list_descriptor_t Desc;
Desc.ZeDevice = ZeDevice;
Desc.Ordinal = Ordinal;
Desc.IsInOrder = IsInOrder;
Desc.Mode = Mode;
Desc.Priority = Priority;
Desc.Index = Index;
addCommandList(Desc, std::move(cmdList));
}

void command_list_cache_t::addRegularCommandList(
raii::ze_command_list_t cmdList, ze_device_handle_t ZeDevice,
bool IsInOrder, uint32_t Ordinal) {
regular_command_list_descriptor_t Desc;
Desc.ZeDevice = ZeDevice;
Desc.IsInOrder = IsInOrder;
Desc.Ordinal = Ordinal;
addCommandList(Desc, std::move(cmdList));
}

raii::ze_command_list_t
command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto it = ZeCommandListCache.find(desc);
if (it == ZeCommandListCache.end()) {
Lock.unlock();
return createCommandList(desc);
}

assert(!it->second.empty());

raii::ze_command_list_t CommandListHandle = std::move(it->second.top());
it->second.pop();

if (it->second.empty())
ZeCommandListCache.erase(it);

return CommandListHandle;
}

void command_list_cache_t::addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList) {
// TODO: add a limit?
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto [it, _] = ZeCommandListCache.try_emplace(desc);
it->second.emplace(std::move(cmdList));
}
} // namespace v2
86 changes: 86 additions & 0 deletions source/adapters/level_zero/v2/command_list_cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===--------- command_list_cache.hpp - Level Zero Adapter ---------------===//
//
// Copyright (C) 2024 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include <stack>

#include <ur/ur.hpp>
#include <ur_api.h>
#include <ze_api.h>

#include "common.hpp"

namespace v2 {
namespace raii {
using ze_command_list_t = std::unique_ptr<::_ze_command_list_handle_t,
decltype(&zeCommandListDestroy)>;
}

struct immediate_command_list_descriptor_t {
ze_device_handle_t ZeDevice;
bool IsInOrder;
uint32_t Ordinal;
ze_command_queue_mode_t Mode;
ze_command_queue_priority_t Priority;
std::optional<uint32_t> Index;
bool operator==(const immediate_command_list_descriptor_t &rhs) const;
};

struct regular_command_list_descriptor_t {
ze_device_handle_t ZeDevice;
bool IsInOrder;
uint32_t Ordinal;
bool operator==(const regular_command_list_descriptor_t &rhs) const;
};

using command_list_descriptor_t =
std::variant<immediate_command_list_descriptor_t,
regular_command_list_descriptor_t>;

struct command_list_descriptor_hash_t {
inline size_t operator()(const command_list_descriptor_t &desc) const;
};

struct command_list_cache_t {
command_list_cache_t(ze_context_handle_t ZeContext);

raii::ze_command_list_t
getImmediateCommandList(ze_device_handle_t ZeDevice, bool IsInOrder,
uint32_t Ordinal, ze_command_queue_mode_t Mode,
ze_command_queue_priority_t Priority,
std::optional<uint32_t> Index = std::nullopt);
raii::ze_command_list_t getRegularCommandList(ze_device_handle_t ZeDevice,
bool IsInOrder,
uint32_t Ordinal);

void addImmediateCommandList(raii::ze_command_list_t cmdList,
ze_device_handle_t ZeDevice, bool IsInOrder,
uint32_t Ordinal, ze_command_queue_mode_t Mode,
ze_command_queue_priority_t Priority,
std::optional<uint32_t> Index = std::nullopt);
void addRegularCommandList(raii::ze_command_list_t cmdList,
ze_device_handle_t ZeDevice, bool IsInOrder,
uint32_t Ordinal);

private:
ze_context_handle_t ZeContext;
std::unordered_map<command_list_descriptor_t,
std::stack<raii::ze_command_list_t>,
command_list_descriptor_hash_t>
ZeCommandListCache;
ur_mutex ZeCommandListCacheMutex;

raii::ze_command_list_t getCommandList(const command_list_descriptor_t &desc);
void addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList);
raii::ze_command_list_t
createCommandList(const command_list_descriptor_t &desc);
};
} // namespace v2
2 changes: 2 additions & 0 deletions test/adapters/level_zero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@ if(NOT WIN32)

target_link_libraries(test-adapter-level_zero_multi_queue PRIVATE zeCallMap)
endif()

add_subdirectory(v2)
32 changes: 32 additions & 0 deletions test/adapters/level_zero/v2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2024 Intel Corporation
# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
# See LICENSE.TXT
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

function(add_unittest name)
set(target test-adapter-${name})
add_adapter_test(${name}
FIXTURE DEVICES
ENVIRONMENT
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_level_zero>\""
SOURCES
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/common.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/ur_level_zero.cpp
${ARGN})

target_include_directories(${target} PUBLIC
${PROJECT_SOURCE_DIR}/source
${PROJECT_SOURCE_DIR}/source/adapters/level_zero
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2
LevelZeroLoader-Headers)

target_link_libraries(${target} PRIVATE
${PROJECT_NAME}::common
LevelZeroLoader
LevelZeroLoader-Headers
)
endfunction()

add_unittest(level_zero_command_list_cache
command_list_cache_test.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/command_list_cache.cpp)
Loading

0 comments on commit ba4e49d

Please sign in to comment.