diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index b87532debe4bc..6ea3f93cdea12 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+max_size : int
+max size in the context. Usage depend on the EP.
notes : string
(Optional) Some notes for the model
onnx_model_filename : string
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index b1a79f5921328..8e881c757f9ac 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -3667,6 +3667,9 @@ struct OrtApi {
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
+ * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
+ * - "0": Default. Disabled.
+ * - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 09a4a77780916..c7a0793c4748f 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3335,6 +3335,11 @@ void RegisterContribSchemas() {
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
+ .Attr(
+ "max_size",
+ "max size in the context. Usage depend on the EP.",
+ AttributeProto::INT,
+ static_cast(0))
.AllowUncheckedAttributes()
.Input(
0,
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
index 57ae8c354abb7..79674fd706151 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
@@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
- QnnModelLookupTable& qnn_models) {
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
@@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()),
static_cast(context_binary.length()),
main_context_node.Name(),
- qnn_models);
+ qnn_models,
+ max_spill_fill_size);
}
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
@@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast(buffer_size),
main_context_node.Name(),
- qnn_models);
+ qnn_models,
+ max_spill_fill_size);
+}
+
+Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs,
+ uint32_t total_context_size,
+ int64_t& max_spill_fill_size,
+ std::vector& main_context_pos_list) {
+ max_spill_fill_size = 0;
+ int max_size_index = 0;
+ for (uint32_t i = 0; i < total_context_size; ++i) {
+ auto index = main_context_pos_list[i];
+ const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph);
+ ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
+ const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin();
+ NodeAttrHelper node_helper(*ep_context_node);
+ int64_t max_size = node_helper.Get(MAX_SIZE, static_cast(0));
+ if (max_size > max_spill_fill_size) {
+ max_spill_fill_size = max_size;
+ max_size_index = i;
+ }
+ }
+ if (0 != max_size_index) {
+ int tmp_index = main_context_pos_list[0];
+ main_context_pos_list[0] = main_context_pos_list[max_size_index];
+ main_context_pos_list[max_size_index] = tmp_index;
+ }
+
+ return Status::OK();
}
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
- const logging::Logger& logger) {
+ const logging::Logger& logger,
+ int64_t max_spill_fill_size) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
- qnn_models);
+ qnn_models, max_spill_fill_size);
// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
@@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
+ uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger) {
auto& graph = model->MainGraph();
@@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model,
}
of_stream.write(reinterpret_cast(buffer), buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
+ ep_node.AddAttribute(MAX_SIZE, static_cast(max_spill_fill_buffer_size));
}
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0));
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
index f308a7456d46c..92c5391b40f09 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
@@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
+static const std::string MAX_SIZE = "max_size";
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);
@@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
- QnnModelLookupTable& qnn_models);
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size);
+
+Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs,
+ uint32_t total_context_size,
+ int64_t& max_spill_fill_size,
+ std::vector& main_context_pos_list);
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
- const logging::Logger& logger);
+ const logging::Logger& logger,
+ int64_t max_spill_fill_size);
Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
@@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model,
const std::unordered_map>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
+ uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
index f37c91aa0413b..8a717c3f29ff9 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
@@ -8,6 +8,7 @@
#include
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
+#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
@@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() {
}
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
- QnnHtpContext_CustomConfig_t customConfig;
- customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
- customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
+ QnnHtpContext_CustomConfig_t custom_config;
+ custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
+ custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
- context_config_weight_sharing.customConfig = &customConfig;
+ context_config_weight_sharing.customConfig = &custom_config;
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
@@ -615,9 +616,71 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6
return context_buffer;
}
+Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
+ uint64_t buffer_length,
+ uint64_t& max_spill_fill_buffer_size) {
+ bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
+ nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
+ nullptr == qnn_sys_interface_.systemContextFree;
+ ORT_RETURN_IF(result, "Failed to get valid function pointer.");
+
+ QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
+ auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
+
+ const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
+ Qnn_ContextBinarySize_t binary_info_size{0};
+ rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
+ static_cast(buffer),
+ buffer_length,
+ &binary_info,
+ &binary_info_size);
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
+
+ // binary_info life cycle is here
+ // Binary info to graph info
+ // retrieve Qnn graph info from binary info
+ ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
+ uint32_t graph_count = 0;
+ QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
+ if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
+ graph_count = binary_info->contextBinaryInfoV3.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV3.graphs;
+ } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
+ graph_count = binary_info->contextBinaryInfoV2.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV2.graphs;
+ } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
+ graph_count = binary_info->contextBinaryInfoV1.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV1.graphs;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
+ }
+
+ for (uint32_t i = 0; i < graph_count; ++i) {
+ if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
+ auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo);
+ if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
+ auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
+ max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size;
+ } else {
+ LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
+ }
+ } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
+ graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
+ LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
+ } else {
+ LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
+ }
+ }
+
+ LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed.";
+ return Status::OK();
+}
+
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
- QnnModelLookupTable& qnn_models) {
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
@@ -638,7 +701,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
// binary_info life cycle is here
// Binary info to graph info
- // retrieve Qnn graph infor from binary info
+ // retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
@@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;
- ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
- "Invalid function pointer for contextCreateFromBinary.");
-
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
- const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};
+ // Register spill fill buffer for multi context
+ QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
+
+ // The spill fill buffer is available since 2.28, API version starts from 2.21
+#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
+ QnnHtpContext_CustomConfig_t custom_config;
+ custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
+ QnnHtpContext_GroupRegistration_t group_info;
+ size_t current_contexts_size = GetQnnContextSize();
+ // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
+ // note that we already move the context with max spill fill size to the beginning of the list
+ group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
+ group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
+ custom_config.groupRegistration = group_info;
+ spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
+ spill_fill_config.customConfig = &custom_config;
+#endif
+ QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
+ LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
+
+ const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
+
+ ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
+ "Invalid function pointer for contextCreateFromBinary.");
Qnn_ContextHandle_t context = nullptr;
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
@@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
buffer_length,
&context,
profile_backend_handle_);
- ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
if (1 == graph_count) {
// in case the EPContext node is generated from script
@@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
return Status::OK();
}
-Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
+// need to load system lib if load from Qnn context binary
+// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
+Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
+ bool load_from_cached_context,
+ bool need_load_system_lib) {
std::lock_guard lock(logger_mutex_);
if (backend_setup_completed_) {
LOGS(logger, VERBOSE) << "Backend setup already!";
@@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
LOGS(logger, VERBOSE) << "LoadBackend succeed.";
- if (load_from_cached_context) {
+ if (load_from_cached_context || need_load_system_lib) {
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
}
@@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
return Status::OK();
}
-void QnnBackendManager::Split(std::vector& split_string,
- const std::string& tokenized_string,
- const char separator) {
- split_string.clear();
- std::istringstream tokenized_string_stream(tokenized_string);
- while (!tokenized_string_stream.eof()) {
- std::string value;
- getline(tokenized_string_stream, value, separator);
- if (!value.empty()) {
- split_string.push_back(value);
- }
- }
-}
-
Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
index 43007d4a5c244..b145f2a2cd724 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
@@ -93,9 +93,10 @@ class QnnBackendManager {
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
- std::unordered_map>& qnn_models);
+ std::unordered_map>& qnn_models,
+ int64_t max_spill_fill_size);
- Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
+ Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
@@ -112,6 +113,10 @@ class QnnBackendManager {
return contexts_[index];
}
+ size_t GetQnnContextSize() {
+ return contexts_.size();
+ }
+
const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; }
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
@@ -145,8 +150,6 @@ class QnnBackendManager {
void ReleaseResources();
- void Split(std::vector& split_string, const std::string& tokenized_string, const char separator);
-
Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
@@ -163,6 +166,10 @@ class QnnBackendManager {
Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);
+ Status GetMaxSpillFillBufferSize(unsigned char* buffer,
+ uint64_t buffer_length,
+ uint64_t& max_spill_fill_buffer_size);
+
private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
index 6735528bebbf9..3bb069196e31c 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
@@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}
+ bool enable_htp_weight_sharing = false;
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
if ("1" == htp_weight_sharing_enabled_pos->second) {
- enable_htp_weight_sharing_ = true;
+ enable_htp_weight_sharing = true;
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
- enable_htp_weight_sharing_ = false;
+ enable_htp_weight_sharing = false;
} else {
- LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_
+ LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
<< " only 0 or 1 allowed. Set to 0.";
}
- LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_;
+ LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
}
+ // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform
+ enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map);
+
model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false,
provider_options_map);
@@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
device_id_,
htp_arch,
soc_model,
- enable_htp_weight_sharing_);
+ enable_htp_weight_sharing);
#ifdef _WIN32
auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance();
@@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
// It will load the QnnSystem lib if is_qnn_ctx_model=true, and
// delay the Qnn context creation to Compile() using the cached context binary
- auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model);
+ // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size
+ auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage();
return result;
@@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
std::vector main_context_pos_list;
ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list));
+ uint32_t total_context_size = SafeInt(main_context_pos_list.size());
+
+ int64_t max_spill_fill_size = 0;
+
+ // Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning
+ // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28
+ if (total_context_size > 1) {
+ ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size,
+ max_spill_fill_size, main_context_pos_list));
+ }
for (auto main_context_pos : main_context_pos_list) {
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
@@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
context_cache_path,
qnn_backend_manager_.get(),
qnn_models,
- logger));
+ logger,
+ max_spill_fill_size));
}
for (auto fused_node_and_graph : fused_nodes_and_graphs) {
@@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
// All partitioned graph share single QNN context, included in the same context binary
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
+ // Get max spill fill buffer size
+ uint64_t max_spill_fill_buffer_size = 0;
+ if (enable_spill_fill_buffer_) {
+ ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(),
+ buffer_size,
+ max_spill_fill_buffer_size));
+ }
qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger);
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(),
context_buffer.get(),
@@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
qnn_models_,
context_cache_path,
qnn_context_embed_mode_,
+ max_spill_fill_buffer_size,
logger));
}
return Status::OK();
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
index 35c061de6132c..a0577e8fd87f2 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
@@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
- bool enable_htp_weight_sharing_ = false;
int32_t vtcm_size_in_mb_ = 0;
std::unique_ptr qnn_ep_context_model_;
ModelMetadefIdGenerator metadef_id_generator_;
@@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider {
uint32_t default_rpc_control_latency_ = 0;
bool enable_HTP_FP16_precision_ = true;
bool share_ep_contexts_ = false;
+ bool enable_spill_fill_buffer_ = false;
#ifdef _WIN32
onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr;
#endif
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index e406405464d99..3f2c2cb7f761c 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -101,6 +101,7 @@ namespace perftest {
"\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n"
"\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n"
"\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n"
+ "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary."
"\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n"
"\n"
"\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n"
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 02768b8c08e85..5db1894a5074b 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -202,7 +202,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
{"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency",
"vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path",
"htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch",
- "enable_htp_fp16_precision", "offload_graph_io_quantization"});
+ "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer"});
for (const auto& provider_option : provider_options) {
const std::string& key = provider_option.first;
const std::string& value = provider_option.second;
@@ -253,7 +253,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_arch. select from: " + str);
}
- } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") {
+ } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") {
std::unordered_set supported_options = {"0", "1"};
if (supported_options.find(value) == supported_options.end()) {
std::ostringstream str_stream;
diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc
index 5b3720992c542..24c343c7b9541 100644
--- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc
+++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc
@@ -50,6 +50,7 @@ namespace qnnctxgen {
"\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n"
"\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n"
"\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n"
+ "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary."
"\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n"
"\n"
"\t-h: help\n");
@@ -146,7 +147,7 @@ static bool ParseSessionConfigs(const std::string& configs_string,
ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str);
}
} else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" ||
- key == "offload_graph_io_quantization") {
+ key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") {
std::unordered_set supported_options = {"0", "1"};
if (supported_options.find(value) == supported_options.end()) {
std::ostringstream str_stream;
@@ -158,7 +159,7 @@ static bool ParseSessionConfigs(const std::string& configs_string,
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode',
'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing',
- 'offload_graph_io_quantization'])");
+ 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])");
}
test_config.run_config.qnn_options[key] = value;