From ddb6e6558862eee1933c0495775813c42323c527 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 6 Dec 2024 11:36:52 -0800 Subject: [PATCH] Enable QNN HTP spill fill buffer setting to save RAM usage. (#22853) ### Description Enable QNN HTP spill fill buffer setting to save RAM usage. This feature is available after QNN 2.28. Need to re-generate QNN context binary. https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_backend.html#qnn-htp-backend-api Requirements: 1. Need to re-generate the Onnx model with QNN context binary by set the EP option enable_htp_spill_fill_buffer = 1. 2. Works for a model with multiple Context binaries. Need manually merge 2 Onnx model with context binary into 1 Onnx model. 3. Requires Linux platform if generate the context binary offline since QnnSystem lib is not available for Windows x86_64 platform. No need to do extra thing while running the model inference. The generated EPContext node will have a max_size attribute with the maximum spill fill buffer size for the context binary image --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/ContribOperators.md | 2 + .../core/session/onnxruntime_c_api.h | 3 + .../core/graph/contrib_ops/contrib_defs.cc | 5 + .../qnn/builder/onnx_ctx_model_helper.cc | 43 +++++- .../qnn/builder/onnx_ctx_model_helper.h | 13 +- .../qnn/builder/qnn_backend_manager.cc | 127 ++++++++++++++---- .../qnn/builder/qnn_backend_manager.h | 15 ++- .../providers/qnn/qnn_execution_provider.cc | 38 +++++- .../providers/qnn/qnn_execution_provider.h | 2 +- .../test/perftest/command_args_parser.cc | 1 + onnxruntime/test/perftest/ort_test_session.cc | 4 +- .../test/qnn_ctx_gen/command_args_parser.cc | 5 +- 12 files changed, 208 insertions(+), 50 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b87532debe4bc..6ea3f93cdea12 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+
max_size : int
+
max size in the context. Usage depend on the EP.
notes : string
(Optional) Some notes for the model
onnx_model_filename : string
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b1a79f5921328..8e881c757f9ac 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3667,6 +3667,9 @@ struct OrtApi { * execution provider (typically CPU EP). * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. * - "1": Enabled. + * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary. + * - "0": Default. Disabled. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09a4a77780916..c7a0793c4748f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3335,6 +3335,11 @@ void RegisterContribSchemas() { AttributeProto::STRING, OPTIONAL_VALUE) .Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE) + .Attr( + "max_size", + "max size in the context. Usage depend on the EP.", + AttributeProto::INT, + static_cast(0)) .AllowUncheckedAttributes() .Input( 0, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 57ae8c354abb7..79674fd706151 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); +} + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list) { + max_spill_fill_size = 0; + int max_size_index = 0; + for (uint32_t i = 0; i < total_context_size; ++i) { + auto index = main_context_pos_list[i]; + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph); + ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); + const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin(); + NodeAttrHelper node_helper(*ep_context_node); + int64_t max_size = node_helper.Get(MAX_SIZE, static_cast(0)); + if (max_size > max_spill_fill_size) { + max_spill_fill_size = max_size; + max_size_index = i; + } + } + if (0 != max_size_index) { + int tmp_index = main_context_pos_list[0]; + main_context_pos_list[0] = main_context_pos_list[max_size_index]; + main_context_pos_list[max_size_index] = tmp_index; + } + + return Status::OK(); } Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger) { + const logging::Logger& logger, + int64_t max_spill_fill_size) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - qnn_models); + qnn_models, max_spill_fill_size); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { @@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model, const QnnModelLookupTable& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { auto& graph = model->MainGraph(); @@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model, } of_stream.write(reinterpret_cast(buffer), buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name); + ep_node.AddAttribute(MAX_SIZE, static_cast(max_spill_fill_buffer_size)); } } else { ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index f308a7456d46c..92c5391b40f09 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; +static const std::string MAX_SIZE = "max_size"; bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); @@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models); + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size); + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger); + const logging::Logger& logger, + int64_t max_spill_fill_size); Status CreateEPContextNodes(Model* model, unsigned char* buffer, @@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model, const std::unordered_map>& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index f37c91aa0413b..8a717c3f29ff9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -8,6 +8,7 @@ #include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" +#include "HTP/QnnHtpSystemContext.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" @@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() { } QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT; - QnnHtpContext_CustomConfig_t customConfig; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = enable_htp_weight_sharing_; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + custom_config.weightSharingEnabled = enable_htp_weight_sharing_; context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - context_config_weight_sharing.customConfig = &customConfig; + context_config_weight_sharing.customConfig = &custom_config; QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); @@ -615,9 +616,71 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } +Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size) { + bool result = nullptr == qnn_sys_interface_.systemContextCreate || + nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || + nullptr == qnn_sys_interface_.systemContextFree; + ORT_RETURN_IF(result, "Failed to get valid function pointer."); + + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; + auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); + + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size{0}; + rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, + static_cast(buffer), + buffer_length, + &binary_info, + &binary_info_size); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + + // binary_info life cycle is here + // Binary info to graph info + // retrieve Qnn graph info from binary info + ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); + uint32_t graph_count = 0; + QnnSystemContext_GraphInfo_t* graphs_info = nullptr; + if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + graphs_info = binary_info->contextBinaryInfoV3.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + graph_count = binary_info->contextBinaryInfoV2.numGraphs; + graphs_info = binary_info->contextBinaryInfoV2.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + graph_count = binary_info->contextBinaryInfoV1.numGraphs; + graphs_info = binary_info->contextBinaryInfoV1.graphs; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } + + for (uint32_t i = 0; i < graph_count; ++i) { + if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo); + if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) { + auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize; + max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version."; + } + } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 || + graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2."; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version."; + } + } + + LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed."; + return Status::OK(); +} + Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -638,7 +701,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // binary_info life cycle is here // Binary info to graph info - // retrieve Qnn graph infor from binary info + // retrieve Qnn graph info from binary info ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); uint32_t graph_count = 0; QnnSystemContext_GraphInfo_t* graphs_info = nullptr; @@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; - ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, - "Invalid function pointer for contextCreateFromBinary."); - QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); - const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + // Register spill fill buffer for multi context + QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; + + // The spill fill buffer is available since 2.28, API version starts from 2.21 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + size_t current_contexts_size = GetQnnContextSize(); + // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle + // note that we already move the context with max spill fill size to the beginning of the list + group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0; + group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0 + custom_config.groupRegistration = group_info; + spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + spill_fill_config.customConfig = &custom_config; +#endif + QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr; + LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size; + + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); Qnn_ContextHandle_t context = nullptr; rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, @@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, &context, profile_backend_handle_); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); contexts_.push_back(context); if (1 == graph_count) { // in case the EPContext node is generated from script @@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t return Status::OK(); } -Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { +// need to load system lib if load from Qnn context binary +// or generate Qnn context binary is enabled -- to get the max spill fill buffer size +Status QnnBackendManager::SetupBackend(const logging::Logger& logger, + bool load_from_cached_context, + bool need_load_system_lib) { std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; @@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "LoadBackend succeed."; - if (load_from_cached_context) { + if (load_from_cached_context || need_load_system_lib) { ORT_RETURN_IF_ERROR(LoadQnnSystemLib()); } @@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ return Status::OK(); } -void QnnBackendManager::Split(std::vector& split_string, - const std::string& tokenized_string, - const char separator) { - split_string.clear(); - std::istringstream tokenized_string_stream(tokenized_string); - while (!tokenized_string_stream.eof()) { - std::string value; - getline(tokenized_string_stream, value, separator); - if (!value.empty()) { - split_string.push_back(value); - } - } -} - Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 43007d4a5c244..b145f2a2cd724 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -93,9 +93,10 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + int64_t max_spill_fill_size); - Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); + Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -112,6 +113,10 @@ class QnnBackendManager { return contexts_[index]; } + size_t GetQnnContextSize() { + return contexts_.size(); + } + const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; } const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } @@ -145,8 +150,6 @@ class QnnBackendManager { void ReleaseResources(); - void Split(std::vector& split_string, const std::string& tokenized_string, const char separator); - Status ExtractBackendProfilingInfo(); Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); @@ -163,6 +166,10 @@ class QnnBackendManager { Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + Status GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6735528bebbf9..3bb069196e31c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_; } + bool enable_htp_weight_sharing = false; static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing"; auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED); if (htp_weight_sharing_enabled_pos != provider_options_map.end()) { if ("1" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = true; + enable_htp_weight_sharing = true; } else if ("0" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = false; + enable_htp_weight_sharing = false; } else { - LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_ + LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing << " only 0 or 1 allowed. Set to 0."; } - LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; + LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing; } + // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform + enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map); + model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, provider_options_map); @@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio device_id_, htp_arch, soc_model, - enable_htp_weight_sharing_); + enable_htp_weight_sharing); #ifdef _WIN32 auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); @@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // It will load the QnnSystem lib if is_qnn_ctx_model=true, and // delay the Qnn context creation to Compile() using the cached context binary - auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); + // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::vector main_context_pos_list; ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list)); + uint32_t total_context_size = SafeInt(main_context_pos_list.size()); + + int64_t max_spill_fill_size = 0; + + // Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning + // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28 + if (total_context_size > 1) { + ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size, + max_spill_fill_size, main_context_pos_list)); + } for (auto main_context_pos : main_context_pos_list) { const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); @@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused context_cache_path, qnn_backend_manager_.get(), qnn_models, - logger)); + logger, + max_spill_fill_size)); } for (auto fused_node_and_graph : fused_nodes_and_graphs) { @@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); + // Get max spill fill buffer size + uint64_t max_spill_fill_buffer_size = 0; + if (enable_spill_fill_buffer_) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(), + buffer_size, + max_spill_fill_buffer_size)); + } qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), context_buffer.get(), @@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_models_, context_cache_path, qnn_context_embed_mode_, + max_spill_fill_buffer_size, logger)); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 35c061de6132c..a0577e8fd87f2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; - bool enable_htp_weight_sharing_ = false; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; @@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool enable_spill_fill_buffer_ = false; #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index e406405464d99..3f2c2cb7f761c 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -101,6 +101,7 @@ namespace perftest { "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 02768b8c08e85..5db1894a5074b 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -202,7 +202,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device {"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency", "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", - "enable_htp_fp16_precision", "offload_graph_io_quantization"}); + "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer"}); for (const auto& provider_option : provider_options) { const std::string& key = provider_option.first; const std::string& value = provider_option.second; @@ -253,7 +253,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 5b3720992c542..24c343c7b9541 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -50,6 +50,7 @@ namespace qnnctxgen { "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -146,7 +147,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" || - key == "offload_graph_io_quantization") { + key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -158,7 +159,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode', 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing', - 'offload_graph_io_quantization'])"); + 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } test_config.run_config.qnn_options[key] = value;