diff --git a/CODEOWNERS b/CODEOWNERS index d09058ed94f7f..e1f79a504ad8b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,6 +1,5 @@ # Mobile -/onnxruntime/test/testdata/kernel_def_hashes/ @microsoft/onnxruntime-mobile -/onnxruntime/core/framework/kernel_def_hash_helpers.* @microsoft/onnxruntime-mobile +/onnxruntime/core/flatbuffers/schema/ort.fbs @microsoft/onnxruntime-mobile # Contrib Ops /onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @microsoft/onnxruntime-mlas diff --git a/docs/ORT_Format_Update_in_1.13.md b/docs/ORT_Format_Update_in_1.13.md new file mode 100644 index 0000000000000..fa67da927a1ef --- /dev/null +++ b/docs/ORT_Format_Update_in_1.13.md @@ -0,0 +1,12 @@ +# ORT Format Update in 1.13 + +In ONNX Runtime 1.13, there was a breaking change to the +[ORT format](https://onnxruntime.ai/docs/reference/ort-format-models.html) in order to enable additional execution +providers with statically registered kernels in a minimal build. +More details can be found [here](../onnxruntime/core/flatbuffers/schema/README.md#version-5). + +Unfortunately, this means that any older models (prior to ORT format version 5) will no longer work with ONNX Runtime +1.13 or later and must be re-converted. +Please refer +[here](https://onnxruntime.ai/docs/reference/ort-format-models.html#convert-onnx-models-to-ort-format) for instructions +on how to convert an ONNX model to ORT format. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 680979ca8a2cf..dc8f312190163 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -418,7 +418,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**X** = tensor(float)| +|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcMaxPool|*in* x:**T**
*out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 501634bf32509..d411ed2451cf3 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -17,9 +17,10 @@ #pragma once -#include -#include #include +#include +#include +#include #include #include #include @@ -28,8 +29,8 @@ #include #include #include +#include #include -#include #include "core/common/code_location.h" #include "core/common/exceptions.h" @@ -279,9 +280,10 @@ constexpr size_t kMaxStrLen = 2048; // Returns whether `key` is in `container`. // Like C++20's map/set contains() member function. template typename AssociativeContainer> -inline bool Contains(const AssociativeContainer& container, const Key& key) { - return container.find(key) != container.end(); + template typename AssociativeContainer, + typename LookupKey> +inline bool Contains(const AssociativeContainer& container, LookupKey&& key) { + return container.find(std::forward(key)) != container.end(); } } // namespace onnxruntime diff --git a/include/onnxruntime/core/common/hash_combine.h b/include/onnxruntime/core/common/hash_combine.h new file mode 100644 index 0000000000000..5662a329ea77f --- /dev/null +++ b/include/onnxruntime/core/common/hash_combine.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +// Combine hash value `seed` with hash value `h`, updating `seed` in place. +// TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine() +inline void HashCombineWithHashValue(size_t h, size_t& seed) { + seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +// Combine hash value `seed` with the hash value of `value`, updating `seed` in place. +// The hash value computation is specified by the `Hash` template parameter. +template > +inline void HashCombine(const T& value, size_t& seed) { + HashCombineWithHashValue(Hash{}(value), seed); +} + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/common/parse_string.h b/include/onnxruntime/core/common/parse_string.h index edb34724f1929..941e3f3377ecc 100644 --- a/include/onnxruntime/core/common/parse_string.h +++ b/include/onnxruntime/core/common/parse_string.h @@ -5,6 +5,7 @@ #include #include +#include #include #include "core/common/common.h" @@ -15,7 +16,7 @@ namespace onnxruntime { * Tries to parse a value from an entire string. */ template -bool TryParseStringWithClassicLocale(const std::string& str, T& value) { +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { if constexpr (std::is_integral::value && std::is_unsigned::value) { // if T is unsigned integral type, reject negative values which will wrap if (!str.empty() && str[0] == '-') { @@ -28,7 +29,7 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) { return false; } - std::istringstream is{str}; + std::istringstream is{std::string{str}}; is.imbue(std::locale::classic()); T parsed_value{}; @@ -43,12 +44,12 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) { return true; } -inline bool TryParseStringWithClassicLocale(const std::string& str, std::string& value) { +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { value = str; return true; } -inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value) { +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { if (str == "0" || str == "False" || str == "false") { value = false; return true; @@ -66,7 +67,7 @@ inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value) * Parses a value from an entire string. */ template -Status ParseStringWithClassicLocale(const std::string& s, T& value) { +Status ParseStringWithClassicLocale(std::string_view s, T& value) { ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); return Status::OK(); } @@ -75,7 +76,7 @@ Status ParseStringWithClassicLocale(const std::string& s, T& value) { * Parses a value from an entire string. */ template -T ParseStringWithClassicLocale(const std::string& s) { +T ParseStringWithClassicLocale(std::string_view s) { T value{}; ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); return value; diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 80086a31379d5..f4ca87eb9a5ef 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -458,6 +458,7 @@ class TensorType : public TensorTypeBase { #if defined(DISABLE_OPTIONAL_TYPE) +// TODO is this still needed after removing kernel def hashes? /// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build /// with disabled types to keep the ORT format model kernel hashes stable. class DisabledTypeBase : public DataTypeImpl { diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 6d3fa92f9f7d1..57e9d5e564206 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -15,11 +15,10 @@ namespace onnxruntime { class GraphViewer; -class Node; struct ComputeCapability; class KernelRegistry; -class KernelRegistryManager; - +struct KernelCreateInfo; +class Node; } // namespace onnxruntime #else #include @@ -89,6 +88,19 @@ class IExecutionProvider { return nullptr; } + /** + * Interface for performing kernel lookup within kernel registries. + * Abstracts away lower-level details about kernel registries and kernel matching. + */ + class IKernelLookup { + public: + /** + * Given `node`, try to find a matching kernel for this EP. + * The return value is non-null if and only if a matching kernel was found. + */ + virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0; + }; + /** Get execution provider's capability for the specified . Return a bunch of IndexedSubGraphs <*this> execution provider can run if @@ -96,22 +108,24 @@ class IExecutionProvider { contains more than one node. The node indexes contained in sub-graphs may have overlap, and it's ONNXRuntime's responsibility to do the partition and decide whether a node will be assigned to <*this> execution provider. + For kernels registered in a kernel registry, `kernel_lookup` must be used + to find a matching kernel for this EP. */ virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) const; + const IKernelLookup& kernel_lookup) const; /** Get kernel registry per execution provider type. The KernelRegistry share pointer returned is shared across sessions. - NOTE: this is a tricky but final solution to achieve following goals, + NOTE: this approach was taken to achieve the following goals, 1. The execution provider type based kernel registry should be shared across sessions. Only one copy of this kind of kernel registry exists in ONNXRuntime with multiple sessions/models. 2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime - frameowrk/session code. + framework/session code. 3. onnxruntime (framework/session) does not depend on any specific execution provider lib. */ diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index 8fc45889c8ad3..8b6a2571a5587 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -53,13 +53,15 @@ class KernelDef { return provider_type_; } + // TODO(edgchen1) do we need both TypeConstraints() and EnabledTypeConstraints()? + // type constraints with types supported by default - const std::map>& TypeConstraints() const { + const std::unordered_map>& TypeConstraints() const { return default_type_constraints_; } // type constraints with types supported in this build - const std::map>& EnabledTypeConstraints() const { + const std::unordered_map>& EnabledTypeConstraints() const { return enabled_type_constraints_; } @@ -108,19 +110,9 @@ class KernelDef { bool IsConflict(const KernelDef& other) const; - HashValue GetHash() const noexcept { - // if we need to support different hash versions we can update CalculateHash to take a version number - // and calculate any non-default versions dynamically. we only use this during kernel lookup so - // it's not performance critical - return hash_; - } - private: friend class KernelDefBuilder; - // called once by KernelDefBuilder::Build - void CalculateHash(); - // The operator name supported by <*this> kernel.. std::string op_name_; @@ -139,18 +131,11 @@ class KernelDef { std::string provider_type_; // The data types that are supported by default for inputs/outputs. - // Key is input/output name defined in op schema, Value are supported types. - // note: std::map as we need the order to be deterministic for the hash - // Note: default_type_constraints_ are used to calculate the kernel hash so that the hash is - // stable across builds with and without kernel type reduction enabled. - std::map> default_type_constraints_; + // Key is input/output/type constraint name defined in op schema, Value are supported types. + std::unordered_map> default_type_constraints_; // the type constraints that are supported in this build (enabled) for the kernel - std::map> enabled_type_constraints_; - - // optional alternate type constraints to use to calculate the hash instead of default_type_constraints_ - // note: this provides a way to update the default type constraints while preserving the hash value - optional>> hash_type_constraints_; + std::unordered_map> enabled_type_constraints_; // An element means that output j reuses the memory of input i. std::vector> inplace_map_; @@ -186,9 +171,6 @@ class KernelDef { OrtMemType default_inputs_mem_type_{OrtMemTypeDefault}; // Default memory type for all outputs OrtMemType default_outputs_mem_type_{OrtMemTypeDefault}; - - // hash of kernel definition for lookup in minimal build - HashValue hash_ = 0; }; class KernelDefBuilder { @@ -259,17 +241,6 @@ class KernelDefBuilder { KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType default_type); KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType default_type); - /** - Specify the original set of types that this kernel supports by default to use when computing the kernel def hash. - The set of types supported by default may change over time, but the hash should stay the same. - */ - KernelDefBuilder& FixedTypeConstraintForHash( - const std::string& arg_name, - const std::vector& default_types_for_hash); - KernelDefBuilder& FixedTypeConstraintForHash( - const char* arg_name, - const std::vector& default_types_for_hash); - /** Inplace mapping from inputs to outputs allowed. It means that uplayer runtime could do memory in-place optimization @@ -392,7 +363,6 @@ class KernelDefBuilder { Return the kernel definition, passing ownership of the KernelDef to the caller */ std::unique_ptr Build() { - kernel_def_->CalculateHash(); return std::move(kernel_def_); } diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 224ce00bc7028..68b610ac4b278 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/framework/op_kernel.h" namespace onnxruntime { @@ -10,9 +12,10 @@ namespace onnxruntime { using KernelCreateMap = std::multimap; using KernelDefHashes = std::vector>; +class IKernelTypeStrResolver; + /** * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider. - * */ class KernelRegistry { public: @@ -23,37 +26,28 @@ class KernelRegistry { Status Register(KernelCreateInfo&& create_info); -#if !defined(ORT_MINIMAL_BUILD) - static bool HasImplementationOf(const KernelRegistry& r, const Node& node, - ProviderType exec_provider) { - const KernelCreateInfo* info; - Status st = r.TryFindKernel(node, exec_provider, &info); - return st.IsOK(); - } - - // factory functions should always return a unique_ptr for maximum flexibility - // for its clients unless the factory is managing the lifecycle of the pointer - // itself. - // TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent - Status TryCreateKernel(const Node& node, const IExecutionProvider& execution_provider, - const std::unordered_map& constant_initialized_tensors, - const OrtValueNameIdxMap& mlvalue_name_idx_map, FuncManager& funcs_mgr, - const DataTransferManager& data_transfer_mgr, - std::unique_ptr& op_kernel) const; + // TODO(edgchen1) for TryFindKernel(), consider using `out` != nullptr as indicator of whether kernel was found and + // Status as an indication of failure // Check if an execution provider can create kernel for a node and return the kernel if so Status TryFindKernel(const Node& node, ProviderType exec_provider, + const IKernelTypeStrResolver& kernel_type_str_resolver, const KernelCreateInfo** out) const; + static bool HasImplementationOf(const KernelRegistry& r, const Node& node, + ProviderType exec_provider, + const IKernelTypeStrResolver& kernel_type_str_resolver) { + const KernelCreateInfo* info; + Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); + return st.IsOK(); + } + +#if !defined(ORT_MINIMAL_BUILD) // Find KernelCreateInfo in instant mode Status TryFindKernel(const std::string& op_name, const std::string& domain, const int& version, const std::unordered_map& type_constraints, ProviderType exec_provider, const KernelCreateInfo** out) const; - -#endif - - // Try to find the kernel given a kernel def hash. - bool TryFindKernelByHash(HashValue kernel_def_hash, const KernelCreateInfo** out) const; +#endif // !defined(ORT_MINIMAL_BUILD) bool IsEmpty() const { return kernel_creator_fn_map_.empty(); } @@ -64,11 +58,7 @@ class KernelRegistry { } #endif - // Get sorted kernel def key and hash pairs. - KernelDefHashes ExportKernelDefHashes() const; - private: -#if !defined(ORT_MINIMAL_BUILD) // Check whether the types of inputs/outputs of the given node match the extra // type-constraints of the given kernel. This serves two purposes: first, to // select the right kernel implementation based on the types of the arguments @@ -79,16 +69,12 @@ class KernelRegistry { // // Note that this is not intended for type-checking the node against the ONNX // type specification of the corresponding op, which is done before this check. - // - // if this function is called before graph partition, then node.provider is not set. - // In this case, kernel_def.provider must equal to exec_provider - // otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored. static bool VerifyKernelDef(const Node& node, const KernelDef& kernel_def, + const IKernelTypeStrResolver& kernel_type_str_resolver, std::string& error_str); -#endif - static std::string GetMapKey(const std::string& op_name, const std::string& domain, const std::string& provider) { + static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) { std::string key(op_name); // use the kOnnxDomainAlias of 'ai.onnx' instead of kOnnxDomain's empty string key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider); @@ -101,8 +87,5 @@ class KernelRegistry { // Kernel create function map from op name to kernel creation info. // key is opname+domain_name+provider_name KernelCreateMap kernel_creator_fn_map_; - - // map from kernel def hash to entry in kernel_creator_fn_map_ - std::unordered_map kernel_def_hash_lookup_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index 32b7aab7c4118..6dd259ddc15e4 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -3,6 +3,10 @@ #pragma once +#include + +#include "core/common/hash_combine.h" + struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor @@ -38,17 +42,13 @@ struct OrtMemoryInfo { return strcmp(name, other.name) < 0; } - static void HashCombine(size_t h, size_t& seed) { - seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - // This is to make OrtMemoryInfo a valid key in hash tables // we ignore device id size_t Hash() const { auto h = std::hash()(alloc_type); - HashCombine(std::hash()(mem_type), h); - HashCombine(std::hash()(id), h); - HashCombine(std::hash()(name), h); + onnxruntime::HashCombine(mem_type, h); + onnxruntime::HashCombine(id, h); + onnxruntime::HashCombine(name, h); return h; } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 8e930ce5a8146..7dda32de41b94 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1333,20 +1333,6 @@ class Graph { RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() { return runtime_optimizations_; } - - // Stores information collected during the replay of loaded runtime optimizations - struct RuntimeOptimizationReplayContext { - std::unordered_map produced_node_index_to_kernel_def_hash{}; - size_t num_replayed_optimizations{}; - }; - - const RuntimeOptimizationReplayContext& RuntimeOptimizationReplayCtx() const { - return runtime_optimization_replay_context_; - } - - RuntimeOptimizationReplayContext& MutableRuntimeOptimizationReplayCtx() { - return runtime_optimization_replay_context_; - } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // This friendship relationship should only be used to call Graph::Graph and @@ -1588,8 +1574,6 @@ class Graph { // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized std::unique_ptr runtime_optimizations_ptr_; RuntimeOptimizationRecordContainer& runtime_optimizations_; - - RuntimeOptimizationReplayContext runtime_optimization_replay_context_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) diff --git a/java/src/test/android/README.md b/java/src/test/android/README.md index c5658be38660e..874ba7ba729fd 100644 --- a/java/src/test/android/README.md +++ b/java/src/test/android/README.md @@ -1,17 +1,18 @@ # Android Test Application for ORT-Mobile -This directory contains a simple android application for testing [ONNX Runtime AAR package](https://www.onnxruntime.ai/docs/how-to/build.html#build-android-archive-aar). +This directory contains a simple android application for testing [ONNX Runtime AAR package](https://onnxruntime.ai/docs/build/android.html#build-android-archive-aar). ## Background -For general usage and build purpose of ORT-Mobile Android, please see the [documentation](https://www.onnxruntime.ai/docs/how-to/build.html#android) here. +For general usage and build purpose of ORT-Mobile Android, please see the [documentation](https://onnxruntime.ai/docs/tutorials/mobile/) here. ### Test Android Application Overview This android application is mainly aimed for testing: - Model used: A simple [sigmoid ONNX model](https://github.com/onnx/onnx/blob/f9b0cc99344869c246b8f4011b8586a39841284c/onnx/backend/test/data/node/test_sigmoid/model.onnx) (converted to ORT format under `app\src\androidTest\assets` folder). - - Here's a [documentation](https://github.com/microsoft/onnxruntime/blob/main/docs/ONNX_Runtime_for_Mobile_Platforms.md#1-create-ort-format-model-and-configuration-file-with-required-operators) about how you can convert an ONNX model into ORT format. + - Here's [documentation](https://onnxruntime.ai/docs/reference/ort-format-models.html#convert-onnx-models-to-ort-format) about how you can convert an ONNX model into ORT format. + - Run `python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed /path/to/model.onnx` and rename the resulting .ort file accordingly. - Main test file: An android instrumentation test under `app\src\androidtest\java\ai.onnxruntime.example.javavalidator\SimpleTest.kt` - The main dependency of this application is `onnxruntime` aar package under `app\libs`. - The MainActivity of this application is set to be empty. @@ -23,7 +24,7 @@ This android application is mainly aimed for testing: ### Building -Use the android's [build instructions](https://www.onnxruntime.ai/docs/how-to/build.html#android-build-instructions) with `--build_java` and `--android_run_emulator` option. +Use the android's [build instructions](https://onnxruntime.ai/docs/build/android.html) with `--build_java` and `--android_run_emulator` option. Please note that you may need to set the `--android_abi=x86_64` (the default option is `arm64-v8a`). This is because android instrumentation test is run on an android emulator which requires an abi of `x86_64`. diff --git a/java/src/test/android/app/src/androidTest/assets/sigmoid.ort b/java/src/test/android/app/src/androidTest/assets/sigmoid.ort index 6336fed141a5e..70d7659bbee25 100644 Binary files a/java/src/test/android/app/src/androidTest/assets/sigmoid.ort and b/java/src/test/android/app/src/androidTest/assets/sigmoid.ort differ diff --git a/js/README.md b/js/README.md index d5b55af6e1452..9f1a6150a70ae 100644 --- a/js/README.md +++ b/js/README.md @@ -408,7 +408,7 @@ By default, ONNX Runtime React Native leverages ONNX Runtime Mobile package with yarn bootstrap ``` - When testing with a custom built ONNX Runtime Android package, copy `/aar_out/MinSizeRel/com/microsoft/onnxruntime/onnxruntime-mobile//onnxruntime-mobile-.aar` into `/js/react_native/e2e/node_modules/onnxruntime-react-native/android/libs` directory. Using a custom built ONNX Runtime iOS package, copy `onnxruntime-mobile-c.zip` into `/js/react_native/local_pods` directory if it's not already done. + When testing with a custom built ONNX Runtime Android package, copy `/aar_out/MinSizeRel/com/microsoft/onnxruntime/onnxruntime-mobile//onnxruntime-mobile-.aar` into `/js/react_native/e2e/android/app/libs` directory. Using a custom built ONNX Runtime iOS package, copy `onnxruntime-mobile-c.zip` into `/js/react_native/local_pods` directory if it's not already done. From `/js/react_native/e2e/android`, run e2e Android tests as follows, diff --git a/js/node/test/e2e/simple-e2e-tests.ts b/js/node/test/e2e/simple-e2e-tests.ts index dbbcdfcf8df09..70ac6ca1e0f94 100644 --- a/js/node/test/e2e/simple-e2e-tests.ts +++ b/js/node/test/e2e/simple-e2e-tests.ts @@ -11,73 +11,73 @@ import {assertDataEqual, TEST_DATA_ROOT} from '../test-utils'; const MODEL_TEST_TYPES_CASES: Array<{model: string; type: Tensor.Type; input0: Tensor.DataType; expectedOutput0: Tensor.DataType}> = [ { - model: path.join(TEST_DATA_ROOT, 'test_types_BOOL.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'), type: 'bool', input0: Uint8Array.from([1, 0, 0, 1, 0]), expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_DOUBLE.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'), type: 'float64', input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_FLOAT.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'), type: 'float32', input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_INT8.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'), type: 'int8', input0: Int8Array.from([1, -2, 3, 4, -5]), expectedOutput0: Int8Array.from([1, -2, 3, 4, -5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_INT16.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'), type: 'int16', input0: Int16Array.from([1, -2, 3, 4, -5]), expectedOutput0: Int16Array.from([1, -2, 3, 4, -5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_INT32.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'), type: 'int32', input0: Int32Array.from([1, -2, 3, 4, -5]), expectedOutput0: Int32Array.from([1, -2, 3, 4, -5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_INT64.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'), type: 'int64', input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]), expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_STRING.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'), type: 'string', input0: ['a', 'b', 'c', 'd', 'e'], expectedOutput0: ['a', 'b', 'c', 'd', 'e'] }, { - model: path.join(TEST_DATA_ROOT, 'test_types_UINT8.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'), type: 'uint8', input0: Uint8Array.from([1, 2, 3, 4, 5]), expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_UINT16.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'), type: 'uint16', input0: Uint16Array.from([1, 2, 3, 4, 5]), expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_UINT32.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'), type: 'uint32', input0: Uint32Array.from([1, 2, 3, 4, 5]), expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5]) }, { - model: path.join(TEST_DATA_ROOT, 'test_types_UINT64.pb'), + model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'), type: 'uint64', input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]), expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]) diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index d2da45566d698..ffb19c7bb2173 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -82,7 +82,7 @@ export function warmup(): void { this.timeout(0); // we have test cases to verify correctness in other place, so do no check here. try { - const session = await InferenceSession.create(path.join(TEST_DATA_ROOT, 'test_types_INT32.pb')); + const session = await InferenceSession.create(path.join(TEST_DATA_ROOT, 'test_types_int32.onnx')); await session.run({input: new Tensor(new Float32Array(5), [1, 5])}, {output: null}, {}); } catch (e) { } diff --git a/js/node/test/testdata/test_types_BOOL.pb b/js/node/test/testdata/test_types_BOOL.pb deleted file mode 100644 index 2c58b06d0aa6d..0000000000000 Binary files a/js/node/test/testdata/test_types_BOOL.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_DOUBLE.pb b/js/node/test/testdata/test_types_DOUBLE.pb deleted file mode 100644 index 65ebf0f848a35..0000000000000 Binary files a/js/node/test/testdata/test_types_DOUBLE.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_FLOAT.pb b/js/node/test/testdata/test_types_FLOAT.pb deleted file mode 100644 index b4ad9807834ed..0000000000000 Binary files a/js/node/test/testdata/test_types_FLOAT.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_FLOAT16.pb b/js/node/test/testdata/test_types_FLOAT16.pb deleted file mode 100644 index f671bb7ce3253..0000000000000 Binary files a/js/node/test/testdata/test_types_FLOAT16.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_INT16.pb b/js/node/test/testdata/test_types_INT16.pb deleted file mode 100644 index f297ec9c54940..0000000000000 Binary files a/js/node/test/testdata/test_types_INT16.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_INT32.pb b/js/node/test/testdata/test_types_INT32.pb deleted file mode 100644 index 73bb539cf44c6..0000000000000 Binary files a/js/node/test/testdata/test_types_INT32.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_INT64.pb b/js/node/test/testdata/test_types_INT64.pb deleted file mode 100644 index ccf8df0033278..0000000000000 Binary files a/js/node/test/testdata/test_types_INT64.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_INT8.pb b/js/node/test/testdata/test_types_INT8.pb deleted file mode 100644 index 72698a779578d..0000000000000 Binary files a/js/node/test/testdata/test_types_INT8.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_STRING.pb b/js/node/test/testdata/test_types_STRING.pb deleted file mode 100644 index 7c8b3e7e2eb82..0000000000000 Binary files a/js/node/test/testdata/test_types_STRING.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_UINT16.pb b/js/node/test/testdata/test_types_UINT16.pb deleted file mode 100644 index 0a9c6fe3770ce..0000000000000 Binary files a/js/node/test/testdata/test_types_UINT16.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_UINT32.pb b/js/node/test/testdata/test_types_UINT32.pb deleted file mode 100644 index 90efef3e7f171..0000000000000 Binary files a/js/node/test/testdata/test_types_UINT32.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_UINT64.pb b/js/node/test/testdata/test_types_UINT64.pb deleted file mode 100644 index 53214a1a2e0e6..0000000000000 Binary files a/js/node/test/testdata/test_types_UINT64.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_UINT8.pb b/js/node/test/testdata/test_types_UINT8.pb deleted file mode 100644 index 8b6a9c42197ef..0000000000000 Binary files a/js/node/test/testdata/test_types_UINT8.pb and /dev/null differ diff --git a/js/node/test/testdata/test_types_bool.onnx b/js/node/test/testdata/test_types_bool.onnx new file mode 100644 index 0000000000000..dc6753a4a0c72 Binary files /dev/null and b/js/node/test/testdata/test_types_bool.onnx differ diff --git a/js/node/test/testdata/test_types_double.onnx b/js/node/test/testdata/test_types_double.onnx new file mode 100644 index 0000000000000..c99dd3facf0fb Binary files /dev/null and b/js/node/test/testdata/test_types_double.onnx differ diff --git a/js/node/test/testdata/test_types_float.onnx b/js/node/test/testdata/test_types_float.onnx new file mode 100644 index 0000000000000..91bdef98910ec Binary files /dev/null and b/js/node/test/testdata/test_types_float.onnx differ diff --git a/js/node/test/testdata/test_types_float16.onnx b/js/node/test/testdata/test_types_float16.onnx new file mode 100644 index 0000000000000..b7dd3dd0c97fd Binary files /dev/null and b/js/node/test/testdata/test_types_float16.onnx differ diff --git a/js/node/test/testdata/test_types_int16.onnx b/js/node/test/testdata/test_types_int16.onnx new file mode 100644 index 0000000000000..df14aef71b97f Binary files /dev/null and b/js/node/test/testdata/test_types_int16.onnx differ diff --git a/js/node/test/testdata/test_types_int32.onnx b/js/node/test/testdata/test_types_int32.onnx new file mode 100644 index 0000000000000..3b0d8c3d677c8 Binary files /dev/null and b/js/node/test/testdata/test_types_int32.onnx differ diff --git a/js/node/test/testdata/test_types_int64.onnx b/js/node/test/testdata/test_types_int64.onnx new file mode 100644 index 0000000000000..5d35b7d74c63e Binary files /dev/null and b/js/node/test/testdata/test_types_int64.onnx differ diff --git a/js/node/test/testdata/test_types_int8.onnx b/js/node/test/testdata/test_types_int8.onnx new file mode 100644 index 0000000000000..8a557e44d5272 Binary files /dev/null and b/js/node/test/testdata/test_types_int8.onnx differ diff --git a/js/node/test/testdata/test_types_string.onnx b/js/node/test/testdata/test_types_string.onnx new file mode 100644 index 0000000000000..8adebf144b89e Binary files /dev/null and b/js/node/test/testdata/test_types_string.onnx differ diff --git a/js/node/test/testdata/test_types_uint16.onnx b/js/node/test/testdata/test_types_uint16.onnx new file mode 100644 index 0000000000000..aadac3651a656 Binary files /dev/null and b/js/node/test/testdata/test_types_uint16.onnx differ diff --git a/js/node/test/testdata/test_types_uint32.onnx b/js/node/test/testdata/test_types_uint32.onnx new file mode 100644 index 0000000000000..c3ad4da3e03e4 Binary files /dev/null and b/js/node/test/testdata/test_types_uint32.onnx differ diff --git a/js/node/test/testdata/test_types_uint64.onnx b/js/node/test/testdata/test_types_uint64.onnx new file mode 100644 index 0000000000000..af7b6378bca3a Binary files /dev/null and b/js/node/test/testdata/test_types_uint64.onnx differ diff --git a/js/node/test/testdata/test_types_uint8.onnx b/js/node/test/testdata/test_types_uint8.onnx new file mode 100644 index 0000000000000..c57f3c8a61366 Binary files /dev/null and b/js/node/test/testdata/test_types_uint8.onnx differ diff --git a/js/node/test/unittests/lib/inference-session.ts b/js/node/test/unittests/lib/inference-session.ts index ffb5e0a48a3cb..d8d961cc94398 100644 --- a/js/node/test/unittests/lib/inference-session.ts +++ b/js/node/test/unittests/lib/inference-session.ts @@ -186,7 +186,7 @@ describe('UnitTests - InferenceSession.run()', () => { }); describe('UnitTests - InferenceSession.SessionOptions', () => { - const modelPath = path.join(__dirname, '../../testdata/test_types_FLOAT.pb'); + const modelPath = path.join(__dirname, '../../testdata/test_types_float.onnx'); const createAny: any = InferenceSession.create; it('BAD CALL - type mismatch', async () => { @@ -323,7 +323,7 @@ describe('UnitTests - InferenceSession.RunOptions', () => { const expectedOutput0 = new Tensor('float32', [1, 2, 3, 4, 5], [1, 5]); before(async () => { - const modelPath = path.join(__dirname, '../../testdata/test_types_FLOAT.pb'); + const modelPath = path.join(__dirname, '../../testdata/test_types_float.onnx'); session = await InferenceSession.create(modelPath); sessionAny = session; }); diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index 19c27441a3624..f508eccae4468 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -34,6 +34,7 @@ import java.util.Map; import org.junit.Assert; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.mockito.MockitoSession; @@ -206,6 +207,7 @@ public void createInputTensor_double() throws Exception { } @Test + @Ignore("data type for Slice is not supported in mobile package") public void createOutputTensor_bool() throws Exception { MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); try { @@ -246,6 +248,7 @@ public void createOutputTensor_bool() throws Exception { } @Test + @Ignore("data type for Slice is not supported in mobile package") public void createOutputTensor_double() throws Exception { MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); try { diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_bool.ort b/js/react_native/android/src/androidTest/res/raw/test_types_bool.ort index 9267bee5abf76..e83b233e28255 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_bool.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_bool.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_double.ort b/js/react_native/android/src/androidTest/res/raw/test_types_double.ort index 9030eca3e4ca1..94f20e0f421f4 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_double.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_double.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_float.ort b/js/react_native/android/src/androidTest/res/raw/test_types_float.ort index 01f489572ba4a..e5c40742843d5 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_float.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_float.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort index 8aecad72cfcf7..6135c9a4aca7c 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort index b84dc2944bae5..a9892d9ec598d 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort index d762f2c5eb22d..f1bf199e488e1 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort index bf2d9ac7f8362..9f5310803323a 100644 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort and b/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort differ diff --git a/js/react_native/e2e/android/app/build.gradle b/js/react_native/e2e/android/app/build.gradle index 3683aa70da648..ea16f2c87f11c 100644 --- a/js/react_native/e2e/android/app/build.gradle +++ b/js/react_native/e2e/android/app/build.gradle @@ -177,6 +177,12 @@ android { } } +repositories { + flatDir { + dir 'libs' + } +} + dependencies { implementation fileTree(dir: "libs", include: ["*.jar"]) //noinspection GradleDynamicVersion @@ -207,6 +213,8 @@ dependencies { androidTestImplementation 'androidx.test:rules:1.4.0' implementation project(':onnxruntime-react-native') + // specify ORT dependency here so it can be found in libs flatDir repository + implementation "com.microsoft.onnxruntime:onnxruntime-mobile:latest.integration@aar" } // Run this once to be able to run the application with BUCK diff --git a/js/react_native/e2e/android/app/src/main/assets/mnist.ort b/js/react_native/e2e/android/app/src/main/assets/mnist.ort index 58dd9e664ac53..a82758ba12de1 100644 Binary files a/js/react_native/e2e/android/app/src/main/assets/mnist.ort and b/js/react_native/e2e/android/app/src/main/assets/mnist.ort differ diff --git a/js/react_native/e2e/android/app/src/main/java/com/example/reactnativeonnxruntimemodule/MNISTDataHandler.java b/js/react_native/e2e/android/app/src/main/java/com/example/reactnativeonnxruntimemodule/MNISTDataHandler.java index 8c9d71b76b34c..a458901b5314c 100644 --- a/js/react_native/e2e/android/app/src/main/java/com/example/reactnativeonnxruntimemodule/MNISTDataHandler.java +++ b/js/react_native/e2e/android/app/src/main/java/com/example/reactnativeonnxruntimemodule/MNISTDataHandler.java @@ -143,6 +143,7 @@ private WritableMap preprocess(String uri) throws Exception { WritableArray dims = Arguments.createArray(); dims.pushInt(batchSize); + dims.pushInt(1); dims.pushInt(imageHeight); dims.pushInt(imageWidth); inputTensorMap.putArray("dims", dims); @@ -155,7 +156,7 @@ private WritableMap preprocess(String uri) throws Exception { String data = Base64.encodeToString(imageByteBuffer.array(), Base64.DEFAULT); inputTensorMap.putString("data", data); - inputDataMap.putMap("flatten_2_input", inputTensorMap); + inputDataMap.putMap("Input3", inputTensorMap); return inputDataMap; } @@ -164,7 +165,7 @@ private WritableMap preprocess(String uri) throws Exception { private WritableMap postprocess(ReadableMap result) throws Exception { String detectionResult = ""; - ReadableMap outputTensor = result.getMap("Identity"); + ReadableMap outputTensor = result.getMap("Plus214_Output_0"); String outputData = outputTensor.getString("data"); FloatBuffer buffer = diff --git a/js/react_native/e2e/ios/MNISTDataHandler.mm b/js/react_native/e2e/ios/MNISTDataHandler.mm index d639ec930daa4..b935a91b63503 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.mm +++ b/js/react_native/e2e/ios/MNISTDataHandler.mm @@ -117,7 +117,9 @@ - (NSDictionary *)preprocess:(NSString *)uri { // dims NSArray *dims = @[ - [NSNumber numberWithInt:1], [NSNumber numberWithInt:static_cast(height)], + [NSNumber numberWithInt:1], + [NSNumber numberWithInt:1], + [NSNumber numberWithInt:static_cast(height)], [NSNumber numberWithInt:static_cast(width)] ]; inputTensorMap[@"dims"] = dims; @@ -129,7 +131,7 @@ - (NSDictionary *)preprocess:(NSString *)uri { NSString *data = [byteBufferRef base64EncodedStringWithOptions:0]; inputTensorMap[@"data"] = data; - inputDataMap[@"flatten_2_input"] = inputTensorMap; + inputDataMap[@"Input3"] = inputTensorMap; return inputDataMap; } @@ -137,7 +139,7 @@ - (NSDictionary *)preprocess:(NSString *)uri { - (NSDictionary *)postprocess:(NSDictionary *)result { NSMutableString *detectionResult = [NSMutableString string]; - NSDictionary *outputTensor = [result objectForKey:@"Identity"]; + NSDictionary *outputTensor = [result objectForKey:@"Plus214_Output_0"]; NSString *data = [outputTensor objectForKey:@"data"]; NSData *buffer = [[NSData alloc] initWithBase64EncodedString:data options:0]; diff --git a/onnxruntime/test/testdata/mnist.level1_opt.onnx b/js/react_native/e2e/src/mnist.onnx similarity index 92% rename from onnxruntime/test/testdata/mnist.level1_opt.onnx rename to js/react_native/e2e/src/mnist.onnx index 70fd5b6f31a82..30761fc7b7be5 100644 Binary files a/onnxruntime/test/testdata/mnist.level1_opt.onnx and b/js/react_native/e2e/src/mnist.onnx differ diff --git a/js/react_native/e2e/src/mnist.ort b/js/react_native/e2e/src/mnist.ort index 58dd9e664ac53..e342b0202497f 100644 Binary files a/js/react_native/e2e/src/mnist.ort and b/js/react_native/e2e/src/mnist.ort differ diff --git a/js/react_native/e2e/src/mnist.readme.md b/js/react_native/e2e/src/mnist.readme.md new file mode 100644 index 0000000000000..6ba5712dd96e7 --- /dev/null +++ b/js/react_native/e2e/src/mnist.readme.md @@ -0,0 +1,14 @@ +`js/react_native/e2e/src/mnist.onnx` is `onnxruntime/test/testdata/mnist.onnx` updated to opset 15. + +```bash +cd /js +python -m onnxruntime.tools.update_onnx_opset --opset 15 ../onnxruntime/test/testdata/mnist.onnx ./react_native/e2e/src/mnist.onnx +``` + +`js/react_native/e2e/src/mnist.ort` and `js/react_native/e2e/android/app/src/main/assets/mnist.ort` are converted from `js/react_native/e2e/src/mnist.onnx`. + +```bash +cd /js +python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed --output_dir ./react_native/e2e/android/app/src/main/assets ./react_native/e2e/src/mnist.onnx +python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed --output_dir ./react_native/e2e/src ./react_native/e2e/src/mnist.onnx +``` diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_bool.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_bool.ort index 9267bee5abf76..e83b233e28255 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_bool.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_bool.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_double.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_double.ort index 9030eca3e4ca1..94f20e0f421f4 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_double.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_double.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_float.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_float.ort index 01f489572ba4a..e5c40742843d5 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_float.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_float.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int32.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int32.ort index 8aecad72cfcf7..6135c9a4aca7c 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int32.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int32.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int64.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int64.ort index b84dc2944bae5..a9892d9ec598d 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int64.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int64.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int8.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int8.ort index d762f2c5eb22d..f1bf199e488e1 100644 Binary files a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int8.ort and b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int8.ort differ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm index ad7606c7f8118..10922f9ef3ffc 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm @@ -213,6 +213,7 @@ - (void)testCreateOutputTensorFloat { } - (void)testCreateOutputTensorDouble { + XCTSkip(@"data type for Slice is not supported in mobile package"); std::array outValues{std::numeric_limits::min(), 1.0f, 2.0f, 3.0f, std::numeric_limits::max()}; std::function convert = [](double_t value) { return [NSNumber numberWithDouble:value]; }; @@ -220,6 +221,7 @@ - (void)testCreateOutputTensorDouble { } - (void)testCreateOutputTensorBool { + XCTSkip(@"data type for Slice is not supported in mobile package"); std::array outValues{false, true, true, false, true}; std::function convert = [](bool value) { return [NSNumber numberWithBool:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeBool, @"test_types_bool", @"ort"); diff --git a/js/react_native/test_types_models.readme.md b/js/react_native/test_types_models.readme.md new file mode 100644 index 0000000000000..fb247fa6fe026 --- /dev/null +++ b/js/react_native/test_types_models.readme.md @@ -0,0 +1,15 @@ +`js/react_native/android/src/androidTest/res/raw/test_types_*.ort` and +`js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_*.ort` ORT format models are converted from +`js/node/test/testdata/test_types_*.onnx` ONNX models. + +For example, to generate `js/react_native/android/src/androidTest/res/raw/test_types_*.ort`, from the `js` directory, +run: + +```bash +python -m onnxruntime.tools.convert_onnx_models_to_ort \ + --optimization_style Fixed \ + --output_dir ./react_native/android/src/androidTest/res/raw \ + ./node/test/testdata +``` + +Some additional files will be generated. They can be removed. diff --git a/objectivec/test/testdata/single_add.basic.ort b/objectivec/test/testdata/single_add.basic.ort index d85f2d4e6c73d..f622784b35366 100644 Binary files a/objectivec/test/testdata/single_add.basic.ort and b/objectivec/test/testdata/single_add.basic.ort differ diff --git a/onnxruntime/contrib_ops/cpu/maxpool_with_mask.cc b/onnxruntime/contrib_ops/cpu/maxpool_with_mask.cc index 0dee2079d4328..e0c420d951d84 100644 --- a/onnxruntime/contrib_ops/cpu/maxpool_with_mask.cc +++ b/onnxruntime/contrib_ops/cpu/maxpool_with_mask.cc @@ -11,7 +11,7 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( 1, float, KernelDefBuilder() - .TypeConstraint("X", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::GetTensorType()), MaxpoolWithMask); } // namespace contrib diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 33d76e71c24a5..6e0eb460d2a63 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -7,6 +7,7 @@ #include #include "core/common/common.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { namespace utils { @@ -18,10 +19,10 @@ namespace utils { * @param keep_empty Whether to keep empty substrings. * @return The split substrings. */ -inline std::vector SplitString(std::string_view string_to_split, std::string_view delimiter, - bool keep_empty = false) { +inline InlinedVector SplitString(std::string_view string_to_split, std::string_view delimiter, + bool keep_empty = false) { ORT_ENFORCE(!delimiter.empty(), "delimiter must not be empty"); - std::vector result{}; + InlinedVector result{}; std::string_view::size_type segment_begin_pos = 0; while (segment_begin_pos != std::string_view::npos) { const std::string_view::size_type segment_end_pos = string_to_split.find(delimiter, segment_begin_pos); diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc index 2d926daf3285a..505b79548a1fa 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "flatbuffers_utils.h" -#include "schema/ort.fbs.h" +#include "core/flatbuffers/flatbuffers_utils.h" + +#include "gsl/gsl" #include "core/common/common.h" +#include "core/flatbuffers/schema/ort.fbs.h" #include "core/graph/constants.h" #include "core/graph/onnx_protobuf.h" -#include "gsl/gsl" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h index 570cec740413c..a04bb60453035 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h @@ -5,6 +5,7 @@ #include +#include "core/common/common.h" #include "core/common/path_string.h" #include "core/common/status.h" @@ -32,11 +33,12 @@ struct ValueInfo; namespace utils { +constexpr auto kInvalidOrtFormatModelMessage = "Invalid ORT format model."; + // Will only create string in flatbuffers when has_string is true flatbuffers::Offset SaveStringToOrtFormat(flatbuffers::FlatBufferBuilder& builder, bool has_string, const std::string& src); -// TODO, add ORT_MUST_USE_RESULT when it is moved to a different header onnxruntime::common::Status SaveValueInfoOrtFormat( flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::ValueInfoProto& value_info_proto, flatbuffers::Offset& fbs_value_info); @@ -67,3 +69,7 @@ bool IsOrtFormatModelBytes(const void* bytes, int num_bytes); } // namespace utils } // namespace fbs } // namespace onnxruntime + +#define ORT_FORMAT_RETURN_IF_NULL(expr, expr_description) \ + ORT_RETURN_IF((expr) == nullptr, (expr_description), " is null. ", \ + onnxruntime::fbs::utils::kInvalidOrtFormatModelMessage) diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgType.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgType.py new file mode 100644 index 0000000000000..a0328a9f469e7 --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgType.py @@ -0,0 +1,8 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +class ArgType(object): + INPUT = 0 + OUTPUT = 1 + diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py new file mode 100644 index 0000000000000..32aaa298dd99a --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py @@ -0,0 +1,44 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ArgTypeAndIndex(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsArgTypeAndIndex(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ArgTypeAndIndex() + x.Init(buf, n + offset) + return x + + @classmethod + def ArgTypeAndIndexBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) + + # ArgTypeAndIndex + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ArgTypeAndIndex + def ArgType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # ArgTypeAndIndex + def Index(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + +def ArgTypeAndIndexStart(builder): builder.StartObject(2) +def ArgTypeAndIndexAddArgType(builder, argType): builder.PrependInt8Slot(0, argType, 0) +def ArgTypeAndIndexAddIndex(builder, index): builder.PrependUint32Slot(1, index, 0) +def ArgTypeAndIndexEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelCreateInfos.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py similarity index 63% rename from onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelCreateInfos.py rename to onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py index 355d462c50797..9f93bffa499d0 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelCreateInfos.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py @@ -6,25 +6,26 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class KernelCreateInfos(object): +# deprecated: no longer using kernel def hashes +class DeprecatedKernelCreateInfos(object): __slots__ = ['_tab'] @classmethod - def GetRootAsKernelCreateInfos(cls, buf, offset): + def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = KernelCreateInfos() + x = DeprecatedKernelCreateInfos() x.Init(buf, n + offset) return x @classmethod - def KernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def DeprecatedKernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) - # KernelCreateInfos + # DeprecatedKernelCreateInfos def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # KernelCreateInfos + # DeprecatedKernelCreateInfos def NodeIndices(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: @@ -32,26 +33,26 @@ def NodeIndices(self, j): return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def NodeIndicesAsNumpy(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def NodeIndicesLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.VectorLen(o) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def NodeIndicesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def KernelDefHashes(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: @@ -59,28 +60,28 @@ def KernelDefHashes(self, j): return self._tab.Get(flatbuffers.number_types.Uint64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def KernelDefHashesAsNumpy(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint64Flags, o) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def KernelDefHashesLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.VectorLen(o) return 0 - # KernelCreateInfos + # DeprecatedKernelCreateInfos def KernelDefHashesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def KernelCreateInfosStart(builder): builder.StartObject(2) -def KernelCreateInfosAddNodeIndices(builder, nodeIndices): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0) -def KernelCreateInfosStartNodeIndicesVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def KernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0) -def KernelCreateInfosStartKernelDefHashesVector(builder, numElems): return builder.StartVector(8, numElems, 8) -def KernelCreateInfosEnd(builder): return builder.EndObject() +def DeprecatedKernelCreateInfosStart(builder): builder.StartObject(2) +def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0) +def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0) +def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems): return builder.StartVector(8, numElems, 8) +def DeprecatedKernelCreateInfosEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/NodeIndexAndKernelDefHash.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py similarity index 54% rename from onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/NodeIndexAndKernelDefHash.py rename to onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py index 35c7301f7d2a5..7137233a9e726 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/NodeIndexAndKernelDefHash.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py @@ -6,39 +6,40 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class NodeIndexAndKernelDefHash(object): +# deprecated: no longer using kernel def hashes +class DeprecatedNodeIndexAndKernelDefHash(object): __slots__ = ['_tab'] @classmethod - def GetRootAsNodeIndexAndKernelDefHash(cls, buf, offset): + def GetRootAsDeprecatedNodeIndexAndKernelDefHash(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = NodeIndexAndKernelDefHash() + x = DeprecatedNodeIndexAndKernelDefHash() x.Init(buf, n + offset) return x @classmethod - def NodeIndexAndKernelDefHashBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def DeprecatedNodeIndexAndKernelDefHashBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) - # NodeIndexAndKernelDefHash + # DeprecatedNodeIndexAndKernelDefHash def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # NodeIndexAndKernelDefHash + # DeprecatedNodeIndexAndKernelDefHash def NodeIndex(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) return 0 - # NodeIndexAndKernelDefHash + # DeprecatedNodeIndexAndKernelDefHash def KernelDefHash(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) return 0 -def NodeIndexAndKernelDefHashStart(builder): builder.StartObject(2) -def NodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex): builder.PrependUint32Slot(0, nodeIndex, 0) -def NodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash): builder.PrependUint64Slot(1, kernelDefHash, 0) -def NodeIndexAndKernelDefHashEnd(builder): return builder.EndObject() +def DeprecatedNodeIndexAndKernelDefHashStart(builder): builder.StartObject(2) +def DeprecatedNodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex): builder.PrependUint32Slot(0, nodeIndex, 0) +def DeprecatedNodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash): builder.PrependUint64Slot(1, kernelDefHash, 0) +def DeprecatedNodeIndexAndKernelDefHashEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SessionState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSessionState.py similarity index 53% rename from onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SessionState.py rename to onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSessionState.py index 274a20d7d16bd..fbf21a38c2f5d 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SessionState.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSessionState.py @@ -6,62 +6,63 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class SessionState(object): +# deprecated: no longer using kernel def hashes +class DeprecatedSessionState(object): __slots__ = ['_tab'] @classmethod - def GetRootAsSessionState(cls, buf, offset): + def GetRootAsDeprecatedSessionState(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = SessionState() + x = DeprecatedSessionState() x.Init(buf, n + offset) return x @classmethod - def SessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def DeprecatedSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) - # SessionState + # DeprecatedSessionState def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # SessionState + # DeprecatedSessionState def Kernels(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: x = self._tab.Indirect(o + self._tab.Pos) - from ort_flatbuffers_py.fbs.KernelCreateInfos import KernelCreateInfos - obj = KernelCreateInfos() + from ort_flatbuffers_py.fbs.DeprecatedKernelCreateInfos import DeprecatedKernelCreateInfos + obj = DeprecatedKernelCreateInfos() obj.Init(self._tab.Bytes, x) return obj return None - # SessionState + # DeprecatedSessionState def SubGraphSessionStates(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: x = self._tab.Vector(o) x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 x = self._tab.Indirect(x) - from ort_flatbuffers_py.fbs.SubGraphSessionState import SubGraphSessionState - obj = SubGraphSessionState() + from ort_flatbuffers_py.fbs.DeprecatedSubGraphSessionState import DeprecatedSubGraphSessionState + obj = DeprecatedSubGraphSessionState() obj.Init(self._tab.Bytes, x) return obj return None - # SessionState + # DeprecatedSessionState def SubGraphSessionStatesLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.VectorLen(o) return 0 - # SessionState + # DeprecatedSessionState def SubGraphSessionStatesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def SessionStateStart(builder): builder.StartObject(2) -def SessionStateAddKernels(builder, kernels): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernels), 0) -def SessionStateAddSubGraphSessionStates(builder, subGraphSessionStates): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(subGraphSessionStates), 0) -def SessionStateStartSubGraphSessionStatesVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def SessionStateEnd(builder): return builder.EndObject() +def DeprecatedSessionStateStart(builder): builder.StartObject(2) +def DeprecatedSessionStateAddKernels(builder, kernels): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernels), 0) +def DeprecatedSessionStateAddSubGraphSessionStates(builder, subGraphSessionStates): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(subGraphSessionStates), 0) +def DeprecatedSessionStateStartSubGraphSessionStatesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def DeprecatedSessionStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py new file mode 100644 index 0000000000000..52b450408632c --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py @@ -0,0 +1,49 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +# deprecated: no longer using kernel def hashes +class DeprecatedSubGraphSessionState(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsDeprecatedSubGraphSessionState(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DeprecatedSubGraphSessionState() + x.Init(buf, n + offset) + return x + + @classmethod + def DeprecatedSubGraphSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) + + # DeprecatedSubGraphSessionState + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DeprecatedSubGraphSessionState + def GraphId(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # DeprecatedSubGraphSessionState + def SessionState(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + from ort_flatbuffers_py.fbs.DeprecatedSessionState import DeprecatedSessionState + obj = DeprecatedSessionState() + obj.Init(self._tab.Bytes, x) + return obj + return None + +def DeprecatedSubGraphSessionStateStart(builder): builder.StartObject(2) +def DeprecatedSubGraphSessionStateAddGraphId(builder, graphId): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(graphId), 0) +def DeprecatedSubGraphSessionStateAddSessionState(builder, sessionState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(sessionState), 0) +def DeprecatedSubGraphSessionStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/InferenceSession.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/InferenceSession.py index d9b7f0d3ec0da..d5a67bf8b8c61 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/InferenceSession.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/InferenceSession.py @@ -43,18 +43,18 @@ def Model(self): return None # InferenceSession - def SessionState(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + def KernelTypeStrResolver(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: x = self._tab.Indirect(o + self._tab.Pos) - from ort_flatbuffers_py.fbs.SessionState import SessionState - obj = SessionState() + from ort_flatbuffers_py.fbs.KernelTypeStrResolver import KernelTypeStrResolver + obj = KernelTypeStrResolver() obj.Init(self._tab.Bytes, x) return obj return None -def InferenceSessionStart(builder): builder.StartObject(3) +def InferenceSessionStart(builder): builder.StartObject(4) def InferenceSessionAddOrtVersion(builder, ortVersion): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(ortVersion), 0) def InferenceSessionAddModel(builder, model): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(model), 0) -def InferenceSessionAddSessionState(builder, sessionState): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(sessionState), 0) +def InferenceSessionAddKernelTypeStrResolver(builder, kernelTypeStrResolver): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStrResolver), 0) def InferenceSessionEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py new file mode 100644 index 0000000000000..94f37b38481fd --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py @@ -0,0 +1,63 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class KernelTypeStrArgsEntry(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsKernelTypeStrArgsEntry(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = KernelTypeStrArgsEntry() + x.Init(buf, n + offset) + return x + + @classmethod + def KernelTypeStrArgsEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) + + # KernelTypeStrArgsEntry + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # KernelTypeStrArgsEntry + def KernelTypeStr(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # KernelTypeStrArgsEntry + def Args(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from ort_flatbuffers_py.fbs.ArgTypeAndIndex import ArgTypeAndIndex + obj = ArgTypeAndIndex() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # KernelTypeStrArgsEntry + def ArgsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # KernelTypeStrArgsEntry + def ArgsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def KernelTypeStrArgsEntryStart(builder): builder.StartObject(2) +def KernelTypeStrArgsEntryAddKernelTypeStr(builder, kernelTypeStr): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStr), 0) +def KernelTypeStrArgsEntryAddArgs(builder, args): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(args), 0) +def KernelTypeStrArgsEntryStartArgsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def KernelTypeStrArgsEntryEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py new file mode 100644 index 0000000000000..ef2cd95df91f7 --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py @@ -0,0 +1,55 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class KernelTypeStrResolver(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsKernelTypeStrResolver(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = KernelTypeStrResolver() + x.Init(buf, n + offset) + return x + + @classmethod + def KernelTypeStrResolverBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) + + # KernelTypeStrResolver + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # KernelTypeStrResolver + def OpKernelTypeStrArgs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from ort_flatbuffers_py.fbs.OpIdKernelTypeStrArgsEntry import OpIdKernelTypeStrArgsEntry + obj = OpIdKernelTypeStrArgsEntry() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # KernelTypeStrResolver + def OpKernelTypeStrArgsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # KernelTypeStrResolver + def OpKernelTypeStrArgsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def KernelTypeStrResolverStart(builder): builder.StartObject(1) +def KernelTypeStrResolverAddOpKernelTypeStrArgs(builder, opKernelTypeStrArgs): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(opKernelTypeStrArgs), 0) +def KernelTypeStrResolverStartOpKernelTypeStrArgsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def KernelTypeStrResolverEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py new file mode 100644 index 0000000000000..97eea172b786b --- /dev/null +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py @@ -0,0 +1,63 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: fbs + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class OpIdKernelTypeStrArgsEntry(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsOpIdKernelTypeStrArgsEntry(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = OpIdKernelTypeStrArgsEntry() + x.Init(buf, n + offset) + return x + + @classmethod + def OpIdKernelTypeStrArgsEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) + + # OpIdKernelTypeStrArgsEntry + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # OpIdKernelTypeStrArgsEntry + def OpId(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # OpIdKernelTypeStrArgsEntry + def KernelTypeStrArgs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from ort_flatbuffers_py.fbs.KernelTypeStrArgsEntry import KernelTypeStrArgsEntry + obj = KernelTypeStrArgsEntry() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # OpIdKernelTypeStrArgsEntry + def KernelTypeStrArgsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # OpIdKernelTypeStrArgsEntry + def KernelTypeStrArgsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def OpIdKernelTypeStrArgsEntryStart(builder): builder.StartObject(2) +def OpIdKernelTypeStrArgsEntryAddOpId(builder, opId): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(opId), 0) +def OpIdKernelTypeStrArgsEntryAddKernelTypeStrArgs(builder, kernelTypeStrArgs): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStrArgs), 0) +def OpIdKernelTypeStrArgsEntryStartKernelTypeStrArgsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def OpIdKernelTypeStrArgsEntryEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py index 488572506975b..7880cc565f69d 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py @@ -45,33 +45,28 @@ def NodesToOptimizeIndices(self): return None # RuntimeOptimizationRecord - def ProducedNodes(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + def ProducedOpIds(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: - x = self._tab.Vector(o) - x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 - x = self._tab.Indirect(x) - from ort_flatbuffers_py.fbs.NodeIndexAndKernelDefHash import NodeIndexAndKernelDefHash - obj = NodeIndexAndKernelDefHash() - obj.Init(self._tab.Bytes, x) - return obj - return None + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" # RuntimeOptimizationRecord - def ProducedNodesLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + def ProducedOpIdsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: return self._tab.VectorLen(o) return 0 # RuntimeOptimizationRecord - def ProducedNodesIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + def ProducedOpIdsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) return o == 0 -def RuntimeOptimizationRecordStart(builder): builder.StartObject(3) +def RuntimeOptimizationRecordStart(builder): builder.StartObject(4) def RuntimeOptimizationRecordAddActionId(builder, actionId): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(actionId), 0) def RuntimeOptimizationRecordAddNodesToOptimizeIndices(builder, nodesToOptimizeIndices): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(nodesToOptimizeIndices), 0) -def RuntimeOptimizationRecordAddProducedNodes(builder, producedNodes): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(producedNodes), 0) -def RuntimeOptimizationRecordStartProducedNodesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def RuntimeOptimizationRecordAddProducedOpIds(builder, producedOpIds): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(producedOpIds), 0) +def RuntimeOptimizationRecordStartProducedOpIdsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def RuntimeOptimizationRecordEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SubGraphSessionState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SubGraphSessionState.py deleted file mode 100644 index dcbabc619d866..0000000000000 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/SubGraphSessionState.py +++ /dev/null @@ -1,48 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: fbs - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class SubGraphSessionState(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsSubGraphSessionState(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = SubGraphSessionState() - x.Init(buf, n + offset) - return x - - @classmethod - def SubGraphSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed) - - # SubGraphSessionState - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # SubGraphSessionState - def GraphId(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.String(o + self._tab.Pos) - return None - - # SubGraphSessionState - def SessionState(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - x = self._tab.Indirect(o + self._tab.Pos) - from ort_flatbuffers_py.fbs.SessionState import SessionState - obj = SessionState() - obj.Init(self._tab.Bytes, x) - return obj - return None - -def SubGraphSessionStateStart(builder): builder.StartObject(2) -def SubGraphSessionStateAddGraphId(builder, graphId): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(graphId), 0) -def SubGraphSessionStateAddSessionState(builder, sessionState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(sessionState), 0) -def SubGraphSessionStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/ort_format_version.h b/onnxruntime/core/flatbuffers/ort_format_version.h index e48cb6ebb82b2..c6c0ad7c20fa3 100644 --- a/onnxruntime/core/flatbuffers/ort_format_version.h +++ b/onnxruntime/core/flatbuffers/ort_format_version.h @@ -19,7 +19,8 @@ namespace onnxruntime { // Version 2 - add serialization/deserialization of sparse_initializer // Version 3 - add `graph_doc_string` to Model // Version 4 - update kernel def hashing to not depend on ordering of type constraint types (NOT BACKWARDS COMPATIBLE) -constexpr const char* kOrtModelVersion = "4"; +// Version 5 - deprecate kernel def hashes and add KernelTypeStrResolver info to replace them (NOT BACKWARDS COMPATIBLE) +constexpr const char* kOrtModelVersion = "5"; // Check if the given ort model version is supported in this build inline bool IsOrtModelVersionSupported(std::string_view ort_model_version) { diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md index c24c1a37e972a..4c15f526551e7 100644 --- a/onnxruntime/core/flatbuffers/schema/README.md +++ b/onnxruntime/core/flatbuffers/schema/README.md @@ -1,11 +1,15 @@ # ORT File Format -This directory contains [the ORT file format schema](ort.fbs) and [the generated C++ header file](ort.fbs.h) for the ORT file format. +This directory contains [the ORT file format schema](ort.fbs) and [the generated C++ header file](ort.fbs.h) for the +ORT file format. -[The ORT file format schema](ort.fbs) uses the [FlatBuffers](https://github.com/google/flatbuffers) serialization library. +[The ORT file format schema](ort.fbs) uses the [FlatBuffers](https://github.com/google/flatbuffers) serialization +library. -Please do not directly modify [the generated C++ header file](ort.fbs.h) or [the generated Python binding files](../ort_flatbuffers_py). +Please do not directly modify [the generated C++ header file](ort.fbs.h) or [the generated Python binding +files](../ort_flatbuffers_py). -The flatbuffers compiler (flatc) is built as part of an ONNX Runtime build. It is located in the external/flatbuffers subdirectory of the build output directory. +The flatbuffers compiler (flatc) is built as part of an ONNX Runtime build. It is located in the external/flatbuffers +subdirectory of the build output directory. e.g. - Windows Debug build @@ -13,7 +17,8 @@ e.g. - Linux Debug build - /build/Linux/external/flatbuffers/Debug/flatc -It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses FlatBuffers 1.12. +It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses +FlatBuffers 1.12. To update the ORT file format schema and generated files: 1. Modify [the ORT file format schema](ort.fbs). @@ -24,16 +29,31 @@ To update the ORT file format schema and generated files: ``` # ORT FB format version history -In [ort_format_version.h](../ort_format_version.h), see `IsOrtModelVersionSupported()` for version array and `kOrtModelVersion` for currently supported version. +In [ort_format_version.h](../ort_format_version.h), see `IsOrtModelVersionSupported()` for the supported versions and +`kOrtModelVersion` for the current version. -## Version 1. History begins -Initial support for FlatBuffers that includes Model support. Graph support including Attributes, Tensors, Tensor Sequences, Maps and Sequences. Constant initializers are also supported. Constant nodes are converted to constant initializers in the ORT format. +## Version 1 +History begins. -## Version 2. -Support for sparse initializers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from a Constant node attribute. +Initial support for FlatBuffers that includes Model support. Graph support including Attributes, Tensors, Tensor +Sequences, Maps and Sequences. Constant initializers are also supported. Constant nodes are converted to constant +initializers in the ORT format. -## Version 3. +## Version 2 +Support for sparse initializers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse +initializers converted from a Constant node attribute. + +## Version 3 Support for storing `graph_doc_string` field in Model (ORT FlatBuffers format). -## Version 4. +## Version 4 Update kernel def hashing to not depend on ordering of type constraint types (NOT BACKWARDS COMPATIBLE). + +## Version 5 +Deprecate kernel def hashes and add KernelTypeStrResolver info to replace them (NOT BACKWARDS COMPATIBLE). +The change to the ORT format itself is not backwards compatibility-breaking, but ORT does not provide backwards +compatibility for processing older models with missing KernelTypeStrResolver info. + +The motivation for this update is to support additional execution providers with statically registered kernels. +The original approach of using kernel def hashes is not so extensible as it requires the execution provider providing +hashes to be enabled at model conversion time. diff --git a/onnxruntime/core/flatbuffers/schema/compile_schema.py b/onnxruntime/core/flatbuffers/schema/compile_schema.py index 9f4372a5f32ad..55d332682b937 100644 --- a/onnxruntime/core/flatbuffers/schema/compile_schema.py +++ b/onnxruntime/core/flatbuffers/schema/compile_schema.py @@ -4,10 +4,10 @@ import argparse import pathlib +import shutil import subprocess import tempfile - SCRIPT_DIR = pathlib.Path(__file__).parent.resolve() @@ -22,16 +22,16 @@ def update_namespace(schema_path: pathlib.Path, updated_schema_path: pathlib.Pat output.write(line.replace('onnxruntime.fbs', 'ort_flatbuffers_py.fbs')) -def generate_python(flatc: pathlib.Path, schema_path: pathlib.Path): +def generate_python(flatc: pathlib.Path, schema_path: pathlib.Path, output_dir: pathlib.Path): # run flatc to generate Python code cmd = [str(flatc), '--python', str(schema_path)] - subprocess.run(cmd, check=True, cwd=SCRIPT_DIR.parent) + subprocess.run(cmd, check=True, cwd=output_dir) -def create_init_py(): +def create_init_py(output_dir: pathlib.Path): # create an __init__.py that imports all the py files so we can just 'import ort_flatbuffers_py.fbs' # in a script that wants to process an ORT format model - init_py_path = SCRIPT_DIR.parent / 'ort_flatbuffers_py/fbs/__init__.py' + init_py_path = output_dir / 'ort_flatbuffers_py/fbs/__init__.py' with open(init_py_path, 'w') as init_py: init_py.write('''from os.path import dirname, basename, isfile, join, splitext import glob @@ -69,10 +69,20 @@ def main(): if 'python' in languages: with tempfile.TemporaryDirectory() as temp_dir_name: - updated_schema_path = pathlib.Path(temp_dir_name, 'ort.py.fbs').resolve() + temp_dir = pathlib.Path(temp_dir_name).resolve() + updated_schema_path = temp_dir / 'ort.py.fbs' update_namespace(schema_path, updated_schema_path) - generate_python(flatc, updated_schema_path) - create_init_py() + + output_dir = temp_dir / 'out' + output_dir.mkdir() + generate_python(flatc, updated_schema_path, output_dir) + create_init_py(output_dir) + + # replace generated files in repo + target_dir = SCRIPT_DIR.parent / 'ort_flatbuffers_py' + if target_dir.is_dir(): + shutil.rmtree(target_dir) + shutil.move(str(output_dir / 'ort_flatbuffers_py'), str(target_dir)) if 'cpp' in languages: generate_cpp(flatc, schema_path) diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs b/onnxruntime/core/flatbuffers/schema/ort.fbs index ec3b09c3a8f94..62f4362938513 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs @@ -22,7 +22,7 @@ enum AttributeType : int32 { // Shape table Shape { -dim:[Dimension]; + dim:[Dimension]; } table Dimension { @@ -63,17 +63,17 @@ enum TensorDataType : int32 { BFLOAT16 = 16, } -table TensorTypeAndShape{ +table TensorTypeAndShape { elem_type:TensorDataType; shape:Shape; } -table MapType{ +table MapType { key_type:TensorDataType; value_type:onnxruntime.fbs.TypeInfo; } -table SequenceType{ +table SequenceType { elem_type:onnxruntime.fbs.TypeInfo; } @@ -151,7 +151,7 @@ table Tensor { raw_data:[uint8]; - // string_data is least used, leave it at the end + // string_data is least used string_data:[string]; } @@ -161,7 +161,7 @@ table SparseTensor { dims:[int64]; } -table Attribute{ +table Attribute { name:string; doc_string:string; @@ -184,7 +184,7 @@ table Attribute{ /// nodes to consider for a runtime optimization /// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h -table NodesToOptimizeIndices{ +table NodesToOptimizeIndices { node_indices:[uint32]; num_inputs:uint32; num_outputs:uint32; @@ -194,30 +194,32 @@ table NodesToOptimizeIndices{ num_variadic_outputs:uint32; } -table NodeIndexAndKernelDefHash{ +/// deprecated: no longer using kernel def hashes +table DeprecatedNodeIndexAndKernelDefHash { node_index:uint32; kernel_def_hash:uint64; } /// a single runtime optimization /// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h -table RuntimeOptimizationRecord{ +table RuntimeOptimizationRecord { action_id:string; nodes_to_optimize_indices:NodesToOptimizeIndices; - produced_nodes:[NodeIndexAndKernelDefHash]; + produced_nodes:[DeprecatedNodeIndexAndKernelDefHash] (deprecated); + produced_op_ids:[string]; } -table RuntimeOptimizationRecordContainerEntry{ +table RuntimeOptimizationRecordContainerEntry { optimizer_name:string (key); runtime_optimization_records:[RuntimeOptimizationRecord]; } -table RuntimeOptimizations{ +table RuntimeOptimizations { /// mapping from optimizer name to [RuntimeOptimizationRecord] records:[RuntimeOptimizationRecordContainerEntry]; } -table Graph{ +table Graph { initializers:[Tensor]; node_args:[ValueInfo]; @@ -253,21 +255,49 @@ table Model { metadata_props:[StringStringEntry]; } -table KernelCreateInfos { +/// deprecated: no longer using kernel def hashes +table DeprecatedKernelCreateInfos { node_indices:[uint32]; kernel_def_hashes:[uint64]; } -table SubGraphSessionState { - // graph_id can be used to binary search SubGraphSessionState in SessionState.sub_graph_session_states +/// deprecated: no longer using kernel def hashes +table DeprecatedSubGraphSessionState { + // graph_id can be used to binary search DeprecatedSubGraphSessionState in + // DeprecatedSessionState.sub_graph_session_states graph_id:string (key); - session_state:SessionState; + session_state:DeprecatedSessionState; } -table SessionState { - kernels:KernelCreateInfos; - sub_graph_session_states:[SubGraphSessionState]; +/// deprecated: no longer using kernel def hashes +table DeprecatedSessionState { + kernels:DeprecatedKernelCreateInfos; + sub_graph_session_states:[DeprecatedSubGraphSessionState]; +} + +enum ArgType : int8 { + INPUT = 0, + OUTPUT = 1, +} + +table ArgTypeAndIndex { + arg_type:ArgType; + index:uint32; +} + +table KernelTypeStrArgsEntry { + kernel_type_str:string (key); + args:[ArgTypeAndIndex]; +} + +table OpIdKernelTypeStrArgsEntry { + op_id:string (key); + kernel_type_str_args:[KernelTypeStrArgsEntry]; +} + +table KernelTypeStrResolver { + op_kernel_type_str_args:[OpIdKernelTypeStrArgsEntry]; } table InferenceSession { @@ -277,7 +307,9 @@ table InferenceSession { ort_version:string; model:Model; - session_state:SessionState; + session_state:DeprecatedSessionState (deprecated); + + kernel_type_str_resolver:KernelTypeStrResolver; } root_type InferenceSession; diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h index 1ee7675bf1cf5..827970c70e4c2 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h @@ -56,8 +56,8 @@ struct AttributeBuilder; struct NodesToOptimizeIndices; struct NodesToOptimizeIndicesBuilder; -struct NodeIndexAndKernelDefHash; -struct NodeIndexAndKernelDefHashBuilder; +struct DeprecatedNodeIndexAndKernelDefHash; +struct DeprecatedNodeIndexAndKernelDefHashBuilder; struct RuntimeOptimizationRecord; struct RuntimeOptimizationRecordBuilder; @@ -77,14 +77,26 @@ struct StringStringEntryBuilder; struct Model; struct ModelBuilder; -struct KernelCreateInfos; -struct KernelCreateInfosBuilder; +struct DeprecatedKernelCreateInfos; +struct DeprecatedKernelCreateInfosBuilder; -struct SubGraphSessionState; -struct SubGraphSessionStateBuilder; +struct DeprecatedSubGraphSessionState; +struct DeprecatedSubGraphSessionStateBuilder; -struct SessionState; -struct SessionStateBuilder; +struct DeprecatedSessionState; +struct DeprecatedSessionStateBuilder; + +struct ArgTypeAndIndex; +struct ArgTypeAndIndexBuilder; + +struct KernelTypeStrArgsEntry; +struct KernelTypeStrArgsEntryBuilder; + +struct OpIdKernelTypeStrArgsEntry; +struct OpIdKernelTypeStrArgsEntryBuilder; + +struct KernelTypeStrResolver; +struct KernelTypeStrResolverBuilder; struct InferenceSession; struct InferenceSessionBuilder; @@ -345,6 +357,36 @@ template<> struct TypeInfoValueTraits { bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj, TypeInfoValue type); bool VerifyTypeInfoValueVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); +enum class ArgType : int8_t { + INPUT = 0, + OUTPUT = 1, + MIN = INPUT, + MAX = OUTPUT +}; + +inline const ArgType (&EnumValuesArgType())[2] { + static const ArgType values[] = { + ArgType::INPUT, + ArgType::OUTPUT + }; + return values; +} + +inline const char * const *EnumNamesArgType() { + static const char * const names[3] = { + "INPUT", + "OUTPUT", + nullptr + }; + return names; +} + +inline const char *EnumNameArgType(ArgType e) { + if (flatbuffers::IsOutRange(e, ArgType::INPUT, ArgType::OUTPUT)) return ""; + const size_t index = static_cast(e); + return EnumNamesArgType()[index]; +} + FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) EdgeEnd FLATBUFFERS_FINAL_CLASS { private: uint32_t node_index_; @@ -1793,8 +1835,9 @@ inline flatbuffers::Offset CreateNodesToOptimizeIndicesD num_variadic_outputs); } -struct NodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef NodeIndexAndKernelDefHashBuilder Builder; +/// deprecated: no longer using kernel def hashes +struct DeprecatedNodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DeprecatedNodeIndexAndKernelDefHashBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_NODE_INDEX = 4, VT_KERNEL_DEF_HASH = 6 @@ -1813,33 +1856,33 @@ struct NodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private flatbuffers:: } }; -struct NodeIndexAndKernelDefHashBuilder { - typedef NodeIndexAndKernelDefHash Table; +struct DeprecatedNodeIndexAndKernelDefHashBuilder { + typedef DeprecatedNodeIndexAndKernelDefHash Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_node_index(uint32_t node_index) { - fbb_.AddElement(NodeIndexAndKernelDefHash::VT_NODE_INDEX, node_index, 0); + fbb_.AddElement(DeprecatedNodeIndexAndKernelDefHash::VT_NODE_INDEX, node_index, 0); } void add_kernel_def_hash(uint64_t kernel_def_hash) { - fbb_.AddElement(NodeIndexAndKernelDefHash::VT_KERNEL_DEF_HASH, kernel_def_hash, 0); + fbb_.AddElement(DeprecatedNodeIndexAndKernelDefHash::VT_KERNEL_DEF_HASH, kernel_def_hash, 0); } - explicit NodeIndexAndKernelDefHashBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit DeprecatedNodeIndexAndKernelDefHashBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - NodeIndexAndKernelDefHashBuilder &operator=(const NodeIndexAndKernelDefHashBuilder &); - flatbuffers::Offset Finish() { + DeprecatedNodeIndexAndKernelDefHashBuilder &operator=(const DeprecatedNodeIndexAndKernelDefHashBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateNodeIndexAndKernelDefHash( +inline flatbuffers::Offset CreateDeprecatedNodeIndexAndKernelDefHash( flatbuffers::FlatBufferBuilder &_fbb, uint32_t node_index = 0, uint64_t kernel_def_hash = 0) { - NodeIndexAndKernelDefHashBuilder builder_(_fbb); + DeprecatedNodeIndexAndKernelDefHashBuilder builder_(_fbb); builder_.add_kernel_def_hash(kernel_def_hash); builder_.add_node_index(node_index); return builder_.Finish(); @@ -1852,7 +1895,7 @@ struct RuntimeOptimizationRecord FLATBUFFERS_FINAL_CLASS : private flatbuffers:: enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_ACTION_ID = 4, VT_NODES_TO_OPTIMIZE_INDICES = 6, - VT_PRODUCED_NODES = 8 + VT_PRODUCED_OP_IDS = 10 }; const flatbuffers::String *action_id() const { return GetPointer(VT_ACTION_ID); @@ -1860,8 +1903,8 @@ struct RuntimeOptimizationRecord FLATBUFFERS_FINAL_CLASS : private flatbuffers:: const onnxruntime::fbs::NodesToOptimizeIndices *nodes_to_optimize_indices() const { return GetPointer(VT_NODES_TO_OPTIMIZE_INDICES); } - const flatbuffers::Vector> *produced_nodes() const { - return GetPointer> *>(VT_PRODUCED_NODES); + const flatbuffers::Vector> *produced_op_ids() const { + return GetPointer> *>(VT_PRODUCED_OP_IDS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1869,9 +1912,9 @@ struct RuntimeOptimizationRecord FLATBUFFERS_FINAL_CLASS : private flatbuffers:: verifier.VerifyString(action_id()) && VerifyOffset(verifier, VT_NODES_TO_OPTIMIZE_INDICES) && verifier.VerifyTable(nodes_to_optimize_indices()) && - VerifyOffset(verifier, VT_PRODUCED_NODES) && - verifier.VerifyVector(produced_nodes()) && - verifier.VerifyVectorOfTables(produced_nodes()) && + VerifyOffset(verifier, VT_PRODUCED_OP_IDS) && + verifier.VerifyVector(produced_op_ids()) && + verifier.VerifyVectorOfStrings(produced_op_ids()) && verifier.EndTable(); } }; @@ -1886,8 +1929,8 @@ struct RuntimeOptimizationRecordBuilder { void add_nodes_to_optimize_indices(flatbuffers::Offset nodes_to_optimize_indices) { fbb_.AddOffset(RuntimeOptimizationRecord::VT_NODES_TO_OPTIMIZE_INDICES, nodes_to_optimize_indices); } - void add_produced_nodes(flatbuffers::Offset>> produced_nodes) { - fbb_.AddOffset(RuntimeOptimizationRecord::VT_PRODUCED_NODES, produced_nodes); + void add_produced_op_ids(flatbuffers::Offset>> produced_op_ids) { + fbb_.AddOffset(RuntimeOptimizationRecord::VT_PRODUCED_OP_IDS, produced_op_ids); } explicit RuntimeOptimizationRecordBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -1905,9 +1948,9 @@ inline flatbuffers::Offset CreateRuntimeOptimizationR flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset action_id = 0, flatbuffers::Offset nodes_to_optimize_indices = 0, - flatbuffers::Offset>> produced_nodes = 0) { + flatbuffers::Offset>> produced_op_ids = 0) { RuntimeOptimizationRecordBuilder builder_(_fbb); - builder_.add_produced_nodes(produced_nodes); + builder_.add_produced_op_ids(produced_op_ids); builder_.add_nodes_to_optimize_indices(nodes_to_optimize_indices); builder_.add_action_id(action_id); return builder_.Finish(); @@ -1917,14 +1960,14 @@ inline flatbuffers::Offset CreateRuntimeOptimizationR flatbuffers::FlatBufferBuilder &_fbb, const char *action_id = nullptr, flatbuffers::Offset nodes_to_optimize_indices = 0, - const std::vector> *produced_nodes = nullptr) { + const std::vector> *produced_op_ids = nullptr) { auto action_id__ = action_id ? _fbb.CreateString(action_id) : 0; - auto produced_nodes__ = produced_nodes ? _fbb.CreateVector>(*produced_nodes) : 0; + auto produced_op_ids__ = produced_op_ids ? _fbb.CreateVector>(*produced_op_ids) : 0; return onnxruntime::fbs::CreateRuntimeOptimizationRecord( _fbb, action_id__, nodes_to_optimize_indices, - produced_nodes__); + produced_op_ids__); } struct RuntimeOptimizationRecordContainerEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -2464,8 +2507,9 @@ inline flatbuffers::Offset CreateModelDirect( metadata_props__); } -struct KernelCreateInfos FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef KernelCreateInfosBuilder Builder; +/// deprecated: no longer using kernel def hashes +struct DeprecatedKernelCreateInfos FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DeprecatedKernelCreateInfosBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_NODE_INDICES = 4, VT_KERNEL_DEF_HASHES = 6 @@ -2486,52 +2530,53 @@ struct KernelCreateInfos FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } }; -struct KernelCreateInfosBuilder { - typedef KernelCreateInfos Table; +struct DeprecatedKernelCreateInfosBuilder { + typedef DeprecatedKernelCreateInfos Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_node_indices(flatbuffers::Offset> node_indices) { - fbb_.AddOffset(KernelCreateInfos::VT_NODE_INDICES, node_indices); + fbb_.AddOffset(DeprecatedKernelCreateInfos::VT_NODE_INDICES, node_indices); } void add_kernel_def_hashes(flatbuffers::Offset> kernel_def_hashes) { - fbb_.AddOffset(KernelCreateInfos::VT_KERNEL_DEF_HASHES, kernel_def_hashes); + fbb_.AddOffset(DeprecatedKernelCreateInfos::VT_KERNEL_DEF_HASHES, kernel_def_hashes); } - explicit KernelCreateInfosBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit DeprecatedKernelCreateInfosBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - KernelCreateInfosBuilder &operator=(const KernelCreateInfosBuilder &); - flatbuffers::Offset Finish() { + DeprecatedKernelCreateInfosBuilder &operator=(const DeprecatedKernelCreateInfosBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateKernelCreateInfos( +inline flatbuffers::Offset CreateDeprecatedKernelCreateInfos( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> node_indices = 0, flatbuffers::Offset> kernel_def_hashes = 0) { - KernelCreateInfosBuilder builder_(_fbb); + DeprecatedKernelCreateInfosBuilder builder_(_fbb); builder_.add_kernel_def_hashes(kernel_def_hashes); builder_.add_node_indices(node_indices); return builder_.Finish(); } -inline flatbuffers::Offset CreateKernelCreateInfosDirect( +inline flatbuffers::Offset CreateDeprecatedKernelCreateInfosDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *node_indices = nullptr, const std::vector *kernel_def_hashes = nullptr) { auto node_indices__ = node_indices ? _fbb.CreateVector(*node_indices) : 0; auto kernel_def_hashes__ = kernel_def_hashes ? _fbb.CreateVector(*kernel_def_hashes) : 0; - return onnxruntime::fbs::CreateKernelCreateInfos( + return onnxruntime::fbs::CreateDeprecatedKernelCreateInfos( _fbb, node_indices__, kernel_def_hashes__); } -struct SubGraphSessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef SubGraphSessionStateBuilder Builder; +/// deprecated: no longer using kernel def hashes +struct DeprecatedSubGraphSessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DeprecatedSubGraphSessionStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_GRAPH_ID = 4, VT_SESSION_STATE = 6 @@ -2539,14 +2584,14 @@ struct SubGraphSessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table const flatbuffers::String *graph_id() const { return GetPointer(VT_GRAPH_ID); } - bool KeyCompareLessThan(const SubGraphSessionState *o) const { + bool KeyCompareLessThan(const DeprecatedSubGraphSessionState *o) const { return *graph_id() < *o->graph_id(); } int KeyCompareWithValue(const char *val) const { return strcmp(graph_id()->c_str(), val); } - const onnxruntime::fbs::SessionState *session_state() const { - return GetPointer(VT_SESSION_STATE); + const onnxruntime::fbs::DeprecatedSessionState *session_state() const { + return GetPointer(VT_SESSION_STATE); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2558,61 +2603,62 @@ struct SubGraphSessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table } }; -struct SubGraphSessionStateBuilder { - typedef SubGraphSessionState Table; +struct DeprecatedSubGraphSessionStateBuilder { + typedef DeprecatedSubGraphSessionState Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_graph_id(flatbuffers::Offset graph_id) { - fbb_.AddOffset(SubGraphSessionState::VT_GRAPH_ID, graph_id); + fbb_.AddOffset(DeprecatedSubGraphSessionState::VT_GRAPH_ID, graph_id); } - void add_session_state(flatbuffers::Offset session_state) { - fbb_.AddOffset(SubGraphSessionState::VT_SESSION_STATE, session_state); + void add_session_state(flatbuffers::Offset session_state) { + fbb_.AddOffset(DeprecatedSubGraphSessionState::VT_SESSION_STATE, session_state); } - explicit SubGraphSessionStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit DeprecatedSubGraphSessionStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - SubGraphSessionStateBuilder &operator=(const SubGraphSessionStateBuilder &); - flatbuffers::Offset Finish() { + DeprecatedSubGraphSessionStateBuilder &operator=(const DeprecatedSubGraphSessionStateBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); - fbb_.Required(o, SubGraphSessionState::VT_GRAPH_ID); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, DeprecatedSubGraphSessionState::VT_GRAPH_ID); return o; } }; -inline flatbuffers::Offset CreateSubGraphSessionState( +inline flatbuffers::Offset CreateDeprecatedSubGraphSessionState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset graph_id = 0, - flatbuffers::Offset session_state = 0) { - SubGraphSessionStateBuilder builder_(_fbb); + flatbuffers::Offset session_state = 0) { + DeprecatedSubGraphSessionStateBuilder builder_(_fbb); builder_.add_session_state(session_state); builder_.add_graph_id(graph_id); return builder_.Finish(); } -inline flatbuffers::Offset CreateSubGraphSessionStateDirect( +inline flatbuffers::Offset CreateDeprecatedSubGraphSessionStateDirect( flatbuffers::FlatBufferBuilder &_fbb, const char *graph_id = nullptr, - flatbuffers::Offset session_state = 0) { + flatbuffers::Offset session_state = 0) { auto graph_id__ = graph_id ? _fbb.CreateString(graph_id) : 0; - return onnxruntime::fbs::CreateSubGraphSessionState( + return onnxruntime::fbs::CreateDeprecatedSubGraphSessionState( _fbb, graph_id__, session_state); } -struct SessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef SessionStateBuilder Builder; +/// deprecated: no longer using kernel def hashes +struct DeprecatedSessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DeprecatedSessionStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_KERNELS = 4, VT_SUB_GRAPH_SESSION_STATES = 6 }; - const onnxruntime::fbs::KernelCreateInfos *kernels() const { - return GetPointer(VT_KERNELS); + const onnxruntime::fbs::DeprecatedKernelCreateInfos *kernels() const { + return GetPointer(VT_KERNELS); } - const flatbuffers::Vector> *sub_graph_session_states() const { - return GetPointer> *>(VT_SUB_GRAPH_SESSION_STATES); + const flatbuffers::Vector> *sub_graph_session_states() const { + return GetPointer> *>(VT_SUB_GRAPH_SESSION_STATES); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2625,55 +2671,308 @@ struct SessionState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } }; -struct SessionStateBuilder { - typedef SessionState Table; +struct DeprecatedSessionStateBuilder { + typedef DeprecatedSessionState Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_kernels(flatbuffers::Offset kernels) { - fbb_.AddOffset(SessionState::VT_KERNELS, kernels); + void add_kernels(flatbuffers::Offset kernels) { + fbb_.AddOffset(DeprecatedSessionState::VT_KERNELS, kernels); } - void add_sub_graph_session_states(flatbuffers::Offset>> sub_graph_session_states) { - fbb_.AddOffset(SessionState::VT_SUB_GRAPH_SESSION_STATES, sub_graph_session_states); + void add_sub_graph_session_states(flatbuffers::Offset>> sub_graph_session_states) { + fbb_.AddOffset(DeprecatedSessionState::VT_SUB_GRAPH_SESSION_STATES, sub_graph_session_states); } - explicit SessionStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit DeprecatedSessionStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - SessionStateBuilder &operator=(const SessionStateBuilder &); - flatbuffers::Offset Finish() { + DeprecatedSessionStateBuilder &operator=(const DeprecatedSessionStateBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateSessionState( +inline flatbuffers::Offset CreateDeprecatedSessionState( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset kernels = 0, - flatbuffers::Offset>> sub_graph_session_states = 0) { - SessionStateBuilder builder_(_fbb); + flatbuffers::Offset kernels = 0, + flatbuffers::Offset>> sub_graph_session_states = 0) { + DeprecatedSessionStateBuilder builder_(_fbb); builder_.add_sub_graph_session_states(sub_graph_session_states); builder_.add_kernels(kernels); return builder_.Finish(); } -inline flatbuffers::Offset CreateSessionStateDirect( +inline flatbuffers::Offset CreateDeprecatedSessionStateDirect( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset kernels = 0, - std::vector> *sub_graph_session_states = nullptr) { - auto sub_graph_session_states__ = sub_graph_session_states ? _fbb.CreateVectorOfSortedTables(sub_graph_session_states) : 0; - return onnxruntime::fbs::CreateSessionState( + flatbuffers::Offset kernels = 0, + std::vector> *sub_graph_session_states = nullptr) { + auto sub_graph_session_states__ = sub_graph_session_states ? _fbb.CreateVectorOfSortedTables(sub_graph_session_states) : 0; + return onnxruntime::fbs::CreateDeprecatedSessionState( _fbb, kernels, sub_graph_session_states__); } +struct ArgTypeAndIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArgTypeAndIndexBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ARG_TYPE = 4, + VT_INDEX = 6 + }; + onnxruntime::fbs::ArgType arg_type() const { + return static_cast(GetField(VT_ARG_TYPE, 0)); + } + uint32_t index() const { + return GetField(VT_INDEX, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ARG_TYPE) && + VerifyField(verifier, VT_INDEX) && + verifier.EndTable(); + } +}; + +struct ArgTypeAndIndexBuilder { + typedef ArgTypeAndIndex Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_arg_type(onnxruntime::fbs::ArgType arg_type) { + fbb_.AddElement(ArgTypeAndIndex::VT_ARG_TYPE, static_cast(arg_type), 0); + } + void add_index(uint32_t index) { + fbb_.AddElement(ArgTypeAndIndex::VT_INDEX, index, 0); + } + explicit ArgTypeAndIndexBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArgTypeAndIndexBuilder &operator=(const ArgTypeAndIndexBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArgTypeAndIndex( + flatbuffers::FlatBufferBuilder &_fbb, + onnxruntime::fbs::ArgType arg_type = onnxruntime::fbs::ArgType::INPUT, + uint32_t index = 0) { + ArgTypeAndIndexBuilder builder_(_fbb); + builder_.add_index(index); + builder_.add_arg_type(arg_type); + return builder_.Finish(); +} + +struct KernelTypeStrArgsEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KernelTypeStrArgsEntryBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KERNEL_TYPE_STR = 4, + VT_ARGS = 6 + }; + const flatbuffers::String *kernel_type_str() const { + return GetPointer(VT_KERNEL_TYPE_STR); + } + bool KeyCompareLessThan(const KernelTypeStrArgsEntry *o) const { + return *kernel_type_str() < *o->kernel_type_str(); + } + int KeyCompareWithValue(const char *val) const { + return strcmp(kernel_type_str()->c_str(), val); + } + const flatbuffers::Vector> *args() const { + return GetPointer> *>(VT_ARGS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KERNEL_TYPE_STR) && + verifier.VerifyString(kernel_type_str()) && + VerifyOffset(verifier, VT_ARGS) && + verifier.VerifyVector(args()) && + verifier.VerifyVectorOfTables(args()) && + verifier.EndTable(); + } +}; + +struct KernelTypeStrArgsEntryBuilder { + typedef KernelTypeStrArgsEntry Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_kernel_type_str(flatbuffers::Offset kernel_type_str) { + fbb_.AddOffset(KernelTypeStrArgsEntry::VT_KERNEL_TYPE_STR, kernel_type_str); + } + void add_args(flatbuffers::Offset>> args) { + fbb_.AddOffset(KernelTypeStrArgsEntry::VT_ARGS, args); + } + explicit KernelTypeStrArgsEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KernelTypeStrArgsEntryBuilder &operator=(const KernelTypeStrArgsEntryBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KernelTypeStrArgsEntry::VT_KERNEL_TYPE_STR); + return o; + } +}; + +inline flatbuffers::Offset CreateKernelTypeStrArgsEntry( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset kernel_type_str = 0, + flatbuffers::Offset>> args = 0) { + KernelTypeStrArgsEntryBuilder builder_(_fbb); + builder_.add_args(args); + builder_.add_kernel_type_str(kernel_type_str); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKernelTypeStrArgsEntryDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *kernel_type_str = nullptr, + const std::vector> *args = nullptr) { + auto kernel_type_str__ = kernel_type_str ? _fbb.CreateString(kernel_type_str) : 0; + auto args__ = args ? _fbb.CreateVector>(*args) : 0; + return onnxruntime::fbs::CreateKernelTypeStrArgsEntry( + _fbb, + kernel_type_str__, + args__); +} + +struct OpIdKernelTypeStrArgsEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OpIdKernelTypeStrArgsEntryBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OP_ID = 4, + VT_KERNEL_TYPE_STR_ARGS = 6 + }; + const flatbuffers::String *op_id() const { + return GetPointer(VT_OP_ID); + } + bool KeyCompareLessThan(const OpIdKernelTypeStrArgsEntry *o) const { + return *op_id() < *o->op_id(); + } + int KeyCompareWithValue(const char *val) const { + return strcmp(op_id()->c_str(), val); + } + const flatbuffers::Vector> *kernel_type_str_args() const { + return GetPointer> *>(VT_KERNEL_TYPE_STR_ARGS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_OP_ID) && + verifier.VerifyString(op_id()) && + VerifyOffset(verifier, VT_KERNEL_TYPE_STR_ARGS) && + verifier.VerifyVector(kernel_type_str_args()) && + verifier.VerifyVectorOfTables(kernel_type_str_args()) && + verifier.EndTable(); + } +}; + +struct OpIdKernelTypeStrArgsEntryBuilder { + typedef OpIdKernelTypeStrArgsEntry Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_op_id(flatbuffers::Offset op_id) { + fbb_.AddOffset(OpIdKernelTypeStrArgsEntry::VT_OP_ID, op_id); + } + void add_kernel_type_str_args(flatbuffers::Offset>> kernel_type_str_args) { + fbb_.AddOffset(OpIdKernelTypeStrArgsEntry::VT_KERNEL_TYPE_STR_ARGS, kernel_type_str_args); + } + explicit OpIdKernelTypeStrArgsEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OpIdKernelTypeStrArgsEntryBuilder &operator=(const OpIdKernelTypeStrArgsEntryBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, OpIdKernelTypeStrArgsEntry::VT_OP_ID); + return o; + } +}; + +inline flatbuffers::Offset CreateOpIdKernelTypeStrArgsEntry( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset op_id = 0, + flatbuffers::Offset>> kernel_type_str_args = 0) { + OpIdKernelTypeStrArgsEntryBuilder builder_(_fbb); + builder_.add_kernel_type_str_args(kernel_type_str_args); + builder_.add_op_id(op_id); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOpIdKernelTypeStrArgsEntryDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *op_id = nullptr, + std::vector> *kernel_type_str_args = nullptr) { + auto op_id__ = op_id ? _fbb.CreateString(op_id) : 0; + auto kernel_type_str_args__ = kernel_type_str_args ? _fbb.CreateVectorOfSortedTables(kernel_type_str_args) : 0; + return onnxruntime::fbs::CreateOpIdKernelTypeStrArgsEntry( + _fbb, + op_id__, + kernel_type_str_args__); +} + +struct KernelTypeStrResolver FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KernelTypeStrResolverBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OP_KERNEL_TYPE_STR_ARGS = 4 + }; + const flatbuffers::Vector> *op_kernel_type_str_args() const { + return GetPointer> *>(VT_OP_KERNEL_TYPE_STR_ARGS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OP_KERNEL_TYPE_STR_ARGS) && + verifier.VerifyVector(op_kernel_type_str_args()) && + verifier.VerifyVectorOfTables(op_kernel_type_str_args()) && + verifier.EndTable(); + } +}; + +struct KernelTypeStrResolverBuilder { + typedef KernelTypeStrResolver Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_op_kernel_type_str_args(flatbuffers::Offset>> op_kernel_type_str_args) { + fbb_.AddOffset(KernelTypeStrResolver::VT_OP_KERNEL_TYPE_STR_ARGS, op_kernel_type_str_args); + } + explicit KernelTypeStrResolverBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KernelTypeStrResolverBuilder &operator=(const KernelTypeStrResolverBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateKernelTypeStrResolver( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> op_kernel_type_str_args = 0) { + KernelTypeStrResolverBuilder builder_(_fbb); + builder_.add_op_kernel_type_str_args(op_kernel_type_str_args); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKernelTypeStrResolverDirect( + flatbuffers::FlatBufferBuilder &_fbb, + std::vector> *op_kernel_type_str_args = nullptr) { + auto op_kernel_type_str_args__ = op_kernel_type_str_args ? _fbb.CreateVectorOfSortedTables(op_kernel_type_str_args) : 0; + return onnxruntime::fbs::CreateKernelTypeStrResolver( + _fbb, + op_kernel_type_str_args__); +} + struct InferenceSession FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef InferenceSessionBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_ORT_VERSION = 4, VT_MODEL = 6, - VT_SESSION_STATE = 8 + VT_KERNEL_TYPE_STR_RESOLVER = 10 }; const flatbuffers::String *ort_version() const { return GetPointer(VT_ORT_VERSION); @@ -2681,8 +2980,8 @@ struct InferenceSession FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const onnxruntime::fbs::Model *model() const { return GetPointer(VT_MODEL); } - const onnxruntime::fbs::SessionState *session_state() const { - return GetPointer(VT_SESSION_STATE); + const onnxruntime::fbs::KernelTypeStrResolver *kernel_type_str_resolver() const { + return GetPointer(VT_KERNEL_TYPE_STR_RESOLVER); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2690,8 +2989,8 @@ struct InferenceSession FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(ort_version()) && VerifyOffset(verifier, VT_MODEL) && verifier.VerifyTable(model()) && - VerifyOffset(verifier, VT_SESSION_STATE) && - verifier.VerifyTable(session_state()) && + VerifyOffset(verifier, VT_KERNEL_TYPE_STR_RESOLVER) && + verifier.VerifyTable(kernel_type_str_resolver()) && verifier.EndTable(); } }; @@ -2706,8 +3005,8 @@ struct InferenceSessionBuilder { void add_model(flatbuffers::Offset model) { fbb_.AddOffset(InferenceSession::VT_MODEL, model); } - void add_session_state(flatbuffers::Offset session_state) { - fbb_.AddOffset(InferenceSession::VT_SESSION_STATE, session_state); + void add_kernel_type_str_resolver(flatbuffers::Offset kernel_type_str_resolver) { + fbb_.AddOffset(InferenceSession::VT_KERNEL_TYPE_STR_RESOLVER, kernel_type_str_resolver); } explicit InferenceSessionBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -2725,9 +3024,9 @@ inline flatbuffers::Offset CreateInferenceSession( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset ort_version = 0, flatbuffers::Offset model = 0, - flatbuffers::Offset session_state = 0) { + flatbuffers::Offset kernel_type_str_resolver = 0) { InferenceSessionBuilder builder_(_fbb); - builder_.add_session_state(session_state); + builder_.add_kernel_type_str_resolver(kernel_type_str_resolver); builder_.add_model(model); builder_.add_ort_version(ort_version); return builder_.Finish(); @@ -2737,13 +3036,13 @@ inline flatbuffers::Offset CreateInferenceSessionDirect( flatbuffers::FlatBufferBuilder &_fbb, const char *ort_version = nullptr, flatbuffers::Offset model = 0, - flatbuffers::Offset session_state = 0) { + flatbuffers::Offset kernel_type_str_resolver = 0) { auto ort_version__ = ort_version ? _fbb.CreateString(ort_version) : 0; return onnxruntime::fbs::CreateInferenceSession( _fbb, ort_version__, model, - session_state); + kernel_type_str_resolver); } inline bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj, TypeInfoValue type) { diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index ab0a2b3897571..45b03112d0410 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -28,28 +28,18 @@ AllocatorPtr IExecutionProvider::GetAllocator(int device_id, OrtMemType mem_type std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { + const IKernelLookup& kernel_lookup) const { std::vector> result; -#if !defined(ORT_MINIMAL_BUILD) - for (auto& node : graph.Nodes()) { - for (auto registry : kernel_registries) { - if (KernelRegistry::HasImplementationOf(*registry, node, Type())) { - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->nodes.push_back(node.Index()); - result.push_back(std::make_unique(std::move(sub_graph))); - break; - } + for (const auto& node : graph.Nodes()) { + if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); + kernel_create_info != nullptr) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node.Index()); + result.push_back(std::make_unique(std::move(sub_graph))); } } return result; -#else - // We have saved hashes to lookup static kernels in an ORT format model so the default behavior is to return an - // empty vector to leave that in place. An EP that compiles nodes can override this in a minimal build. - ORT_UNUSED_PARAMETER(graph); - ORT_UNUSED_PARAMETER(kernel_registries); - return result; -#endif } // Update allocator in the provider if already present; ignore if not. diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 46b825a1496ca..ab8266cf939fc 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -39,8 +39,7 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node } // namespace std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes) { // automatic conversion from const std::vector& const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); @@ -69,12 +68,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe provider_nodes.insert(node_id); const Node* node = graph.GetNode(node_id); - const KernelCreateInfo* kernel_info = nullptr; - for (auto registry : kernel_registries) { - auto st = registry->TryFindKernel(*node, provider_type, &kernel_info); - if (st.IsOK()) - break; - } + const KernelCreateInfo* kernel_info = kernel_lookup.LookUpKernel(*node); // at least one registry has a target provider's kernel for this node ORT_ENFORCE(kernel_info != nullptr); node_to_kernel.insert({node_id, kernel_info}); diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index 531c1c07464e1..b6015cdd576f4 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,10 +3,11 @@ #pragma once +#include + #include "core/common/inlined_containers_fwd.h" -#include "core/framework/kernel_registry.h" +#include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup #include "core/graph/graph_viewer.h" -#include namespace onnxruntime { @@ -14,13 +15,11 @@ namespace onnxruntime { Returns a list of nodes that are preferred on CPU. They are commonly shape-related computation subgraphs. @param graph Graph viewer - @param provider_type The target execution provider type - @param kernel_registries Kernel registries for the target EP + @param kernel_lookup The kernel lookup for the target execution provider @param tentative_nodes Nodes that are tentative to be placed on on target EP */ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index c96d932497bd6..9749dc5676fe2 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #include "core/framework/graph_partitioner.h" -#include "core/framework/kernel_registry_manager.h" -#include "core/graph/function.h" -#include "core/graph/graph_viewer.h" + +#include + #include "core/framework/compute_capability.h" -#include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" -#include "core/framework/kernel_registry.h" #include "core/framework/func_kernel.h" +#include "core/framework/kernel_lookup.h" +#include "core/framework/kernel_registry_manager.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/function.h" +#include "core/graph/graph_viewer.h" // uncomment this line to count non-CUDA ops in ONNX domain //#define COUNT_NON_CUDA_OPS @@ -40,9 +41,25 @@ class NonCudaOps { NonCudaOps non_cuda; #endif -using namespace ::onnxruntime::common; namespace onnxruntime { +namespace { + +// contains some common parameters used by the partitioning helper functions +struct PartitionParams { + std::reference_wrapper graph; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + std::reference_wrapper func_mgr; + std::reference_wrapper fused_kernel_registry; + std::reference_wrapper fused_node_unique_id; + TransformLayoutFunction transform_layout_function; +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +}; +} // namespace + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // minimal KernelDef based on MetaDef instead of a Function based node static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef, const std::string& provider_type) { @@ -52,8 +69,6 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph .Provider(provider_type); } -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - /// /// Check if a node can be placed on a specific provider. If yes, then set the nodes execution provider. /// Do nothing if the node is already assigned. @@ -61,15 +76,15 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph /// Graph in question. /// Indexed subgraph which needs to be assigned /// The EP to assign the Indexed subgraph to -static void AssignNodes(Graph& graph, const IndexedSubGraph& capability, - const std::string& provider_type) { +static bool TryAssignNodes(Graph& graph, const IndexedSubGraph& capability, + const std::string& provider_type) { // Before assigning the ep to any node, first walk through all the nodes and ensure // none of the nodes have already been assigned. If a node is assigned, simply return. for (auto node_index : capability.nodes) { const auto* node = graph.GetNode(node_index); if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { - return; + return false; } } @@ -77,10 +92,29 @@ static void AssignNodes(Graph& graph, const IndexedSubGraph& capability, auto* node = graph.GetNode(node_index); node->SetExecutionProviderType(provider_type); } + + return true; } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +static bool TryAssignSingleNode(Graph& graph, + const IndexedSubGraph& indexed_sub_graph, + const std::string& provider_type) { + // The provider can run a single node in the if not using meta-defs. + // A fused kernel is not supported in this case. + ORT_ENFORCE(1 == indexed_sub_graph.nodes.size()); + + auto* node = graph.GetNode(indexed_sub_graph.nodes[0]); + if (nullptr != node && node->GetExecutionProviderType().empty()) { + // The node was not fused or assigned. Assign it to . + node->SetExecutionProviderType(provider_type); + return true; + } + + return false; +} + static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, GraphPartitioner::Mode mode, std::vector>& capabilities, @@ -93,10 +127,30 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg return Status::OK(); } + auto get_capabilities = [](const IExecutionProvider& ep, + const GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + + // In theory an EP could return an empty capability. Remove those. + capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), + [](const std::unique_ptr& capability) { + return !capability || !capability->sub_graph; + }), + capabilities.end()); + + return capabilities; + }; + + const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); + const KernelLookup kernel_lookup{ep_type, + kernel_registries_for_ep, + kernel_registry_mgr.GetKernelTypeStrResolver()}; + { - GraphViewer graph_viewer(graph); - capabilities = current_ep.GetCapability(graph_viewer, - kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type)); + const GraphViewer graph_viewer(graph); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + if (capabilities.empty()) { return Status::OK(); } @@ -108,12 +162,7 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg if (mode != GraphPartitioner::Mode::kAssignOnly && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { - // in theory an EP could return an empty value... - if (!capability || !capability->sub_graph) { - continue; - } - - AssignNodes(graph, *capability->sub_graph, ep_type); + TryAssignNodes(graph, *capability->sub_graph, ep_type); } const NodeIndex first_new_node = graph.MaxNodeIndex(); @@ -136,9 +185,9 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg const NodeIndex end_node = graph.MaxNodeIndex(); capabilities.clear(); - GraphViewer graph_viewer(graph); - capabilities = current_ep.GetCapability(graph_viewer, - kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type)); + + const GraphViewer graph_viewer(graph); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -164,6 +213,8 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg } } } +#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + ORT_UNUSED_PARAMETER(mode); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) return Status::OK(); @@ -177,7 +228,8 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg * \param capability * \param kernel_registry_mgr * \param provider_type name of the provider to test - * \param count A counter for generating fused node names. Unique across the entire model. + * \param mode + * \param fused_node_unique_id A counter for generating fused node names. Unique across the entire model. * \return Fused node. Return nullptr if there is no fuse */ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, @@ -188,15 +240,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, Node* result = nullptr; if (nullptr == capability.GetMetaDef()) { - // The can run a single node in the if not using meta-defs. - // A fused kernel is not supported in this case. - ORT_ENFORCE(1 == capability.nodes.size()); - - auto* node = graph.GetNode(capability.nodes[0]); - if (nullptr != node && node->GetExecutionProviderType().empty()) { - // The node was not fused or assigned. Assign it to this . - node->SetExecutionProviderType(provider_type); - } + TryAssignSingleNode(graph, capability, provider_type); } else { // The can run a fused in the . @@ -291,7 +335,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, for (auto& node : graph.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { Graph* subgraph = entry.second; - // we pass through the export_dll value and FuncManager from the top level graph + // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, transform_layout_function)); @@ -335,11 +379,6 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, entry->sub_graph->GetMetaDef() != nullptr; })); for (auto& capability : capabilities) { - // in theory an EP could return an empty value... - if (!capability || !capability->sub_graph) { - continue; - } - Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach @@ -477,16 +516,21 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } -Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr, - KernelRegistry& fused_kernel_registry, Mode mode, - int& fused_node_unique_id, - TransformLayoutFunction transform_layout_function) const { +static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, + const ExecutionProviders& execution_providers, + KernelRegistryManager& kernel_registry_manager) { bool modified_graph = false; + auto& graph = partition_params.graph.get(); + auto& func_mgr = partition_params.func_mgr.get(); + auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); + auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); + const auto& transform_layout_function = partition_params.transform_layout_function; + do { // process full graph with each EP - for (const auto& ep : providers_) { - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, + for (const auto& ep : execution_providers) { + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function)); } @@ -506,94 +550,97 @@ Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, FuncManager& fun #endif // !defined(ORT_MINIMAL_BUILD) -static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, +static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, - KernelRegistry& fused_kernel_registry, - IExecutionProvider& current_ep, - std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id, - TransformLayoutFunction transform_layout_function) { - // recurse into nested graphs first to partition bottom up. - for (auto& node : graph.Nodes()) { - for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - Graph* subgraph = entry.second; - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, - current_ep, compiled_kernel_hashes, fused_node_unique_id, - transform_layout_function)); - } - } - + IExecutionProvider& current_ep) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability + auto& graph = partition_params.graph.get(); if (graph.NumberOfNodes() == 0) { return Status::OK(); } + // recurse into nested graphs first to partition bottom up. + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + auto& subgraph = *entry.second; + PartitionParams subgraph_partition_params = partition_params; + subgraph_partition_params.graph = std::ref(subgraph); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep)); + } + } + const std::string& type = current_ep.Type(); - std::vector nodes_and_viewers; std::vector> capabilities; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + TransformLayoutFunction transform_layout_function = partition_params.transform_layout_function; +#else + TransformLayoutFunction transform_layout_function{}; +#endif ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep, - GraphPartitioner::Mode::kOrtFormatLoad, capabilities, transform_layout_function)); + GraphPartitioner::Mode::kOrtFormatLoad, capabilities, + transform_layout_function)); if (capabilities.empty()) { return Status::OK(); } - // storage for the GraphViewer for each IndexedSubGraph - std::vector> viewers; - viewers.reserve(capabilities.size()); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + struct CompilationEntry { + std::unique_ptr viewer; + std::reference_wrapper fused_node; + std::reference_wrapper capability; + }; + std::vector compilation_entries; + compilation_entries.reserve(capabilities.size()); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - for (auto& capability : capabilities) { + for (const auto& capability : capabilities) { const IndexedSubGraph& indexed_sub_graph = *capability->sub_graph; const IndexedSubGraph::MetaDef* metadef = indexed_sub_graph.GetMetaDef(); if (!metadef) { - // Static kernel - use the kernel hash that was saved in the ORT format model. - auto* node = graph.GetNode(indexed_sub_graph.nodes[0]); - if (nullptr != node && node->GetExecutionProviderType().empty()) { - // The node was not fused or assigned. Assign it to this . - node->SetExecutionProviderType(type); - } - continue; - } - - std::ostringstream oss; - oss << type << "_" << metadef->name << "_" << fused_node_unique_id++; - std::string node_name = oss.str(); + TryAssignSingleNode(graph, indexed_sub_graph, type); + } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + std::ostringstream oss; + oss << type << "_" << metadef->name << "_" << partition_params.fused_node_unique_id++; + const std::string node_name = oss.str(); - Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name); - fused_node.SetExecutionProviderType(type); + Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name); + fused_node.SetExecutionProviderType(type); - // create filtered graph viewer for this set of nodes - // - // TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing - // GraphViewer instance instead of the Graph (copying the topological order instead of recalculating). - viewers.push_back(std::make_unique(graph, indexed_sub_graph)); - nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()}); + // create filtered graph viewer for this set of nodes + // + // TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing + // GraphViewer instance instead of the Graph (copying the topological order instead of recalculating). + auto viewer = std::make_unique(graph, indexed_sub_graph); + compilation_entries.push_back(CompilationEntry{std::move(viewer), fused_node, *capability}); +#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Compiling capabilities is not supported in this build."); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + } } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // We will compile the fused nodes one by one, and fuse the subgraph if successful. - // If a compilation fails we undo the fusion and leave the original nodes available for other EPs to take - for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) { - Node& node = nodes_and_viewers[j].fused_node; + for (const auto& compilation_entry : compilation_entries) { + Node& node = compilation_entry.fused_node; std::vector single_node_compute_func; - ORT_RETURN_IF_ERROR(current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func)); + ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, + single_node_compute_func)); ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element."); + auto& func_mgr = partition_params.func_mgr.get(); ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0]))); - const auto& cur_capability = capabilities[j]; - const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + const ComputeCapability& cur_capability = compilation_entry.capability; + const IndexedSubGraph& indexed_sub_graph = *cur_capability.sub_graph; const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); KernelDefBuilder builder; BuildFusedKernelDef(builder, metadef, type); auto kernel_def = builder.Build(); - // save hash so SessionState can find the kernel. each kernel name should be unique - if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { - ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, - ". Execution Provider must generate unique names across the entire model."); - } - + auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( KernelCreateInfo(std::move(kernel_def), [](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr& out) -> Status { @@ -603,53 +650,25 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one graph.FinalizeFuseSubGraph(indexed_sub_graph, node); } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) return Status::OK(); } -// If this is an ORT format model the hashes will be for CPU EP kernels, so set the EP of any unassigned nodes -// to kCpuExecutionProvider. -static void AssignRemainingNodesToCpuEp(Graph& graph) { - for (auto& node : graph.Nodes()) { - for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - Graph* subgraph = entry.second; - AssignRemainingNodesToCpuEp(*subgraph); - } - - if (node.GetExecutionProviderType().empty()) { - node.SetExecutionProviderType(kCpuExecutionProvider); - } - } -} - // Simplified partitioning where custom EPs may produce compiled nodes. -// EPs with static kernels do not need to be processed as their kernels are matched via hash information serialized -// as part of the ORT format model. -Status GraphPartitioner::PartitionOrtFormatModel( - Graph& graph, FuncManager& func_mgr, - KernelRegistry& fused_kernel_registry, - std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id, - TransformLayoutFunction transform_layout_function) const { +static Status PartitionOrtFormatModel(const PartitionParams& partition_params, + const ExecutionProviders& execution_providers, + KernelRegistryManager& kernel_registry_manager) { // process full graph with each EP - for (const auto& ep : providers_) { - if (ep->Type() == kCpuExecutionProvider) { - // hash for kernel is stored in session state for EPs that have pre-registered kernels - // (vs. runtime fused kernels) so we can simply assign any remaining nodes to the CPU EP - AssignRemainingNodesToCpuEp(graph); - } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry, - *ep, compiled_kernel_hashes, fused_node_unique_id, - transform_layout_function)); - } + for (const auto& ep : execution_providers) { + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep)); } return Status::OK(); } Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, - TransformLayoutFunction transform_layout_function, Mode mode, - std::unordered_map* compiled_kernel_hashes) const { + TransformLayoutFunction transform_layout_function, Mode mode) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. // 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet. @@ -659,9 +678,11 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, // 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider // preference. if (providers_.Empty()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // fused_kernel_registry is preparing the kernels created on the fly for fused sub graph. // It is only visible for current session. std::shared_ptr fused_kernel_registry = std::make_shared(); @@ -669,24 +690,43 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, // we make sure each fused node name is unique across the entire model for clarity int fused_node_unique_id = 0; + PartitionParams partition_params{ + std::ref(graph), + std::ref(func_mgr), + std::ref(*fused_kernel_registry), + std::ref(fused_node_unique_id), + transform_layout_function, + }; + +#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + + ORT_UNUSED_PARAMETER(func_mgr); + ORT_UNUSED_PARAMETER(transform_layout_function); + PartitionParams partition_params{ + std::ref(graph), + }; + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, func_mgr, *fused_kernel_registry, mode, - fused_node_unique_id, transform_layout_function)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, + providers_, kernel_registry_mgr_)); #else - ORT_THROW("Not supported in this build."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided"); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes, - fused_node_unique_id, transform_layout_function)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, + providers_, kernel_registry_mgr_)); } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) if (!fused_kernel_registry->IsEmpty()) { kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry); } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) return Status::OK(); } -} // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4aac959f3e1bc..042c1a89c55f3 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -3,17 +3,13 @@ #pragma once -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #include "core/common/common.h" -#include "core/graph/graph_viewer.h" -#include "core/framework/op_kernel.h" +#include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" namespace onnxruntime { class ExecutionProviders; -class KernelRegistry; class KernelRegistryManager; using TransformLayoutFunction = std::function; @@ -25,34 +21,22 @@ class GraphPartitioner { kOrtFormatLoad = 2 // loading ORT format model. Partition with compiling EPs, GraphViewer based Compile. }; - //The order of providers represents the user preference. + // The order of providers represents the user preference. GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) : kernel_registry_mgr_(kernel_registry_mgr), providers_(providers) { } - // Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad. - Status Partition(Graph& graph, FuncManager& func_mgr, + // Run partitioning. + Status Partition(Graph& graph, FuncManager& func_mgr, TransformLayoutFunction transform_layout_function, - Mode mode = Mode::kNormal, - std::unordered_map* compiled_kernel_hashes = nullptr) const; + Mode mode = Mode::kNormal) const; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); -#if !defined(ORT_MINIMAL_BUILD) - Status PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr, - KernelRegistry& fused_kernel_registry, Mode mode, - int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const; -#endif - - Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, - std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const; - KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; }; -} // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_def_builder.cc b/onnxruntime/core/framework/kernel_def_builder.cc index b3fffe602faef..c939d96039715 100644 --- a/onnxruntime/core/framework/kernel_def_builder.cc +++ b/onnxruntime/core/framework/kernel_def_builder.cc @@ -9,8 +9,6 @@ #include "gsl/gsl" -#include "core/framework/murmurhash3.h" - namespace onnxruntime { namespace { @@ -31,58 +29,6 @@ inline bool AreVectorsOverlap(const std::vector& v1, const std::vector& v2 } // namespace -void KernelDef::CalculateHash() { - uint32_t hash[4] = {0, 0, 0, 0}; - - auto hash_int = [&hash](int i) { MurmurHash3::x86_128(&i, sizeof(i), hash[0], &hash); }; - auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); - }; - - // use name, start/end, domain, provider and the type constraints. - // we wouldn't have two kernels that only differed by the inplace or alias info or memory types. - // currently nothing sets exec_queue_id either (and would assumably be a runtime thing and not part of the base - // kernel definition) - - hash_str(op_name_); - - if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { - hash_int(op_since_version_start_); - } else { - const uint8_t* c = (const uint8_t*)&op_since_version_start_; - hash_int((uint32_t)c[0] | - (uint32_t)c[1] << 8 | - (uint32_t)c[2] << 16 | - (uint32_t)c[3] << 24); - } - - // If we include op_since_version_end_ the hash of an existing op changes when it's superseded. - // e.g. Unsqueeze 11 had no end version until Unsqueeze 13, at which point the existing op is changed to have - // an end version of 12. That would result in a new ORT build having a different hash for Unsqueeze 11 and a - // previously serialized ORT format model wouldn't find the kernel. In order to select the kernel to include - // in the ORT model the full OpSchema info is used, so it's safe to exclude op_since_version_end_ from the hash. - - hash_str(op_domain_); - hash_str(provider_type_); - - // use the hash_type_constraints_ or default_type_constraints_ list for the hash so the value in an ORT format model - // is stable. - const auto& hash_type_constraints = - hash_type_constraints_.has_value() ? *hash_type_constraints_ : default_type_constraints_; - for (const auto& key_value : hash_type_constraints) { - hash_str(key_value.first); - auto data_type_strings = DataTypeImpl::ToString(key_value.second); - // sort type constraint data type strings so that order does not matter - std::sort(data_type_strings.begin(), data_type_strings.end()); - for (const auto& data_type_string : data_type_strings) { - hash_str(data_type_string); - } - } - - hash_ = hash[0] & 0xfffffff8; // save low 3 bits for hash version info in case we need it in the future - hash_ |= uint64_t(hash[1]) << 32; -} - // TODO: Tell user why it has conflicts // TODO: Investigate why IsConflict() was not triggered when there were duplicate Tile CUDA // kernels registered. Removing `InputMemoryType(OrtMemTypeCPUInput, 1)` in the kernel definition @@ -227,23 +173,6 @@ KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* arg_name, return TypeConstraint(std::string(arg_name), default_type); } -KernelDefBuilder& KernelDefBuilder::FixedTypeConstraintForHash( - const std::string& arg_name, - const std::vector& default_types_for_hash) { - auto& hash_type_constraints = kernel_def_->hash_type_constraints_; - if (!hash_type_constraints.has_value()) { - hash_type_constraints.emplace(); - } - (*hash_type_constraints)[arg_name] = default_types_for_hash; - return *this; -} - -KernelDefBuilder& KernelDefBuilder::FixedTypeConstraintForHash( - const char* arg_name, - const std::vector& default_types_for_hash) { - return FixedTypeConstraintForHash(std::string{arg_name}, default_types_for_hash); -} - KernelDefBuilder& KernelDefBuilder::MayInplace(const std::vector>& inplaces) { kernel_def_->inplace_map_ = inplaces; return *this; diff --git a/onnxruntime/core/framework/kernel_def_hash_helpers.cc b/onnxruntime/core/framework/kernel_def_hash_helpers.cc deleted file mode 100644 index 15e02f2c23f68..0000000000000 --- a/onnxruntime/core/framework/kernel_def_hash_helpers.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/kernel_def_hash_helpers.h" - -#include "core/framework/data_types_internal.h" -#include "core/graph/graph.h" - -namespace onnxruntime { -namespace utils { -std::optional GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version) { - // Layout tranformer can add new nodes to the graph. - // Since layout transformation can happen in an extended build, if these nodes are not picked up and compiled by - // NNAPI or other compiling EPs then we need a way to get the hashes for these nodes. Since the infrastructure - // as well as op_schema required to generate these hashes is not available in an extended minimal build, - // we maintain a static map of nodes to hash value. This hash value can then be used to retrieve the - // kernel for the given op. - static const std::unordered_map static_kernel_hashes{ - // Note: these region_begin/end markers are used by tools/ci_build/reduce_op_kernels.py - // @@region_begin(layout_transformation_required_kernels)@@ - {"Transpose_1", 4324835766923221184ULL}, - {"Transpose_13", 17267477159887372848ULL}, - {"Squeeze_1", 12889825108950034784ULL}, - {"Squeeze_11", 14725795030460042064ULL}, - {"Squeeze_13", 16122603335179721968ULL}, - {"UnSqueeze_1", 15964030255371555232ULL}, - {"UnSqueeze_11", 16989589986691430224ULL}, - {"UnSqueeze_13", 9466011545409597224ULL}, - {"Gather_1", 625186873870077080ULL}, - {"Gather_11", 11761559382112736008ULL}, - {"Gather_13", 7462749543760614528ULL}, - {"Identity_1", 18001636502361632792ULL}, - {"Identity_13", 16879814636194901248ULL}, - {"Identity_14", 16515685968327103576ULL}, - {"Identity_16", 17661628575887109792ULL}, - // @@region_end(layout_transformation_required_kernels)@@ - }; - - auto key = op_type + "_" + std::to_string(since_version); - auto iter = static_kernel_hashes.find(key); - if (iter != static_kernel_hashes.end()) { - return iter->second; - } - - return std::nullopt; -} - -// special case for node NHWC optimizer may insert when running in minimal build -std::optional GetInternalNhwcOpHash(const Node& node) { - if (node.Domain() == kMSDomain) { - const auto& op_type = node.OpType(); - const auto& input_0_type = *node.InputDefs()[0]->TypeAsProto(); - - if (op_type == "QLinearConv") { - // first input is a tensor. could be uint8 or int8 - bool is_uint8 = input_0_type.tensor_type().elem_type() == utils::ToTensorProtoElementType(); - return is_uint8 ? 16835965565578160400ULL : 10904143578341560456ULL; - } else if (op_type == "NhwcMaxPool") { - // first input is a tensor. could be uint8 or int8 - bool is_uint8 = input_0_type.tensor_type().elem_type() == utils::ToTensorProtoElementType(); - return is_uint8 ? 8512357837341844248ULL : 11773579655431087496ULL; - } - } - - return std::nullopt; -} - -void UpdateHashForBackwardsCompatibility(HashValue& hash) { - // map of old hash to new hash if we were forced to break backwards compatibility for a kernel registration - // - // If we need to update the hash for an existing registration, an entry needs to be added here to map the - // old hash to the new. This should rarely be required as historically the only need for it was fixing - // kernel registrations with invalid type constraints. Please carefully read through the information at the top of - // onnxruntime/test/providers/kernel_def_hash_test.cc regarding how/when hashes might change and the best way to - // address that. - static const std::unordered_map hashes{ - // old new domain, operator, opset[, type] - {2832535737534577496ULL, 16708009824840936392ULL}, // kOnnxDomain, Dropout, 7 - {12198479371038564912ULL, 1718418059112844640ULL}, // kOnnxDomain, Scan, 9 - {2560955351529676608ULL, 3668627007850399040ULL}, // kOnnxDomain, Scan, 11 - {10232409728231027688ULL, 5212043150202938416ULL}, // kOnnxDomain, Not, 1 - {11912523891622051440ULL, 10225383741733918632ULL}, // kOnnxDomain, RoiAlign, 10, float - {18084231515768318048ULL, 17022700455473327752ULL}, // kOnnxDomain, RoiAlign, 10, double - {14033689580222898712ULL, 634727773751317256ULL}, // kOnnxDomain, GatherND, 11 - {646512416908411600ULL, 3064028185911332496ULL}, // kOnnxDomain, GatherND, 12 - {15019893097608892000ULL, 11311962292460032936ULL}, // kOnnxDomain, GatherND, 13 - {14259324427750852648ULL, 7767393334034626736ULL}, // kOnnxDomain, StringNormalizer, 10 - // contrib ops - {7642430665819070720ULL, 8620498355864235632ULL}, // kMSDomain, CropAndResize, 1 - {15019666093341768288ULL, 11924582339825775592ULL}, // kMSDomain, GridSample, 1 - {8466416990072218056ULL, 18418354579469131656ULL}, // kOnnxDomain, LayerNormalization, 1, float - {4058615579523172864ULL, 4827261308628792072ULL}, // kOnnxDomain, LayerNormalization, 1, double - {16349480652468900704ULL, 4809288790945391544ULL}, // kOnnxDomain, SimplifiedLayerNormalization, 1, float - {418129161279605176ULL, 13556035637124174064ULL}}; // kOnnxDomain, SimplifiedLayerNormalization, 1, double - - auto iter = hashes.find(hash); - if (iter != hashes.cend()) { - // hash was updated in newer version of ORT kernel registrations - hash = iter->second; - } -} - -} // namespace utils -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_def_hash_helpers.h b/onnxruntime/core/framework/kernel_def_hash_helpers.h deleted file mode 100644 index a612f503c20e8..0000000000000 --- a/onnxruntime/core/framework/kernel_def_hash_helpers.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/graph/basic_types.h" - -#include -#include - -namespace onnxruntime { -class Node; -namespace utils { -/** - * @brief Gets the hash value for provided op type + version combination if it is available, otherwise - * returns a nullopt. The hash value is available if this node was added by layout transformer. For all other - * nodes, the hash values should be present either in the serialized session state obtained form ort format model - * or from compiled kernel hash map which is generated during partitioning. - * @return std::optional - */ -std::optional GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version); - -/** - * Get hash value for com.microsoft ops with CPU EP implementations that the NHWC optimizer may insert. - * These are required when that optimizer is run using a minimal build and ORT format model. - * @param Node Node to find hash for. - */ -std::optional GetInternalNhwcOpHash(const Node& node); - -/** - * Get replacement hash for backwards compatibility if we had to modify an existing kernel registration. - * @param hash Hash to update if needed. - */ -void UpdateHashForBackwardsCompatibility(HashValue& hash); -} // namespace utils -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h new file mode 100644 index 0000000000000..492d60fad4975 --- /dev/null +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "gsl/gsl" + +#include "core/common/common.h" +#include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup +#include "core/framework/kernel_registry.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/op_kernel.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * Utility class for performing kernel lookup. + * Primary usage pattern is to be created during graph partitioning and passed to IExecutionProvider::GetCapability(). + */ +class KernelLookup : public IExecutionProvider::IKernelLookup { + public: + KernelLookup(ProviderType provider_type, + gsl::span> kernel_registries, + const IKernelTypeStrResolver& kernel_type_str_resolver) + : provider_type_{provider_type}, + kernel_registries_{kernel_registries}, + kernel_type_str_resolver_{kernel_type_str_resolver} { + ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified."); + } + + const KernelCreateInfo* LookUpKernel(const Node& node) const override { + const KernelCreateInfo* kernel_create_info{}; + for (const auto& registry : kernel_registries_) { + const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, + &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return kernel_create_info; + } + } + + return nullptr; + } + + private: + ProviderType provider_type_; + const gsl::span> kernel_registries_; + const IKernelTypeStrResolver& kernel_type_str_resolver_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index 90b0b190a8988..24f11a6a8c9c2 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -5,169 +5,100 @@ #include #include +#include #include +#include "core/framework/kernel_type_str_resolver.h" #include "core/framework/session_state.h" namespace onnxruntime { -#if !defined(ORT_MINIMAL_BUILD) namespace { -// Traverses the node's formal parameters and calls TraverseFn with the formal -// parameter and its associated TypeProto. -// node - the node to traverse -// param_filter_fn - called to determine whether to consider a given formal parameter: -// bool ParamFilterFn(const ONNX_NAMESPACE::OpSchema::FormalParameter& param) -// param - the formal parameter -// returns true if the formal parameter should be considered, false otherwise -// traverse_fn - called to process the formal parameter and its associated TypeProto: -// bool TraverseFn(const ONNX_NAMESPACE::OpSchema::FormalParameter& param, -// const ONNX_NAMESPACE::TypeProto* type) -// param - the formal parameter -// type - the associated TypeProto -// returns true if traversal should continue, false otherwise -template -bool TraverseFormalParametersWithTypeProto(const Node& node, - ParamFilterFn param_filter_fn, - TraverseFn traverse_fn) { - const ONNX_NAMESPACE::OpSchema& op_schema = *node.Op(); - - // was the param name matched in either inputs, outputs or type constraints. - // this validates the name was valid and that the type involved will be returned if available. - // if the name is invalid we do not return a type, and any applicable type constraint can not be applied - // in VerifyKernelDef. - bool matched = false; +bool IsTypeProtoCompatible(gsl::span enabled_types, const ONNX_NAMESPACE::TypeProto& actual_type, + std::string& mismatch_reason) { + const bool is_type_compatible = std::any_of( + enabled_types.begin(), enabled_types.end(), + [&actual_type](const DataTypeImpl* expected_type) { + bool rc = expected_type->IsCompatible(actual_type); // for easier debugging + return rc; + }); - // process inputs: - const size_t len = node.InputArgCount().size(); - ORT_ENFORCE(len <= op_schema.inputs().size()); - int actual_index = 0; - for (size_t formal_index = 0; formal_index != len; ++formal_index) { - const auto& param = op_schema.inputs()[formal_index]; - if (param_filter_fn(param)) { - matched = true; - // get type of any corresponding actual parameter, if present - for (int i = 0, end = node.InputArgCount()[formal_index]; i < end; ++i) { - const NodeArg* arg = node.InputDefs()[static_cast(actual_index) + i]; - if (!arg->Exists()) continue; // a missing optional argument - if (!traverse_fn(param, arg->TypeAsProto())) return matched; - } + if (!is_type_compatible) { + std::ostringstream ostr; + ostr << "This op has been implemented only for the following types ("; + for (const auto& enabled_type : enabled_types) { + ostr << DataTypeImpl::ToString(enabled_type) << ","; } - actual_index += node.InputArgCount()[formal_index]; + ostr << "),"; + const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(actual_type)); + ostr << " but the node in the model has the following type (" << actual_type_str << ")"; + mismatch_reason = ostr.str(); + return false; } - // process outputs: - auto actual_outputs = node.OutputDefs(); - const auto num_actual_outputs = actual_outputs.size(); - const auto& schema_outputs = op_schema.outputs(); - const auto last_formal = schema_outputs.size() - 1; - size_t i = 0; - for (; i != num_actual_outputs; ++i) { - const auto& formal = schema_outputs[std::min(i, last_formal)]; - if (!param_filter_fn(formal)) continue; - matched = true; - const NodeArg* arg = actual_outputs[i]; - if (!arg->Exists()) continue; - if (!traverse_fn(formal, arg->TypeAsProto())) return matched; - } + return true; +} - // missing optional outputs. check if type constraint name was valid if we haven't matched anything yet. - if (!matched) { - while (i <= last_formal) { - if (param_filter_fn(schema_outputs[i])) { - matched = true; - break; +bool MatchKernelDefTypes(const Node& node, + const KernelDef& kernel_def, + const IKernelTypeStrResolver& kernel_type_str_resolver, + std::string& mismatch_reason) { + const auto actual_inputs = node.InputDefs(); + const auto actual_outputs = node.OutputDefs(); + const auto& actual_input_arg_counts = node.InputArgCount(); + const auto actual_input_arg_offsets = [&actual_input_arg_counts]() { + InlinedVector offsets{}; + offsets.reserve(actual_input_arg_counts.size()); + std::exclusive_scan(actual_input_arg_counts.begin(), actual_input_arg_counts.end(), + std::back_inserter(offsets), 0); + return offsets; + }(); + + // for each type constraint + // map type constraint to arg + // check arg type against type constraint enabled types + const auto& kernel_type_constraints = kernel_def.EnabledTypeConstraints(); + for (const auto& [kernel_type_str, enabled_types] : kernel_type_constraints) { + gsl::span constraint_args{}; + ORT_THROW_IF_ERROR(kernel_type_str_resolver.ResolveKernelTypeStr(node, kernel_type_str, + constraint_args)); + + for (const auto& [arg_type, formal_arg_idx] : constraint_args) { + const NodeArg* arg; + if (arg_type == ArgType::kInput) { + if (formal_arg_idx >= actual_input_arg_counts.size() || + actual_input_arg_counts[formal_arg_idx] == 0) { + arg = nullptr; + } else { + const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; + ORT_ENFORCE(static_cast(first_arg_idx) < actual_inputs.size()); + arg = actual_inputs[first_arg_idx]; + } + } else { + arg = formal_arg_idx < actual_outputs.size() ? actual_outputs[formal_arg_idx] : nullptr; } - ++i; - } - } + if (arg && arg->Exists()) { + const ONNX_NAMESPACE::TypeProto* type_proto = arg->TypeAsProto(); + ORT_ENFORCE(type_proto != nullptr); - return matched; -} - -class TypeBindingResolver { - public: - TypeBindingResolver(const Node& node, bool use_lookup_map) - : node_(node), - type_binding_map_() { - if (use_lookup_map) { - type_binding_map_ = std::make_unique(); - TraverseFormalParametersWithTypeProto( - node_, - [](const ONNX_NAMESPACE::OpSchema::FormalParameter&) -> bool { return true; }, - [this](const ONNX_NAMESPACE::OpSchema::FormalParameter& param, - const ONNX_NAMESPACE::TypeProto* type) -> bool { - type_binding_map_->emplace(param.GetName(), type); - type_binding_map_->emplace(param.GetTypeStr(), type); - return true; - }); - } - } - - // Resolves a type constraint name to a TypeProto* for a given node. ONNX code checks that all usages of the type - // constraint name by the node are consistent, so we just need to match the first usage to see the actual type - // being used by the node. e.g. if type constraint 'T' allows float and double, any input or output for that node - // that has constraint 'T' must use the same type, be that float or double. - // - // Also can resolve an input/output name to a contraint when a type constraint name is not used. - // e.g. the 'shape' input of Reshape has a directly specified constraint of 'tensor(int64)'. - // - // Returns the resolved TypeProto* or nullptr if unable to resolve due to the - // constraint being for a missing optional output. - const ONNX_NAMESPACE::TypeProto* Resolve(const std::string& name_or_type_str) const { - const ONNX_NAMESPACE::TypeProto* result{}; - bool matched = false; + if (!IsTypeProtoCompatible(enabled_types, *type_proto, mismatch_reason)) { + return false; + } - // lookup if available - if (type_binding_map_) { - auto found_it = type_binding_map_->find(name_or_type_str); - matched = found_it != type_binding_map_->end(); - if (matched) { - result = found_it->second; + // found a match, don't need to check other args with this constraint + break; } } - - if (!matched) { - // fall back to node parameter traversal - matched = TraverseFormalParametersWithTypeProto( - node_, - [&name_or_type_str](const ONNX_NAMESPACE::OpSchema::FormalParameter& param) -> bool { - return param.GetTypeStr() == name_or_type_str || param.GetName() == name_or_type_str; - }, - [&result](const ONNX_NAMESPACE::OpSchema::FormalParameter&, - const ONNX_NAMESPACE::TypeProto* type) -> bool { - result = type; - return false; - }); - } - -// invalid kernel def with type constraints that don't match the schema. this means the type constraints are not -// actually applied, making the kernel def misleading and potentially matching an unexpected/incorrect kernel. -// warn in a release build as we do not have coverage of every single opset for every single operator -// in the unit tests, so issues may be missed and the model may still work (e.g. matches the correct kernel by chance). -// throw in a debug build so the issue is obvious and force it to be fixed. -#ifdef NDEBUG - if (!matched) { - LOGS_DEFAULT(WARNING) << name_or_type_str << " constraint was not found for " << node_.OpType(); - } -#else - ORT_ENFORCE(matched, name_or_type_str, " constraint was not found for ", node_.OpType()); -#endif - return result; } - private: - // map from input/output name or type string to TypeProto pointer - using TypeBindingMap = std::unordered_map; - - const Node& node_; - std::unique_ptr type_binding_map_; -}; -}; // namespace + return true; +} +} // namespace bool KernelRegistry::VerifyKernelDef(const Node& node, const KernelDef& kernel_def, + const IKernelTypeStrResolver& kernel_type_str_resolver, std::string& error_str) { // check if version matches int kernel_start_version; @@ -200,81 +131,21 @@ bool KernelRegistry::VerifyKernelDef(const Node& node, return false; } - // check if type matches - auto& kernel_type_constraints = kernel_def.EnabledTypeConstraints(); - - // Note: The number of formal input/output parameters is N and the number of - // type constraints is M. We select between an O(N*M) and an O(N+M) approach. - // The O(N*M) approach has lower initial overhead. - // kTypeBindingResolverComplexityThreshold is the value of N*M above which we - // will use the O(N+M) approach. - constexpr int kTypeBindingResolverComplexityThreshold = 50 * 50; - const bool use_lookup_map = (kernel_type_constraints.size() * (node.Op()->inputs().size() + node.Op()->outputs().size()) > - kTypeBindingResolverComplexityThreshold); - TypeBindingResolver type_binding_resolver{node, use_lookup_map}; - - for (auto& constraint : kernel_type_constraints) { - const std::string& name = constraint.first; - const std::vector& allowed_types = constraint.second; - const ONNX_NAMESPACE::TypeProto* actual_type = type_binding_resolver.Resolve(name); - - // If actual_type is null, this represents a type-constraint on a - // missing optional parameter, which can be skipped. - // TODO: We should check that names specified in kernel_type_constraints are - // valid names (of types or parameters) at the time that kernels are registered. - if (nullptr != actual_type) { - bool is_type_compatible = std::any_of(allowed_types.begin(), allowed_types.end(), - [actual_type](const DataTypeImpl* expected_type) { - bool rc = expected_type->IsCompatible(*actual_type); // for easier debugging - return rc; - }); - if (!is_type_compatible) { - std::ostringstream ostr; - ostr << "Found kernel for Op with name (" << node.Name() << ")" - << " and type (" << node.OpType() << ")" - << " in the supported version range" - << " (node_version: " << node_since_version - << " kernel start version: " << kernel_start_version - << " kernel_end_version: " << kernel_end_version << ")." - << " However the types are incompatible." - << " This op has been implemented only for the following types ("; - for (const auto& allowed_type : allowed_types) { - ostr << DataTypeImpl::ToString(allowed_type) << ","; - } - ostr << "),"; - const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*actual_type)); - ostr << " but the node in the model has the following type (" << actual_type_str << ")"; - error_str = ostr.str(); - return false; - } - } + if (std::string mismatch_reason; + !MatchKernelDefTypes(node, kernel_def, kernel_type_str_resolver, mismatch_reason)) { + std::ostringstream ostr; + ostr << "Found kernel for Op with name (" << node.Name() << ")" + << " and type (" << node.OpType() << ")" + << " in the supported version range" + << " (node_version: " << node_since_version + << " kernel start version: " << kernel_start_version + << " kernel_end_version: " << kernel_end_version << ")." + << " However the types are incompatible. " << mismatch_reason; + error_str = ostr.str(); + return false; } - return true; -} -Status KernelRegistry::TryCreateKernel(const Node& node, - const IExecutionProvider& execution_provider, - const std::unordered_map& constant_initialized_tensors, - const OrtValueNameIdxMap& ort_value_name_idx_map, - FuncManager& funcs_mgr, - const DataTransferManager& data_transfer_mgr, - /*out*/ std::unique_ptr& op_kernel) const { - const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(TryFindKernel(node, execution_provider.Type(), &kernel_create_info)); - OpKernelInfo kernel_info(node, - *kernel_create_info->kernel_def, - execution_provider, - constant_initialized_tensors, - ort_value_name_idx_map, - data_transfer_mgr); - return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel); -} - -static std::string ToString(const std::vector& error_strs) { - std::ostringstream ostr; - std::for_each(std::begin(error_strs), std::end(error_strs), - [&ostr](const std::string& str) { ostr << str << "\n"; }); - return ostr.str(); + return true; } // It's often this function returns a failed status, but it is totally expected. @@ -284,6 +155,7 @@ static std::string ToString(const std::vector& error_strs) { // otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored. Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, + const IKernelTypeStrResolver& kernel_type_str_resolver, const KernelCreateInfo** out) const { const auto& node_provider = node.GetExecutionProviderType(); const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); @@ -295,7 +167,7 @@ Status KernelRegistry::TryFindKernel(const Node& node, for (auto i = range.first; i != range.second; ++i) { std::string error_str; - if (VerifyKernelDef(node, *i->second.kernel_def, error_str)) { + if (VerifyKernelDef(node, *i->second.kernel_def, kernel_type_str_resolver, error_str)) { if (out) *out = &i->second; return Status::OK(); } @@ -307,7 +179,10 @@ Status KernelRegistry::TryFindKernel(const Node& node, oss << "Op with name (" << node.Name() << ")" << " and type (" << node.OpType() << ")" << " kernel is not supported in " << expected_provider << "." - << " Encountered following errors: (" << ToString(verify_kernel_def_error_strs) << ")"; + << " Encountered following errors: ("; + std::copy(verify_kernel_def_error_strs.begin(), verify_kernel_def_error_strs.end(), + std::ostream_iterator(oss, "\n")); + oss << ")"; VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); @@ -316,47 +191,38 @@ Status KernelRegistry::TryFindKernel(const Node& node, return Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found"); } +#if !defined(ORT_MINIMAL_BUILD) Status KernelRegistry::TryFindKernel(const std::string& op_name, const std::string& domain, const int& version, const std::unordered_map& type_constraints, - ProviderType exec_provider, const KernelCreateInfo** out) const { - *out = nullptr; + ProviderType exec_provider, const KernelCreateInfo** kernel_out) const { + const KernelCreateInfo* kernel = nullptr; auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_name, domain, exec_provider)); - for (auto i = range.first; i != range.second; ++i) { //loop through all kernels + for (auto i = range.first; i != range.second; ++i) { // loop through all kernels const KernelCreateInfo& kci = i->second; int start_ver{}; int end_ver{}; kci.kernel_def->SinceVersion(&start_ver, &end_ver); - if (start_ver <= version && end_ver >= version) { //try match the version + if (start_ver <= version && end_ver >= version) { // try match the version auto& kci_constraints = kci.kernel_def->TypeConstraints(); bool match = true; - for (auto& constraint : type_constraints) { //try match type constraints + for (auto& constraint : type_constraints) { // try match type constraints auto iter = kci_constraints.find(constraint.first); if (iter == kci_constraints.end() || find(iter->second.begin(), iter->second.end(), constraint.second) == iter->second.end()) { match = false; break; } - } //for + } // for if (match) { - *out = &kci; //found match, exit loop + kernel = &kci; // found match, exit loop break; } - } //if - } //for - return *out == nullptr ? Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found") : Status::OK(); + } // if + } // for + if (kernel_out) *kernel_out = kernel; + return kernel == nullptr ? Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found") : Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) -bool KernelRegistry::TryFindKernelByHash(HashValue kernel_def_hash, const KernelCreateInfo** out) const { - const auto hash_lookup_it = kernel_def_hash_lookup_.find(kernel_def_hash); - if (hash_lookup_it == kernel_def_hash_lookup_.end()) { - if (out) *out = nullptr; - return false; - } - - if (out) *out = &hash_lookup_it->second->second; - return true; -} - Status KernelRegistry::Register(KernelDefBuilder& kernel_builder, const KernelCreateFn& kernel_creator) { return Register(KernelCreateInfo(kernel_builder.Build(), kernel_creator)); @@ -378,29 +244,10 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) { } } - // check for existing hash conflict - const auto kernel_def_hash = create_info.kernel_def->GetHash(); - ORT_RETURN_IF(kernel_def_hash_lookup_.find(kernel_def_hash) != kernel_def_hash_lookup_.end(), - "Failed to add kernel for " + key + ": Conflict with existing kernel def hash."); - // Register the kernel. // Ownership of the KernelDef is transferred to kernel_creator_fn_map_. - auto it = kernel_creator_fn_map_.emplace(key, std::move(create_info)); - kernel_def_hash_lookup_.emplace(kernel_def_hash, it); + kernel_creator_fn_map_.emplace(key, std::move(create_info)); return Status::OK(); } -KernelDefHashes KernelRegistry::ExportKernelDefHashes() const { - KernelDefHashes result{}; - result.reserve(kernel_creator_fn_map_.size()); - std::transform( - kernel_creator_fn_map_.begin(), kernel_creator_fn_map_.end(), - std::back_inserter(result), - [](const KernelCreateMap::value_type& kvp) { - return std::make_pair(kvp.first, kvp.second.kernel_def->GetHash()); - }); - std::sort(result.begin(), result.end()); - return result; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 2041b0d7bac4a..13bcd71651e0a 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/framework/kernel_registry_manager.h" + #include "core/framework/kernel_registry.h" #include "core/framework/execution_providers.h" #include "core/framework/session_state.h" +#include "core/framework/kernel_type_str_resolver.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/framework/customregistry.h" @@ -13,7 +15,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, +Status KernelRegistryManager::CreateKernel(const Node& node, const IExecutionProvider& execution_provider, SessionState& session_state, const KernelCreateInfo& kernel_create_info, @@ -53,15 +55,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); - return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { - return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type); - }); -} - -Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node, +Status KernelRegistryManager::SearchKernelRegistry(const Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const { Status status; @@ -80,7 +74,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node } for (auto& registry : custom_kernel_registries_) { - status = registry->TryFindKernel(node, std::string(), kernel_create_info); + status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); if (status.IsOK()) { return status; } @@ -93,7 +87,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node } if (p != nullptr) { - status = p->TryFindKernel(node, std::string(), kernel_create_info); + status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); if (status.IsOK()) { return status; } @@ -101,25 +95,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } -#endif -bool KernelRegistryManager::SearchKernelRegistriesByHash(HashValue kernel_def_hash, - const KernelCreateInfo** kernel_create_info) const { -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - for (const auto& registry : custom_kernel_registries_) { - if (registry->TryFindKernelByHash(kernel_def_hash, kernel_create_info)) { - return true; - } - } -#endif - - for (const auto& kv : provider_type_to_registry_) { - if (kv.second->TryFindKernelByHash(kernel_def_hash, kernel_create_info)) { - return true; - } - } - - return false; +bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { + const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); + return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { + return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver()); + }); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 4d9da148af064..344ab220e984d 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -2,11 +2,16 @@ // Licensed under the MIT License. #pragma once -#include -#include #include +#include +#include #include + +#include "gsl/gsl" + +#include "core/common/inlined_containers.h" #include "core/common/status.h" +#include "core/framework/kernel_type_str_resolver.h" #include "core/graph/graph_viewer.h" #include "core/platform/ort_mutex.h" @@ -30,7 +35,7 @@ class KernelRegistryManager { KernelRegistryManager() = default; // Register kernels from providers - Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT; + Status RegisterKernels(const ExecutionProviders& execution_providers); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // The registry passed in this function has highest priority than anything already in this KernelRegistryManager, @@ -41,47 +46,49 @@ class KernelRegistryManager { // RegisterKernelRegistry(B); // Then B > A > providers void RegisterKernelRegistry(std::shared_ptr kernel_registry); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) /** - * Search kernel registry by provider type. - * @param type provider type string - * @return It returns all the possible results. The returned value may contain garbage that doesn't belong to - * this provider. Caller should do the filtering. The returned value won't have no nullptrs. + * Gets kernel registries for the specified provider type. + * @param provider_type provider type string + * @return The kernel registries. This also includes custom registries. These may contain kernels that don't belong + * to this provider. The caller should do the filtering. */ - std::vector GetKernelRegistriesByProviderType(const std::string& type) const { - std::vector result; + InlinedVector> GetKernelRegistriesByProviderType( + const std::string& provider_type) const { + InlinedVector> result; + result.reserve(custom_kernel_registries_.size() + 1); for (auto& registry : custom_kernel_registries_) { result.push_back(registry.get()); } - auto iter = provider_type_to_registry_.find(type); + auto iter = provider_type_to_registry_.find(provider_type); if (iter != provider_type_to_registry_.end()) result.push_back(iter->second.get()); return result; } -#endif -#if !defined(ORT_MINIMAL_BUILD) // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status SearchKernelRegistry(const onnxruntime::Node& node, + Status SearchKernelRegistry(const Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const; /** * Whether this node can be run on this provider */ static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); -#endif - /** - * Search the kernel registries given a kernel def hash. - */ - bool SearchKernelRegistriesByHash(HashValue kernel_def_hash, - const KernelCreateInfo** kernel_create_info) const; - - Status CreateKernel(const onnxruntime::Node& node, + Status CreateKernel(const Node& node, const IExecutionProvider& execution_provider, SessionState& session_state, const KernelCreateInfo& kernel_create_info, std::unique_ptr& out) const; + const IKernelTypeStrResolver& GetKernelTypeStrResolver() const { + return std::visit([](auto&& r) -> const IKernelTypeStrResolver& { return r; }, kernel_type_str_resolver_variant_); + } + + void SetKernelTypeStrResolver(KernelTypeStrResolver&& kernel_type_str_resolver) { + kernel_type_str_resolver_variant_ = std::move(kernel_type_str_resolver); + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelRegistryManager); private: @@ -90,5 +97,14 @@ class KernelRegistryManager { // Each kernel registry may contain kernels from many different providers. // in order to search kernels from a specific provider, we have to iterate all its elements std::list> custom_kernel_registries_; + + // kernel type str resolver used by kernel registries for kernel matching + using KernelTypeStrResolverVariant = std::variant< +#if !defined(ORT_MINIMAL_BUILD) + OpSchemaKernelTypeStrResolver, // the default in a full build +#endif + KernelTypeStrResolver // the default in a minimal build + >; + KernelTypeStrResolverVariant kernel_type_str_resolver_variant_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.cc b/onnxruntime/core/framework/kernel_type_str_resolver.cc new file mode 100644 index 0000000000000..1d87a77fe4677 --- /dev/null +++ b/onnxruntime/core/framework/kernel_type_str_resolver.cc @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/kernel_type_str_resolver.h" + +#include // for std::lock_guard + +#include "core/flatbuffers/schema/ort.fbs.h" +#include "core/flatbuffers/flatbuffers_utils.h" +#include "core/graph/op_identifier_utils.h" + +namespace fb = flatbuffers; + +namespace onnxruntime { + +Status KernelTypeStrResolver::ResolveKernelTypeStr(const Node& node, std::string_view kernel_type_str, + gsl::span& resolved_args) const { + const auto op_id = utils::MakeOpId(node); + const auto op_it = op_kernel_type_str_map_.find(op_id); + ORT_RETURN_IF(op_it == op_kernel_type_str_map_.end(), "Failed to find op_id: ", op_id); + const auto& type_str_map = op_it->second; + +#ifdef DISABLE_ABSEIL + // TODO(edgchen1) maybe we can use transparent hash/eq to enable lookup with string_view + const auto type_str_it = type_str_map.find(std::string(kernel_type_str)); +#else + const auto type_str_it = type_str_map.find(kernel_type_str); +#endif + + ORT_RETURN_IF(type_str_it == type_str_map.end(), + "Failed to find args for kernel type string '", kernel_type_str, + "'. If type constraint names are available, ensure that they are used in the kernel def type " + "constraints instead of op input or output names. Not doing so will result in this error."); + resolved_args = type_str_it->second; + return Status::OK(); +} + +#if !defined(ORT_MINIMAL_BUILD) +Status KernelTypeStrResolver::RegisterOpSchema(const ONNX_NAMESPACE::OpSchema& op_schema, bool* registered_out) { + auto op_id = utils::MakeOpId(op_schema); + if (Contains(op_kernel_type_str_map_, op_id)) { + if (registered_out) { + *registered_out = false; + } + return Status::OK(); + } + + const auto type_constraint_names = [&]() { + const auto& type_constraints = op_schema.typeConstraintParams(); + InlinedHashSet names{}; + names.reserve(type_constraints.size()); + for (const auto& type_constraint : type_constraints) { + names.emplace(type_constraint.type_param_str); + } + return names; + }(); + + InlinedHashMap> kernel_type_str_map{}; + // at most one entry for each input/output + kernel_type_str_map.reserve(op_schema.inputs().size() + op_schema.outputs().size()); + + auto process_formal_params = [&](ArgType arg_type) -> Status { + const auto& formal_params = arg_type == ArgType::kInput ? op_schema.inputs() : op_schema.outputs(); + for (size_t i = 0; i < formal_params.size(); ++i) { + const auto& formal_param = formal_params[i]; + auto curr_arg_type_and_idx = ArgTypeAndIndex{arg_type, i}; + + // first, try to use type constraint name as kernel type string + if (const auto& type_str = formal_param.GetTypeStr(); + Contains(type_constraint_names, type_str)) { + kernel_type_str_map[type_str].push_back(curr_arg_type_and_idx); + continue; + } + + // otherwise, use input/output name as kernel type string + auto& args_for_io_name = kernel_type_str_map[formal_param.GetName()]; + if (!args_for_io_name.empty()) { + // It's possible that an input and output have the same name (e.g, BatchNormalization-9 has both an input and + // an output named 'mean'). + // If so, their formal parameters also need to have the same type string. Otherwise, it would be ambiguous to + // use that name as a kernel type string. + auto formal_param_type_str = [&op_schema](const ArgTypeAndIndex& arg_type_and_idx) { + const auto& [arg_type, idx] = arg_type_and_idx; + const auto& formal_params = arg_type == ArgType::kInput ? op_schema.inputs() : op_schema.outputs(); + return formal_params[idx].GetTypeStr(); + }; + + ORT_RETURN_IF_NOT( + formal_param_type_str(curr_arg_type_and_idx) == formal_param_type_str(args_for_io_name.front()), + "Kernel type string already exists for formal parameter name '", formal_param.GetName(), + "', but the existing argument with that formal parameter name has a different formal parameter " + "type string."); + } + args_for_io_name.push_back(std::move(curr_arg_type_and_idx)); + } + return Status::OK(); + }; + + ORT_RETURN_IF_ERROR(process_formal_params(ArgType::kInput)); + ORT_RETURN_IF_ERROR(process_formal_params(ArgType::kOutput)); + + op_kernel_type_str_map_.emplace(std::move(op_id), std::move(kernel_type_str_map)); + if (registered_out) { + *registered_out = true; + } + return Status::OK(); +} + +Status KernelTypeStrResolver::RegisterNodeOpSchema(const Node& node) { + ORT_RETURN_IF(node.Op() == nullptr, "Op schema must be available."); + return RegisterOpSchema(*node.Op()); +} + +Status KernelTypeStrResolver::RegisterGraphNodeOpSchemas(const Graph& graph) { + for (const Node& node : graph.Nodes()) { + ORT_RETURN_IF_ERROR(RegisterNodeOpSchema(node)); + + if (node.ContainsSubgraph()) { + const auto subgraphs = node.GetSubgraphs(); + for (const auto& subgraph : subgraphs) { + ORT_RETURN_IF_ERROR(RegisterGraphNodeOpSchemas(*subgraph)); + } + } + } + return Status::OK(); +} + +Status KernelTypeStrResolver::SaveToOrtFormat( + fb::FlatBufferBuilder& builder, + fb::Offset& fbs_kernel_type_str_resolver) const { + std::vector> fbs_op_kernel_type_str_args{}; + fbs_op_kernel_type_str_args.reserve(op_kernel_type_str_map_.size()); + + for (const auto& [op_id, kernel_type_str_map] : op_kernel_type_str_map_) { + std::vector> fbs_kernel_type_str_args{}; + fbs_kernel_type_str_args.reserve(kernel_type_str_map.size()); + + for (const auto& [kernel_type_str, args] : kernel_type_str_map) { + std::vector> fbs_args{}; + fbs_args.reserve(args.size()); + + for (const auto& arg : args) { + auto fbs_arg = fbs::CreateArgTypeAndIndex( + builder, + arg.first == ArgType::kInput ? fbs::ArgType::INPUT : fbs::ArgType::OUTPUT, + gsl::narrow(arg.second)); + fbs_args.push_back(fbs_arg); + } + + auto fbs_kernel_type_str_args_entry = fbs::CreateKernelTypeStrArgsEntry( + builder, + builder.CreateSharedString(kernel_type_str), + builder.CreateVector(fbs_args)); + fbs_kernel_type_str_args.push_back(fbs_kernel_type_str_args_entry); + } + + fb::Offset fbs_op_id{}; + ORT_RETURN_IF_ERROR(fbs::utils::SaveOpIdentifierOrtFormat(builder, op_id, fbs_op_id)); + + auto fbs_op_kernel_type_str_args_entry = fbs::CreateOpIdKernelTypeStrArgsEntry( + builder, + fbs_op_id, + builder.CreateVectorOfSortedTables(&fbs_kernel_type_str_args)); + fbs_op_kernel_type_str_args.push_back(fbs_op_kernel_type_str_args_entry); + } + + fbs_kernel_type_str_resolver = fbs::CreateKernelTypeStrResolver( + builder, + builder.CreateVectorOfSortedTables(&fbs_op_kernel_type_str_args)); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +Status KernelTypeStrResolver::LoadFromOrtFormat(const fbs::KernelTypeStrResolver& fbs_kernel_type_str_resolver) { + const auto* fbs_op_kernel_type_str_args = fbs_kernel_type_str_resolver.op_kernel_type_str_args(); + ORT_FORMAT_RETURN_IF_NULL(fbs_op_kernel_type_str_args, "op_kernel_type_str_args"); + + OpKernelTypeStrMap op_kernel_type_str_map{}; + op_kernel_type_str_map.reserve(fbs_op_kernel_type_str_args->size()); + for (const auto* fbs_op_kernel_type_str_args_entry : *fbs_op_kernel_type_str_args) { + ORT_FORMAT_RETURN_IF_NULL(fbs_op_kernel_type_str_args_entry, "op_kernel_type_str_args entry"); + + const auto* fbs_op_id = fbs_op_kernel_type_str_args_entry->op_id(); + ORT_FORMAT_RETURN_IF_NULL(fbs_op_id, "op_id"); + + const auto* fbs_kernel_type_str_args = fbs_op_kernel_type_str_args_entry->kernel_type_str_args(); + ORT_FORMAT_RETURN_IF_NULL(fbs_kernel_type_str_args, "kernel_type_str_args"); + + KernelTypeStrToArgsMap kernel_type_str_map{}; + kernel_type_str_map.reserve(fbs_kernel_type_str_args->size()); + for (const auto* fbs_kernel_type_str_args_entry : *fbs_kernel_type_str_args) { + ORT_FORMAT_RETURN_IF_NULL(fbs_kernel_type_str_args_entry, "kernel_type_str_args entry"); + + const auto* fbs_kernel_type_str = fbs_kernel_type_str_args_entry->kernel_type_str(); + ORT_FORMAT_RETURN_IF_NULL(fbs_kernel_type_str, "kernel_type_str"); + + const auto* fbs_args = fbs_kernel_type_str_args_entry->args(); + ORT_FORMAT_RETURN_IF_NULL(fbs_args, "args"); + + InlinedVector args{}; + args.reserve(fbs_args->size()); + for (const auto* fbs_arg : *fbs_args) { + ORT_FORMAT_RETURN_IF_NULL(fbs_arg, "args entry"); + args.push_back(ArgTypeAndIndex{ + fbs_arg->arg_type() == fbs::ArgType::INPUT ? ArgType::kInput : ArgType::kOutput, + fbs_arg->index()}); + } + + const auto [it, inserted] = kernel_type_str_map.try_emplace(fbs_kernel_type_str->str(), std::move(args)); + ORT_RETURN_IF_NOT(inserted, "Duplicate entry for kernel type str: ", it->first, ". ", + fbs::utils::kInvalidOrtFormatModelMessage); + } + + OpIdentifier op_id; + ORT_RETURN_IF_ERROR(fbs::utils::LoadOpIdentifierOrtFormat(*fbs_op_id, op_id)); + const auto [it, inserted] = op_kernel_type_str_map.try_emplace(std::move(op_id), + std::move(kernel_type_str_map)); + ORT_RETURN_IF_NOT(inserted, "Duplicate entry for op id: ", it->first, ". ", + fbs::utils::kInvalidOrtFormatModelMessage); + } + + op_kernel_type_str_map_ = std::move(op_kernel_type_str_map); + return Status::OK(); +} + +void KernelTypeStrResolver::Merge(KernelTypeStrResolver src) { + op_kernel_type_str_map_.merge(src.op_kernel_type_str_map_); +} + +#if !defined(ORT_MINIMAL_BUILD) +Status OpSchemaKernelTypeStrResolver::ResolveKernelTypeStr( + const Node& node, std::string_view kernel_type_str, + gsl::span& resolved_args) const { + std::lock_guard lock{resolver_mutex_}; + ORT_RETURN_IF_ERROR(resolver_.RegisterNodeOpSchema(node)); + ORT_RETURN_IF_ERROR(resolver_.ResolveKernelTypeStr(node, kernel_type_str, resolved_args)); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h new file mode 100644 index 0000000000000..0a95faf0b40a0 --- /dev/null +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "gsl/gsl" + +#if !defined(ORT_MINIMAL_BUILD) +#include "onnx/defs/schema.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +#include "core/common/inlined_containers.h" +#include "core/common/status.h" +#include "core/graph/op_identifier.h" +#include "core/graph/graph.h" +#include "core/platform/ort_mutex.h" + +namespace flatbuffers { +class FlatBufferBuilder; +template +struct Offset; +} // namespace flatbuffers + +namespace onnxruntime { + +namespace fbs { +struct KernelTypeStrResolver; +} // namespace fbs + +using ArgTypeAndIndex = std::pair; +using KernelTypeStrToArgsMap = InlinedHashMap>; +using OpKernelTypeStrMap = InlinedHashMap; + +/** + * This class interface provides a way to resolve an op's kernel type string to its associated arguments. + * + * A 'kernel type string' is a string that is used in kernel def type constraints. In particular, it can be a named + * type parameter (such as 'T') specified in the op schema or it can be the name of an input or output parameter. + */ +class IKernelTypeStrResolver { + public: + /** + * Resolves an op's kernel type string to its associated arguments. + * @param node The op's node. + * @param kernel_type_str The op kernel type string. + * @param[out] resolved_args The op arguments associated with kernel_type_str. + */ + virtual Status ResolveKernelTypeStr(const Node& node, std::string_view kernel_type_str, + gsl::span& resolved_args) const = 0; +}; + +/** + * A basic implementation of IKernelTypeStrResolver. + * + * Supports loading information from op schemas in a full build and saving to/loading from an ORT format model + * representation. + */ +class KernelTypeStrResolver : public IKernelTypeStrResolver { + public: + Status ResolveKernelTypeStr(const Node& node, std::string_view kernel_type_str, + gsl::span& resolved_args) const override; + +#if !defined(ORT_MINIMAL_BUILD) + + /** + * Registers kernel type string matching info from an op schema. + * This will not overwrite an existing registration for the same op. + * @param op_schema The op schema to register. + * @param[out] registered Whether the op schema was registered or there was already an existing registration. + */ + Status RegisterOpSchema(const ONNX_NAMESPACE::OpSchema& op_schema, bool* registered = nullptr); + + /** + * Registers kernel type string matching info from an op schema from a node. + * @param node The node to register. + */ + Status RegisterNodeOpSchema(const Node& node); + + /** + * Registers kernel type string matching info from op schemas from nodes in a graph. + * @param graph The graph to register. + */ + Status RegisterGraphNodeOpSchemas(const Graph& graph); + + /** + * Saves to an ORT format model representation. + * @param builder The flatbuffers builder. + * @param[out] fbs_kernel_type_str_resolver The saved flatbuffers representation offset. + */ + Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, + flatbuffers::Offset& fbs_kernel_type_str_resolver) const; + +#endif // !defined(ORT_MINIMAL_BUILD) + + /** + * Loads from an ORT format model representation. + * @param fbs_kernel_type_str_resolver The flatbuffers representation to load. + */ + Status LoadFromOrtFormat(const fbs::KernelTypeStrResolver& fbs_kernel_type_str_resolver); + + /** + * Merges kernel type string matching info from another KernelTypeStrResolver. + * @param src The KernelTypeStrResolver to merge from. + */ + void Merge(KernelTypeStrResolver src); + + const OpKernelTypeStrMap& GetOpKernelTypeStrMap() const { return op_kernel_type_str_map_; } + + private: + OpKernelTypeStrMap op_kernel_type_str_map_; +}; + +#if !defined(ORT_MINIMAL_BUILD) + +/** + * An implementation of IKernelTypeStrResolver which loads kernel type string matching info from node op schemas. + * + * As this requires node op schemas, it is only enabled in a full build. + */ +class OpSchemaKernelTypeStrResolver : public IKernelTypeStrResolver { + public: + // Note: `node`'s op schema must be populated. + Status ResolveKernelTypeStr(const Node& node, std::string_view kernel_type_str, + gsl::span& resolved_args) const override; + + private: + // used as a cache when resolving + // since the cache may be modified with a const instance, ensure that access to the cache is thread-safe + mutable KernelTypeStrResolver resolver_; + mutable OrtMutex resolver_mutex_; +}; + +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc new file mode 100644 index 0000000000000..7984fb219ce29 --- /dev/null +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "core/framework/kernel_type_str_resolver_utils.h" + +#include "flatbuffers/flatbuffers.h" + +#include "core/common/common.h" +#include "core/flatbuffers/schema/ort.fbs.h" +#include "core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h" + +namespace onnxruntime::kernel_type_str_resolver_utils { + +static constexpr auto* kStandaloneKernelTypeStrResolverFileIdentifier = "ktsr"; + +#if !defined(ORT_MINIMAL_BUILD) + +gsl::span GetLayoutTransformationRequiredOpIdentifiers() { + return kLayoutTransformationPotentiallyAddedOps; +} + +Status SaveKernelTypeStrResolverToBuffer(const KernelTypeStrResolver& kernel_type_str_resolver, + flatbuffers::DetachedBuffer& buffer, gsl::span& buffer_span) { + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset fbs_kernel_type_str_resolver; + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.SaveToOrtFormat(builder, fbs_kernel_type_str_resolver)); + builder.Finish(fbs_kernel_type_str_resolver, kStandaloneKernelTypeStrResolverFileIdentifier); + buffer = builder.Release(); + buffer_span = gsl::make_span(buffer.data(), buffer.size()); + return Status::OK(); +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +Status LoadKernelTypeStrResolverFromBuffer(KernelTypeStrResolver& kernel_type_str_resolver, + gsl::span buffer_span) { + flatbuffers::Verifier verifier{buffer_span.data(), buffer_span.size_bytes()}; + ORT_RETURN_IF_NOT(verifier.VerifyBuffer(kStandaloneKernelTypeStrResolverFileIdentifier), + "Failed to verify KernelTypeStrResolver flatbuffers data."); + const auto* fbs_kernel_type_str_resolver = flatbuffers::GetRoot(buffer_span.data()); + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.LoadFromOrtFormat(*fbs_kernel_type_str_resolver)); + return Status::OK(); +} + +Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrResolver& kernel_type_str_resolver) { + KernelTypeStrResolver resolver_with_required_ops{}; + + // to generate kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes, run the test: + // KernelTypeStrResolverUtilsTest.DISABLED_PrintExpectedLayoutTransformationRequiredOpsResolverByteArray + + // clang-format off + constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { + 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0xfc, 0x04, 0x00, 0x00, + 0x48, 0x06, 0x00, 0x00, 0xac, 0x06, 0x00, 0x00, 0xa4, 0x05, 0x00, 0x00, 0x2c, 0x03, 0x00, 0x00, + 0xd0, 0x01, 0x00, 0x00, 0xe0, 0x05, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x10, 0x02, 0x00, 0x00, + 0x58, 0x01, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x38, 0x05, 0x00, 0x00, 0xc0, 0x02, 0x00, 0x00, + 0xb0, 0x00, 0x00, 0x00, 0x40, 0x02, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x48, 0x03, 0x00, 0x00, + 0x30, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, + 0x3a, 0x31, 0x00, 0x00, 0x54, 0xf9, 0xff, 0xff, 0x7c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x82, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x7c, 0xf9, 0xff, 0xff, 0x78, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, + 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xac, 0xf9, 0xff, 0xff, + 0x24, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xda, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd4, 0xf9, 0xff, 0xff, + 0xd0, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, + 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, 0xf8, 0xf9, 0xff, 0xff, 0xd8, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x26, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x20, 0xfa, 0xff, 0xff, 0x1c, 0xfa, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x44, 0xfa, 0xff, 0xff, 0x8c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x72, 0xfa, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x6c, 0xfa, 0xff, 0xff, 0x68, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, + 0x90, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbe, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xb8, 0xfa, 0xff, 0xff, 0xb4, 0xfa, 0xff, 0xff, 0xe4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa0, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xd0, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0xf8, 0xfa, 0xff, 0xff, 0x1c, 0x04, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x26, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x20, 0xfb, 0xff, 0xff, 0x1c, 0xfb, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, + 0x40, 0xfb, 0xff, 0xff, 0x90, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x68, 0xfb, 0xff, 0xff, 0x64, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, + 0x90, 0xfb, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x88, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xb8, 0xfb, 0xff, 0xff, 0x18, 0x04, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xe6, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xe0, 0xfb, 0xff, 0xff, 0xdc, 0xfb, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x04, 0xfc, 0xff, 0xff, 0xcc, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xfc, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x2c, 0xfc, 0xff, 0xff, 0x28, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x50, 0xfc, 0xff, 0xff, 0x80, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x7e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x78, 0xfc, 0xff, 0xff, 0x74, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0xbc, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0xe4, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x28, 0x01, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, + 0xc0, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xb8, 0xfc, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0xe8, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1e, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xe8, 0xfc, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x18, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x10, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x48, 0xfd, 0xff, 0xff, 0x44, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x3c, 0xfd, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x44, 0xfd, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x74, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x98, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x90, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0xc0, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xb8, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xe8, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, + 0x10, 0xfe, 0xff, 0xff, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xfc, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x2c, 0xfe, 0xff, 0xff, + 0xa4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x5a, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x54, 0xfe, 0xff, 0xff, + 0x50, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x78, 0xfe, 0xff, 0xff, 0x58, 0x01, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xa6, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa0, 0xfe, 0xff, 0xff, 0x9c, 0xfe, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, + 0xc0, 0xfe, 0xff, 0xff, 0x10, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xee, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xe8, 0xfe, 0xff, 0xff, 0xe4, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0x0c, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x3c, 0xff, 0xff, 0xff, 0x38, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x31, 0x00, 0x00, + 0x60, 0xff, 0xff, 0xff, 0xb0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x4c, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x7c, 0xff, 0xff, 0xff, + 0x54, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xaa, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa4, 0xff, 0xff, 0xff, + 0xa0, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, + 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xc8, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + }; + // clang-format on + + ORT_RETURN_IF_ERROR(LoadKernelTypeStrResolverFromBuffer(resolver_with_required_ops, + kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes)); + kernel_type_str_resolver.Merge(std::move(resolver_with_required_ops)); + return Status::OK(); +} + +} // namespace onnxruntime::kernel_type_str_resolver_utils + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h new file mode 100644 index 0000000000000..b9535c31f15cc --- /dev/null +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "gsl/gsl" + +#include "core/common/status.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/graph/op_identifier.h" + +namespace flatbuffers { +class DetachedBuffer; +} + +namespace onnxruntime::kernel_type_str_resolver_utils { + +#if !defined(ORT_MINIMAL_BUILD) + +/** + * Gets the ops that the layout transformation may potentially add. + */ +gsl::span GetLayoutTransformationRequiredOpIdentifiers(); + +/** + * Saves `kernel_type_str_resolver` to a byte buffer owned by `buffer` and referenced by `buffer_span`. + */ +Status SaveKernelTypeStrResolverToBuffer(const KernelTypeStrResolver& kernel_type_str_resolver, + flatbuffers::DetachedBuffer& buffer, gsl::span& buffer_span); + +#endif // !defined(ORT_MINIMAL_BUILD) + +/** + * Loads `kernel_type_str_resolver` from the byte buffer referenced by `buffer_span`. + */ +Status LoadKernelTypeStrResolverFromBuffer(KernelTypeStrResolver& kernel_type_str_resolver, + gsl::span buffer_span); + +/** + * Adds the ops that the layout transformation may potentially add to `kernel_type_str_resolver`. + * This is needed when loading an ORT format model in a build where layout transformation is enabled. + */ +Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrResolver& kernel_type_str_resolver); + +} // namespace onnxruntime::kernel_type_str_resolver_utils + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 8c5b42ad9bc3c..0dcd7b6be0ca0 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -10,11 +10,9 @@ #include "core/common/safeint.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/allocator.h" -#include "core/framework/kernel_def_hash_helpers.h" #include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" -#include "core/framework/session_state_flatbuffers_utils.h" #include "core/framework/session_state_utils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" @@ -114,20 +112,18 @@ void SessionState::CreateGraphInfo() { LOGS(logger_, VERBOSE) << "Done saving OrtValue mappings."; } -#if !defined(ORT_MINIMAL_BUILD) Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kernel_registry_manager, bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); if (!status.IsOK() && saving_ort_format) { // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. // in that case we assigned the node to that EP but do not compile it into a fused node. // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. - // we now revert to the CPU EP to include the hash for the kernel as a fallback. at runtime when the model - // is loaded in a minimal build, the compiling EP will replace this node if possible. if that's not possible for - // some reason we can fallback to the CPU EP implementation via this hash. + // we now revert to the CPU EP kernel as a fallback. + // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. + // if that's not possible for some reason we can fallback to the CPU EP implementation. node.SetExecutionProviderType(kCpuExecutionProvider); status = kernel_registry_manager.SearchKernelRegistry(node, &kci); } @@ -141,13 +137,13 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne for (const auto& entry : subgraph_session_states_) { for (const auto& name_to_subgraph_session_state : entry.second) { SessionState& subgraph_session_state = *name_to_subgraph_session_state.second; - ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format)); + ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager, + saving_ort_format)); } } return Status::OK(); } -#endif const KernelCreateInfo& SessionState::GetNodeKernelCreateInfo(NodeIndex node_index) const { auto entry = kernel_create_info_map_.find(node_index); @@ -896,55 +892,6 @@ const InlinedHashSet* SessionState::GetToBeExecutedNodes( return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr; } -static Status GetSubGraphSessionStatesOrtFormat( - flatbuffers::FlatBufferBuilder& builder, - const SubgraphSessionStateMap& subgraph_session_states, - std::vector>& fbs_subgraph_session_states) { - size_t number_of_states = 0; - for (const auto& pair : subgraph_session_states) { - number_of_states += pair.second.size(); - } - fbs_subgraph_session_states.clear(); - fbs_subgraph_session_states.reserve(number_of_states); - for (const auto& [node_idx, session_states] : subgraph_session_states) { - for (const auto& name_to_subgraph_session_state : session_states) { - const std::string& attr_name = name_to_subgraph_session_state.first; - SessionState& subgraph_session_state = *name_to_subgraph_session_state.second; - auto graph_id = builder.CreateString(fbs::utils::GetSubgraphId(node_idx, attr_name)); - flatbuffers::Offset session_state; - ORT_RETURN_IF_ERROR( - subgraph_session_state.SaveToOrtFormat(builder, session_state)); - - fbs_subgraph_session_states.push_back( - fbs::CreateSubGraphSessionState(builder, graph_id, session_state)); - } - } - return Status::OK(); -} - -Status SessionState::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, - flatbuffers::Offset& fbs_session_state) const { - size_t size = kernel_create_info_map_.size(); - std::vector node_indices; - std::vector kernel_def_hashes; - node_indices.reserve(size); - kernel_def_hashes.reserve(size); - for (const auto& kvp : kernel_create_info_map_) { - node_indices.push_back(gsl::narrow(kvp.first)); - kernel_def_hashes.push_back(kvp.second->kernel_def->GetHash()); - } - - auto kernels = fbs::CreateKernelCreateInfosDirect(builder, &node_indices, &kernel_def_hashes); - - // Subgraph session states - std::vector> sub_graph_session_states; - ORT_RETURN_IF_ERROR( - GetSubGraphSessionStatesOrtFormat(builder, subgraph_session_states_, sub_graph_session_states)); - - fbs_session_state = fbs::CreateSessionStateDirect(builder, kernels, &sub_graph_session_states); - return Status::OK(); -} - #endif // !defined(ORT_MINIMAL_BUILD) Status SessionState::CreateSubgraphSessionState() { @@ -981,121 +928,6 @@ Status SessionState::CreateSubgraphSessionState() { return Status::OK(); } -Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_state, - const KernelRegistryManager& kernel_registry_manager) { - using fbs::utils::FbsSessionStateViewer; - const FbsSessionStateViewer fbs_session_state_viewer{fbs_session_state}; - ORT_RETURN_IF_ERROR(fbs_session_state_viewer.Validate()); - - // look up KernelCreateInfo with hash and - // - add KernelCreateInfo for node - // - set node's EP from KernelCreateInfo if unset - auto add_kernel_and_set_node_ep_by_hash = - [&kernel_registry_manager, this](Node& node, HashValue hash) { - const KernelCreateInfo* kci = nullptr; - utils::UpdateHashForBackwardsCompatibility(hash); - - ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(hash, &kci), - "Failed to find kernel def hash (", hash, ") in kernel registries for ", - node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'."); - - { - const auto [it, inserted] = kernel_create_info_map_.emplace(node.Index(), - gsl::not_null(kci)); - ORT_RETURN_IF_NOT(inserted, - "Cannot overwrite existing kernel for ", - node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), - "'. Existing kernel def hash: ", it->second->kernel_def->GetHash(), - ", new kernel def hash: ", hash, "."); - } - - if (node.GetExecutionProviderType().empty()) { - node.SetExecutionProviderType(kci->kernel_def->Provider()); - } else { - ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kci->kernel_def->Provider(), - "Node execution provider type mismatch. Existing: ", node.GetExecutionProviderType(), - ", from KernelCreateInfo (via hash lookup): ", kci->kernel_def->Provider()); - } - - return Status::OK(); - }; - - // kernel hashes for model are in top level SessionState - const auto& compiled_kernel_hashes = GetCompiledKernelHashes(); - - // process the nodes that existed when the model was created - for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) { - const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i); - - Node* const node = graph_.GetNode(node_kernel_info.node_index); - if (node == nullptr) { -#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) - // this is OK if we have compiled kernels and the original node was replaced. - // if not the model is invalid. - ORT_RETURN_IF(compiled_kernel_hashes.empty(), - "Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model."); -#endif // defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) - continue; - } - - ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, node_kernel_info.kernel_def_hash)); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // process the nodes that were added by replaying any loaded runtime optimizations - for (const auto& [node_index, kernel_def_hash] : - graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) { - auto* node = graph_.GetNode(node_index); - - // NHWC optimizer may replace a node, so a missing node isn't necessarily an error - // ORT_RETURN_IF(node == nullptr, "Can't find runtime optimization produced node with index ", node_index); - - if (node != nullptr) { - ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, kernel_def_hash)); - } - } - - // Look up the hashes for any nodes we compiled or added during graph partitioning or other runtime optimizations. - // These nodes are not in the original model as they were created at runtime. - for (auto& node : graph_.Nodes()) { - if (kernel_create_info_map_.find(node.Index()) != kernel_create_info_map_.end()) { - continue; - } - - if (node.Domain() == kOnnxDomain || node.Domain() == kMSDomain) { - // two possible places to get hash from - auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); - if (!kernel_hash.has_value()) { - kernel_hash = utils::GetInternalNhwcOpHash(node); - } - ORT_RETURN_IF_NOT(kernel_hash.has_value(), - "Unable to find kernel hash for node: '", node.Name(), "' optype: ", node.OpType()); - - ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, *kernel_hash)); - } else { - const auto hash_info = compiled_kernel_hashes.find(node.OpType()); - ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(), - "Unable to find compiled kernel hash for node '", node.Name(), "'."); - - ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, hash_info->second)); - } - } -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - - if (!subgraph_session_states_.empty()) { - for (const auto& [node_idx, session_states] : subgraph_session_states_) { - for (const auto& [attr_name, subgraph_session_state] : session_states) { - const fbs::SessionState* fbs_subgraph_session_state; - ORT_RETURN_IF_ERROR(fbs_session_state_viewer.GetSubgraphSessionState(node_idx, attr_name, fbs_subgraph_session_state)); - - ORT_RETURN_IF_ERROR(subgraph_session_state->LoadFromOrtFormat(*fbs_subgraph_session_state, kernel_registry_manager)); - } - } - } - - return Status::OK(); -} - // Calculate the use count of a constant initialized tensor, including the use in subgraph. // Note: This function doesn't handle the case below: // The main graph has a constant initializer called X, and the subgraph also has a constant initializer called X, which overrides the X from main graph. @@ -1207,7 +1039,6 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging:: Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, const SessionOptions& session_options, - const onnxruntime::fbs::SessionState* serialized_session_state, bool remove_initializers, bool saving_ort_format) { // recursively create the subgraph session state instances and populate the kernel create info in them. @@ -1215,24 +1046,8 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); @@ -1383,8 +1198,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string> unused_initializer_names; - unused_initializer_names.reserve(graph_.GetAllInitializedTensors().size()); + InlinedVector unused_initializer_names; for (const auto& [name, tensor_proto] : graph_.GetAllInitializedTensors()) { ORT_UNUSED_PARAMETER(tensor_proto); int idx; diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 83b94070056a3..4e153dc8c2590 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -301,21 +301,11 @@ class SessionState { #if !defined(ORT_MINIMAL_BUILD) void UpdateToBeExecutedNodes(gsl::span fetch_mlvalue_idxs); const InlinedHashSet* GetToBeExecutedNodes(gsl::span fetch_mlvalue_idxs) const; - Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, - flatbuffers::Offset& fbs_session_state) const; #endif - void SetCompiledKernelHashes(std::unordered_map&& compiled_kernel_hashes) { - compiled_kernel_hashes_ = std::move(compiled_kernel_hashes); - } - - Status LoadFromOrtFormat(const onnxruntime::fbs::SessionState& fbs_session_state, - const KernelRegistryManager& kernel_registry_manager); - Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, const SessionOptions& session_options = {}, - const onnxruntime::fbs::SessionState* serialized_session_state = nullptr, bool remove_initializers = true, bool saving_ort_format = false); @@ -378,9 +368,8 @@ class SessionState { void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, std::unique_ptr session_state); -#if !defined(ORT_MINIMAL_BUILD) - Status PopulateKernelCreateInfo(const KernelRegistryManager& kernel_registry_manager, bool saving_ort_format); -#endif + Status PopulateKernelCreateInfo(const KernelRegistryManager& kernel_registry_manager, + bool saving_ort_format); Status FinalizeSessionStateImpl(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, @@ -399,21 +388,12 @@ class SessionState { InlinedHashMap& inferred_shapes) const; #endif - // the SessionState for the main Graph contains the compiled kernel hashes for the entire model - const std::unordered_map& GetCompiledKernelHashes() const { - return parent_ ? parent_->GetCompiledKernelHashes() : compiled_kernel_hashes_; - } - // KernelCreateInfo for each node so we do kernel lookup once KernelCreateInfoMap kernel_create_info_map_; // fused_funcs_mgr_ must live longer than the session_kernels_, becaues a kernel could be created from this manager FuncManager fused_funcs_mgr_; - // If we compile kernels in a minimal build we need a way to find the kernel using the hash. - // We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat. - std::unordered_map compiled_kernel_hashes_; - // cache of the constructed kernels to avoid spending construction time per executor std::vector> session_kernels_; Graph& graph_; diff --git a/onnxruntime/core/framework/session_state_flatbuffers_utils.cc b/onnxruntime/core/framework/session_state_flatbuffers_utils.cc deleted file mode 100644 index 52c2171961ff4..0000000000000 --- a/onnxruntime/core/framework/session_state_flatbuffers_utils.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/session_state_flatbuffers_utils.h" - -#include "core/framework/kernel_def_hash_helpers.h" - -namespace onnxruntime::fbs::utils { - -std::string GetSubgraphId(const NodeIndex node_idx, const std::string& attr_name) { - return std::to_string(node_idx) + "_" + attr_name; -} - -FbsSessionStateViewer::FbsSessionStateViewer(const fbs::SessionState& fbs_session_state) - : fbs_session_state_{fbs_session_state} { -} - -Status FbsSessionStateViewer::Validate() const { - if (fbs_session_state_.sub_graph_session_states() == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SessionState for subgraphs is null. Invalid ORT format model."); - } - - const auto* const fbs_kcis = fbs_session_state_.kernels(); - if (fbs_kcis == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel create info is null. Invalid ORT format model."); - } - - const auto* const fbs_node_indices = fbs_kcis->node_indices(); - if (fbs_node_indices == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel create info node indices are null. Invalid ORT format model."); - } - - const auto* const fbs_kernel_def_hashes = fbs_kcis->kernel_def_hashes(); - if (fbs_kernel_def_hashes == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel create info hashes are null. Invalid ORT format model."); - } - - if (fbs_node_indices->size() != fbs_kernel_def_hashes->size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Size mismatch for kernel create info node indexes and hashes. Invalid ORT format model.", - fbs_node_indices->size(), " != ", fbs_kernel_def_hashes->size()); - } - - return Status::OK(); -} - -FbsSessionStateViewer::NodeKernelInfo FbsSessionStateViewer::GetNodeKernelInfo(Index idx) const { - const auto* const fbs_kcis = fbs_session_state_.kernels(); - const auto* const fbs_node_indices = fbs_kcis->node_indices(); - const auto* const fbs_kernel_def_hashes = fbs_kcis->kernel_def_hashes(); - - HashValue hash = fbs_kernel_def_hashes->Get(idx); - onnxruntime::utils::UpdateHashForBackwardsCompatibility(hash); - - return {fbs_node_indices->Get(idx), hash}; -} - -FbsSessionStateViewer::Index FbsSessionStateViewer::GetNumNodeKernelInfos() const { - return fbs_session_state_.kernels()->node_indices()->size(); -} - -Status FbsSessionStateViewer::GetSubgraphSessionState(NodeIndex node_idx, const std::string& attr_name, - const fbs::SessionState*& fbs_subgraph_session_state_out) const { - const auto key = GetSubgraphId(node_idx, attr_name); - const auto* const fbs_subgraph_session_state_entry = - fbs_session_state_.sub_graph_session_states()->LookupByKey(key.c_str()); - if (fbs_subgraph_session_state_entry == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Subgraph SessionState entry for ", key, " is missing. Invalid ORT format model."); - } - - const auto* const fbs_subgraph_session_state = fbs_subgraph_session_state_entry->session_state(); - if (fbs_subgraph_session_state == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Subgraph SessionState for ", key, " is null. Invalid ORT format model."); - } - - fbs_subgraph_session_state_out = fbs_subgraph_session_state; - return Status::OK(); -} -} // namespace onnxruntime::fbs::utils diff --git a/onnxruntime/core/framework/session_state_flatbuffers_utils.h b/onnxruntime/core/framework/session_state_flatbuffers_utils.h deleted file mode 100644 index 654163470f7bd..0000000000000 --- a/onnxruntime/core/framework/session_state_flatbuffers_utils.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/common.h" -#include "core/flatbuffers/schema/ort.fbs.h" -#include "core/graph/basic_types.h" - -namespace onnxruntime::fbs::utils { - -/** - * Gets the key that can be used to look up a fbs::SubGraphSessionState in a fbs::SessionState. - * - * @param node_idx The index of the node in the current graph. - * @param attr_name The name of the node attribute that contains the subgraph. - * @return The subgraph key. - */ -std::string GetSubgraphId(const NodeIndex node_idx, const std::string& attr_name); - -/** - * Provides read-only helper functions for a fbs::SessionState instance. - */ -class FbsSessionStateViewer { - public: - /** - * Creates an instance. - * Validation is not performed here, but in Validate(). - * - * @param fbs_session_state The fbs::SessionState instance. - */ - FbsSessionStateViewer(const fbs::SessionState& fbs_session_state); - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FbsSessionStateViewer); - - /** - * Validates the underlying fbs::SessionState instance. - * WARNING: Other methods assume that the fbs::SessionState is valid! - * - * @return Whether the fbs::SessionState instance is valid. - */ - Status Validate() const; - - using Index = flatbuffers::uoffset_t; - - struct NodeKernelInfo { - NodeIndex node_index; - HashValue kernel_def_hash; - }; - - /** - * Retrieves the node kernel info element. - * - * @param index The index of the node kernel info element. - * @return The node kernel info element. - */ - NodeKernelInfo GetNodeKernelInfo(Index idx) const; - - /** - * Gets the number of node kernel info elements. - */ - Index GetNumNodeKernelInfos() const; - - /** - * Retrieves the subgraph session state from the fbs::SessionState instance. - * - * @param node_idx The index of the node containing the subgraph. - * @param attr_name The name of the attribute containing the subgraph. - * @param[out] fbs_subgraph_session_state The subgraph session state. Non-null if successful. - * @return Whether the retrieval was successful. - */ - Status GetSubgraphSessionState(NodeIndex node_idx, const std::string& attr_name, - const fbs::SessionState*& fbs_subgraph_session_state) const; - - private: - const fbs::SessionState& fbs_session_state_; -}; -} // namespace onnxruntime::fbs::utils diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 089aa6cd9121a..f7b7e7e52b102 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -756,8 +756,7 @@ Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, fbs::utils::LoadStringFromOrtFormat(op_type_, fbs_node.op_type()); node_type_ = static_cast(fbs_node.type()); // we skip populating the saved EP here - // the node will either be assigned to another EP by the ORT format model-specific graph partitioning or fall back to - // the EP encoded in its kernel def hash + // the node will be assigned to an EP by the ORT format model-specific graph partitioning // fbs::utils::LoadStringFromOrtFormat(execution_provider_type_, fbs_node.execution_provider_type()); ORT_RETURN_IF_ERROR(LoadNodeArgsFromOrtFormat(fbs_node.inputs(), definitions_.input_defs)); @@ -3831,11 +3830,11 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std func_meta_def->domain); fused_node.SetNodeType(Node::Type::Fused); + fused_node.SetSinceVersion(func_meta_def->since_version); + #if !defined(ORT_MINIMAL_BUILD) // if this is a full build create the lightweight Function implementation that provides the schema so that // kernel lookup works as per usual, if not using an existing schema. - // in an extended minimal build we do the lookup via a hash so don't need a schema. - fused_node.SetSinceVersion(func_meta_def->since_version); if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) { ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node), "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType()); diff --git a/onnxruntime/core/graph/op_identifier.h b/onnxruntime/core/graph/op_identifier.h new file mode 100644 index 0000000000000..c3cdfee1208bf --- /dev/null +++ b/onnxruntime/core/graph/op_identifier.h @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/hash_combine.h" +#include "core/common/status.h" +#include "core/common/string_utils.h" +#include "core/common/parse_string.h" + +namespace onnxruntime { + +template +struct BasicOpIdentifier { + StringType domain; + StringType op_type; + int since_version; + + // comparison + + friend constexpr bool operator<(const BasicOpIdentifier& lhs, + const BasicOpIdentifier& rhs) { + return lhs.Tied() < rhs.Tied(); + } + + friend constexpr bool operator==(const BasicOpIdentifier& lhs, + const BasicOpIdentifier& rhs) { + return lhs.Tied() == rhs.Tied(); + } + + // hash computation + + size_t GetHash() const { + size_t h = std::hash{}(domain); + HashCombine(op_type, h); + HashCombine(since_version, h); + return h; + } + + // string conversion + + std::string ToString() const { + return MakeString(domain, kStringRepresentationDelimiter, + op_type, kStringRepresentationDelimiter, + since_version); + } + + static Status LoadFromString(std::string_view op_id_str, BasicOpIdentifier& op_id) { + const auto components = utils::SplitString(op_id_str, kStringRepresentationDelimiter, true); + ORT_RETURN_IF_NOT(components.size() == 3, "Invalid OpIdentifier string: ", op_id_str); + int since_version{}; + ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(components[2], since_version), + "Failed to parse since_version from ", components[2]); + op_id = BasicOpIdentifier{StringType{components[0]}, StringType{components[1]}, since_version}; + return Status::OK(); + } + + friend std::ostream& operator<<(std::ostream& os, const BasicOpIdentifier& op_id) { + os << op_id.ToString(); + return os; + } + + private: + constexpr auto Tied() const { + return std::tie(domain, op_type, since_version); + } + + static constexpr std::string_view kStringRepresentationDelimiter = ":"; +}; + +using OpIdentifier = BasicOpIdentifier; + +// An op identifier that uses std::string_view to refer to domain and op type values. +// IMPORTANT: Be sure that the underlying strings remain valid for the lifetime of the op identifier. +using OpIdentifierWithStringViews = BasicOpIdentifier; + +} // namespace onnxruntime + +// add std::hash specializations +namespace std { +template <> +struct hash { + size_t operator()(const onnxruntime::OpIdentifier& v) const { + return v.GetHash(); + } +}; + +template <> +struct hash { + size_t operator()(const onnxruntime::OpIdentifierWithStringViews& v) const { + return v.GetHash(); + } +}; +} // namespace std diff --git a/onnxruntime/core/graph/op_identifier_utils.cc b/onnxruntime/core/graph/op_identifier_utils.cc new file mode 100644 index 0000000000000..4f0deca6ef8d7 --- /dev/null +++ b/onnxruntime/core/graph/op_identifier_utils.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/op_identifier_utils.h" + +#include "core/flatbuffers/flatbuffers_utils.h" +#include "core/flatbuffers/schema/ort.fbs.h" + +namespace onnxruntime::fbs::utils { + +#if !defined(ORT_MINIMAL_BUILD) + +Status SaveOpIdentifierOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const onnxruntime::OpIdentifier& op_id, + flatbuffers::Offset& fbs_op_id_str) { + const auto op_id_str = op_id.ToString(); + fbs_op_id_str = builder.CreateSharedString(op_id_str); + return Status::OK(); +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +Status LoadOpIdentifierOrtFormat(const flatbuffers::String& fbs_op_id_str, + onnxruntime::OpIdentifier& op_id) { + ORT_RETURN_IF_ERROR(onnxruntime::OpIdentifier::LoadFromString(fbs_op_id_str.string_view(), op_id)); + return Status::OK(); +} + +} // namespace onnxruntime::fbs::utils diff --git a/onnxruntime/core/graph/op_identifier_utils.h b/onnxruntime/core/graph/op_identifier_utils.h new file mode 100644 index 0000000000000..14cec1f3f0c93 --- /dev/null +++ b/onnxruntime/core/graph/op_identifier_utils.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/op_identifier.h" + +#include "core/common/status.h" +#include "core/graph/graph.h" + +#if !defined(ORT_MINIMAL_BUILD) + +#include "onnx/defs/schema.h" // for ONNX_NAMESPACE::OpSchema + +#endif // !defined(ORT_MINIMAL_BUILD) + +namespace flatbuffers { +class FlatBufferBuilder; + +template +struct Offset; + +struct String; +} // namespace flatbuffers + +namespace onnxruntime { + +namespace fbs::utils { + +#if !defined(ORT_MINIMAL_BUILD) + +Status SaveOpIdentifierOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const onnxruntime::OpIdentifier& op_id, + flatbuffers::Offset& fbs_op_id_str); + +#endif // !defined(ORT_MINIMAL_BUILD) + +Status LoadOpIdentifierOrtFormat(const flatbuffers::String& fbs_op_id_str, + onnxruntime::OpIdentifier& op_id); + +} // namespace fbs::utils + +namespace utils { + +inline onnxruntime::OpIdentifier MakeOpId(const Node& node) { + return onnxruntime::OpIdentifier{node.Domain(), node.OpType(), node.SinceVersion()}; +} + +#if !defined(ORT_MINIMAL_BUILD) + +inline onnxruntime::OpIdentifier MakeOpId(const ONNX_NAMESPACE::OpSchema& op_schema) { + return onnxruntime::OpIdentifier{op_schema.domain(), op_schema.Name(), op_schema.SinceVersion()}; +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace utils + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/runtime_optimization_record.h b/onnxruntime/core/graph/runtime_optimization_record.h index bf9177c6aae40..7ac29b6d589d1 100644 --- a/onnxruntime/core/graph/runtime_optimization_record.h +++ b/onnxruntime/core/graph/runtime_optimization_record.h @@ -4,11 +4,12 @@ #pragma once #include -#include #include #include // for std::tie +#include "core/common/inlined_containers.h" #include "core/graph/basic_types.h" +#include "core/graph/op_identifier.h" /* Runtime optimization limitations * @@ -48,7 +49,7 @@ an ORT format model. This also means that non-empty node indices here must be in static_assert(kEmptyNodeIndex <= std::numeric_limits::max()); /** Indices of the nodes in the graph that are considered for optimization. */ - std::vector nodes; + InlinedVector nodes; /** The number of inputs of the target node. */ int num_inputs; /** The number of outputs of the target node. */ @@ -75,21 +76,19 @@ an ORT format model. This also means that non-empty node indices here must be in } }; -struct NodeIndexAndKernelDefHash { - NodeIndex node_index; - HashValue kernel_def_hash; -}; - /** Information for a single runtime optimization. -It does not contain information about the optimizer itself, that should be maintained seperately. +It does not contain information about the optimizer itself, that should be maintained separately. */ struct RuntimeOptimizationRecord { /** The optimization action identifier. */ std::string action_id; + /** The nodes to consider for optimization. */ NodesToOptimizeIndices nodes_to_optimize_indices; - /** Any new nodes introduced by the optimization. */ - std::vector produced_nodes; + + using ProducedOpIdVector = InlinedVector; + /** Op identifiers for any new nodes introduced by the optimization. */ + ProducedOpIdVector produced_op_ids; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.cc b/onnxruntime/core/graph/runtime_optimization_record_container.cc index 342ca11bc88a5..9a4e705d9ae49 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.cc +++ b/onnxruntime/core/graph/runtime_optimization_record_container.cc @@ -11,10 +11,12 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/flatbuffers/schema/ort.fbs.h" +#include "core/graph/op_identifier_utils.h" namespace onnxruntime { -#if defined(ORT_ENABLE_ADDING_RUNTIME_OPTIMIZATION_RECORDS) +#if !defined(ORT_MINIMAL_BUILD) + bool RuntimeOptimizationRecordContainer::RecordExists(const std::string& optimizer_name, const std::string& action_id, const NodesToOptimizeIndices& nodes_to_optimize_indices) const { @@ -34,7 +36,8 @@ void RuntimeOptimizationRecordContainer::AddRecord(const std::string& optimizer_ auto& optimizations = optimizer_name_to_records_[optimizer_name]; optimizations.emplace_back(std::move(runtime_optimization_record)); } -#endif + +#endif // !defined(ORT_MINIMAL_BUILD) std::vector RuntimeOptimizationRecordContainer::RemoveRecordsForOptimizer( const std::string& optimizer_name) { @@ -46,6 +49,8 @@ std::vector RuntimeOptimizationRecordContainer::Remov return records; } +#if !defined(ORT_MINIMAL_BUILD) + static Status SaveRuntimeOptimizationRecordToOrtFormat( flatbuffers::FlatBufferBuilder& builder, const RuntimeOptimizationRecord& runtime_optimization_record, @@ -66,20 +71,21 @@ static Status SaveRuntimeOptimizationRecordToOrtFormat( nodes_to_optimize_indices.num_variadic_inputs, nodes_to_optimize_indices.num_variadic_outputs); - const auto fbs_produced_nodes = builder.CreateVector>( - runtime_optimization_record.produced_nodes.size(), - [&](size_t i) -> flatbuffers::Offset { - return fbs::CreateNodeIndexAndKernelDefHash( - builder, - gsl::narrow(runtime_optimization_record.produced_nodes[i].node_index), - runtime_optimization_record.produced_nodes[i].kernel_def_hash); - }); + const auto& produced_op_ids = runtime_optimization_record.produced_op_ids; + + std::vector> fbs_produced_op_id_vector; + fbs_produced_op_id_vector.reserve(produced_op_ids.size()); + for (const auto& produced_op_id : produced_op_ids) { + flatbuffers::Offset fbs_produced_op_id; + ORT_RETURN_IF_ERROR(fbs::utils::SaveOpIdentifierOrtFormat(builder, produced_op_id, fbs_produced_op_id)); + fbs_produced_op_id_vector.push_back(fbs_produced_op_id); + } fbs_runtime_optimization_record = fbs::CreateRuntimeOptimizationRecord(builder, builder.CreateSharedString(runtime_optimization_record.action_id), fbs_nodes_to_optimize, - fbs_produced_nodes); + builder.CreateVector(fbs_produced_op_id_vector)); return Status::OK(); } @@ -108,6 +114,8 @@ Status RuntimeOptimizationRecordContainer::SaveToOrtFormat( return Status::OK(); } +#endif // !defined(ORT_MINIMAL_BUILD) + static Status LoadRuntimeOptimizationRecordFromOrtFormat( const fbs::RuntimeOptimizationRecord& fbs_runtime_optimization_record, RuntimeOptimizationRecord& runtime_optimization_record_out) { @@ -120,7 +128,7 @@ static Status LoadRuntimeOptimizationRecordFromOrtFormat( if (const auto* fbs_nodes_to_optimize_indices = fbs_runtime_optimization_record.nodes_to_optimize_indices()) { if (const auto* fbs_node_indices = fbs_nodes_to_optimize_indices->node_indices()) { nodes_to_optimize_indices.nodes = [&]() { - std::vector result; + InlinedVector result; result.reserve(fbs_node_indices->size()); std::transform(fbs_node_indices->begin(), fbs_node_indices->end(), std::back_inserter(result), [](const uint32_t idx) { return static_cast(idx); }); @@ -136,14 +144,14 @@ static Status LoadRuntimeOptimizationRecordFromOrtFormat( nodes_to_optimize_indices.num_variadic_outputs = fbs_nodes_to_optimize_indices->num_variadic_outputs(); } - if (const auto* fbs_produced_nodes = fbs_runtime_optimization_record.produced_nodes()) { - runtime_optimization_record.produced_nodes.reserve(fbs_produced_nodes->size()); - for (const auto* fbs_node_index_and_kernel_def_hash : *fbs_produced_nodes) { - if (!fbs_node_index_and_kernel_def_hash) continue; - - runtime_optimization_record.produced_nodes.push_back( - NodeIndexAndKernelDefHash{static_cast(fbs_node_index_and_kernel_def_hash->node_index()), - fbs_node_index_and_kernel_def_hash->kernel_def_hash()}); + auto& produced_op_ids = runtime_optimization_record.produced_op_ids; + if (const auto* fbs_produced_op_ids = fbs_runtime_optimization_record.produced_op_ids()) { + produced_op_ids.reserve(fbs_produced_op_ids->size()); + for (const auto* fbs_produced_op_id : *fbs_produced_op_ids) { + ORT_FORMAT_RETURN_IF_NULL(fbs_produced_op_id, "runtime optimization record produced op id"); + OpIdentifier produced_op_id; + ORT_RETURN_IF_ERROR(fbs::utils::LoadOpIdentifierOrtFormat(*fbs_produced_op_id, produced_op_id)); + produced_op_ids.push_back(std::move(produced_op_id)); } } diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.h b/onnxruntime/core/graph/runtime_optimization_record_container.h index 278425cc04139..5db784f1a27af 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.h +++ b/onnxruntime/core/graph/runtime_optimization_record_container.h @@ -12,10 +12,6 @@ #include "core/common/common.h" #include "core/graph/runtime_optimization_record.h" -#if !defined(ORT_MINIMAL_BUILD) -#define ORT_ENABLE_ADDING_RUNTIME_OPTIMIZATION_RECORDS -#endif // !defined(ORT_MINIMAL_BUILD) - namespace flatbuffers { class FlatBufferBuilder; template @@ -34,13 +30,15 @@ class RuntimeOptimizationRecordContainer { public: bool IsEmpty() const { return optimizer_name_to_records_.empty(); } -#if defined(ORT_ENABLE_ADDING_RUNTIME_OPTIMIZATION_RECORDS) +#if !defined(ORT_MINIMAL_BUILD) + bool RecordExists(const std::string& optimizer_name, const std::string& action_id, const NodesToOptimizeIndices& nodes_to_optimize_indices) const; void AddRecord(const std::string& optimizer_name, RuntimeOptimizationRecord&& runtime_optimization_record); -#endif + +#endif // !defined(ORT_MINIMAL_BUILD) std::vector RemoveRecordsForOptimizer(const std::string& optimizer_name); diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index 48e0a5413ca11..e9d9c7de45eb7 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -66,7 +66,7 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); if (domain != kMSDomain) { - SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", kMSDomain); + SwapNodeOpTypeDomainAndSinceVersion(*api_graph, *node, "QLinearConv", kMSDomain, 1); } modified = true; diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index b858e79bfb34d..a9c11604a65f9 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -1,19 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/optimizer/optimizer_execution_frame.h" + #include "core/common/common.h" -#include "core/common/status.h" #include "core/common/logging/logging.h" #include "core/common/logging/macros.h" +#include "core/common/status.h" +#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" -#include "core/framework/tensorprotoutils.h" #include "core/framework/data_types.h" -#include "core/framework/mldata_type_utils.h" -#include "core/framework/kernel_registry.h" #include "core/framework/fuse_nodes_funcs.h" -#include "core/framework/callback.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/mldata_type_utils.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/TensorSeq.h" -#include "core/optimizer/optimizer_execution_frame.h" namespace onnxruntime { @@ -119,18 +122,40 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, node_index_info_ = std::make_unique(nodes, ort_value_name_idx_map_); } -Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const{ +Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const { std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); - return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), out); + const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; + return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out); +} + +static Status TryCreateKernel(const Node& node, + const KernelRegistry& kernel_registry, + const IExecutionProvider& execution_provider, + const std::unordered_map& constant_initialized_tensors, + const OrtValueNameIdxMap& ort_value_name_idx_map, + FuncManager& funcs_mgr, + const DataTransferManager& data_transfer_mgr, + /*out*/ std::unique_ptr& op_kernel) { + const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; + const KernelCreateInfo* kernel_create_info = nullptr; + ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, + &kernel_create_info)); + OpKernelInfo kernel_info(node, + *kernel_create_info->kernel_def, + execution_provider, + constant_initialized_tensors, + ort_value_name_idx_map, + data_transfer_mgr); + return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel); } std::unique_ptr OptimizerExecutionFrame::Info::CreateKernel(const Node* node) const { std::unique_ptr op_kernel; std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; - auto status = kernel_registry->TryCreateKernel(*node, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, - op_kernel); + auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, + ort_value_name_idx_map_, func, data_transfer_mgr_, + op_kernel); // Kernel found in the CPU kernel registry if (status.IsOK()) diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index 8c772f7a91d66..61d4388aa9905 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -125,14 +125,6 @@ static bool IsRelevantOutput(const Node* node, const NodeArg* output) { return true; } -namespace { -// borrowed from providers/common.h -template -inline bool Contains(const AssociativeContainer& container, const Key& key) { - return container.find(key) != container.end(); -} -} // namespace - // Check whether the given opcode is fp16 allowed for the given level of optimization. static bool IsFP16Allow(const std::string& op_type, size_t level, const FP16AllowOps& fp16_allow_level0_ops) { // XXX: Shall we add a check for unsupported level or just ignore it as the current code does? diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 849ad875212ec..d7407fcf20a80 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -139,7 +139,7 @@ struct SetOptionalZeroPoint { }; void SetOptionalZeroPoint::UpdateNodes(Graph& graph, const NodesToOptimize& selected_nodes) { - std::vector nodes = selected_nodes.AllNodes(); + const auto nodes = selected_nodes.AllNodes(); for (Node* node_ptr : nodes) { if (node_ptr == nullptr) { continue; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 4972d749bfd41..6b588314c4f16 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -85,8 +85,11 @@ std::optional BaseSelector::Select(const GraphViewer& gr } NodesToOptimizeIndicesBuilder builder; - builder.input_nodes = qdq_group->dq_nodes; - builder.output_nodes = qdq_group->q_nodes; + // TODO(edgchen1) update NodeGroup to use InlinedVector + builder.input_nodes.assign(qdq_group->dq_nodes.begin(), qdq_group->dq_nodes.end()); + builder.output_nodes.assign(qdq_group->q_nodes.begin(), qdq_group->q_nodes.end()); + //builder.input_nodes = qdq_group->dq_nodes; + //builder.output_nodes = qdq_group->q_nodes; builder.target_node = qdq_group->target_node; UpdateBuilder(builder); diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index 0075d55c33393..2229a280a976d 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -112,9 +112,9 @@ Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) #if !defined(ORT_MINIMAL_BUILD) Status ReplaceWithNew::RunForSave(Graph& graph, const NodesToOptimize& selected_nodes, - const SatRuntimeOptimizationSaveContext& save_context, + const SatRuntimeOptimizationSaveContext& /*save_context*/, SavedState& saved_state, bool& graph_modified) const { - // make temporary node, use it to look up kernel def hash, remove temporary node + // make temporary node, save its op schema, remove temporary node const RuntimeState runtime_state{graph, selected_nodes}; Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, @@ -125,12 +125,7 @@ Status ReplaceWithNew::RunForSave(Graph& graph, const NodesToOptimize& selected_ /* only_update_dest_definitions */ true, &replacement)); ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(*replacement), "Failed to set node op schema."); - - const KernelCreateInfo* kernel_create_info{}; - ORT_RETURN_IF_ERROR(save_context.kernel_registry_manager.get().SearchKernelRegistry(*replacement, - &kernel_create_info)); - const auto replacement_kernel_def_hash = kernel_create_info->kernel_def->GetHash(); - saved_state.produced_nodes.push_back({replacement->Index(), replacement_kernel_def_hash}); + saved_state.produced_node_op_schemas.push_back(replacement->Op()); ORT_RETURN_IF_NOT(graph.RemoveNode(replacement->Index()), "Failed to remove node."); diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 8bfec9489912e..4bd4f6cadfd6d 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -5,12 +5,20 @@ #include +#include "gsl/gsl" + #include "core/common/common.h" #include "core/graph/graph_utils.h" // TODO: Minimize usage of this given we want to use Actions in a minimal build #include "core/graph/runtime_optimization_record.h" #include "core/optimizer/selectors_actions/helpers.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" +#if !defined(ORT_MINIMAL_BUILD) +namespace ONNX_NAMESPACE { +class OpSchema; +} // namespace ONNX_NAMESPACE +#endif // !defined(ORT_MINIMAL_BUILD) + namespace onnxruntime { class Graph; @@ -25,7 +33,7 @@ struct Action { #if !defined(ORT_MINIMAL_BUILD) // per-action saved state struct SavedState { - std::vector produced_nodes; + std::vector> produced_node_op_schemas; }; // saving interface diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.cc b/onnxruntime/core/optimizer/selectors_actions/helpers.cc index 41444ba8000b4..9eb03badb80b9 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.cc +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/optimizer/selectors_actions/helpers.h" + #include "core/optimizer/selectors_actions/actions.h" using namespace ONNX_NAMESPACE; @@ -126,7 +128,7 @@ Node* GetNodeByNodeIndex(Graph& graph, NodeIndex idx, bool& missing) { return node; } -bool GetNodesByNodeIndex(Graph& graph, const std::vector& indices, std::vector& nodes) { +bool GetNodesByNodeIndex(Graph& graph, gsl::span indices, InlinedVector& nodes) { nodes.reserve(indices.size()); bool missing = false; @@ -150,7 +152,7 @@ bool GetNodesByNodeIndex(Graph& graph, const std::vector& indices, st // Helper to create the NodesToOptimizeIndices // specify num_input_defs/num_output_defs if the last input/output is variadic (default is non-variadic) static NodesToOptimizeIndices GetNodesToOptimizeIndices( - const std::vector& input_nodes, NodeIndex target_node, const std::vector& output_nodes, + gsl::span input_nodes, NodeIndex target_node, gsl::span output_nodes, int num_input_defs, int num_output_defs) { size_t num_inputs = num_input_defs == -1 ? input_nodes.size() : static_cast(num_input_defs); size_t num_outputs = num_output_defs == -1 ? output_nodes.size() : static_cast(num_output_defs); @@ -169,7 +171,7 @@ static NodesToOptimizeIndices GetNodesToOptimizeIndices( num_variadic_outputs = gsl::narrow_cast(output_nodes.size()) - num_output_defs + 1; } - std::vector node_indices; + InlinedVector node_indices; node_indices.reserve(NumIOEntries(variadic_input, num_inputs, num_variadic_inputs) + 1 + NumIOEntries(variadic_output, num_outputs, num_variadic_outputs)); std::copy(input_nodes.begin(), input_nodes.end(), std::back_inserter(node_indices)); @@ -191,9 +193,9 @@ NodesToOptimizeIndices NodesToOptimizeIndicesBuilder::Build() const { return GetNodesToOptimizeIndices(input_nodes, target_node, output_nodes, num_input_defs, num_output_defs); } -NodesToOptimize::NodesToOptimize(const std::vector& input_nodes, +NodesToOptimize::NodesToOptimize(gsl::span input_nodes, Node& target_node, - const std::vector& output_nodes, + gsl::span output_nodes, int num_input_defs, int num_output_defs) : num_inputs{num_input_defs == -1 ? gsl::narrow_cast(input_nodes.size()) : num_input_defs}, num_outputs{num_output_defs == -1 ? gsl::narrow_cast(output_nodes.size()) : num_output_defs} { @@ -228,7 +230,7 @@ NodesToOptimize::NodesToOptimize(Graph& graph, } NodesToOptimizeIndices NodesToOptimize::ToIndices() const { - std::vector node_indices; + InlinedVector node_indices; node_indices.reserve(nodes_.size()); std::for_each(nodes_.cbegin(), nodes_.cend(), [&node_indices](const Node* node) { const NodeIndex node_idx = node != nullptr ? node->Index() : NodesToOptimizeIndices::kEmptyNodeIndex; @@ -242,8 +244,8 @@ NodesToOptimizeIndices NodesToOptimize::ToIndices() const { num_variadic_inputs_, num_variadic_outputs_}; } -std::vector NodesToOptimize::Inputs(const std::vector& indices, bool required) const { - std::vector results; +InlinedVector NodesToOptimize::Inputs(gsl::span indices, bool required) const { + InlinedVector results; results.reserve(NumInputEntries()); for (auto idx : indices) { @@ -259,8 +261,8 @@ std::vector NodesToOptimize::Inputs(const std::vector& indices, bool return results; } -std::vector NodesToOptimize::Outputs(const std::vector& indices, bool required) const { - std::vector results; +InlinedVector NodesToOptimize::Outputs(const std::vector& indices, bool required) const { + InlinedVector results; results.reserve(NumOutputEntries()); // offset by all the inputs and the target node @@ -279,13 +281,14 @@ std::vector NodesToOptimize::Outputs(const std::vector& indices, boo return results; } -std::vector NodesToOptimize::GetNodesAtLocation(const NodeLocation& location, bool required) const { +InlinedVector NodesToOptimize::GetNodesAtLocation(const NodeLocation& location, bool required) const { if (location.type == NodeType::kInput) { return Inputs({location.index}, required); } else if (location.type == NodeType::kOutput) { return Outputs({location.index}, required); - } else + } else { return {&Target()}; + } }; size_t NodesToOptimize::NumInputEntries() const { diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.h b/onnxruntime/core/optimizer/selectors_actions/helpers.h index 4a81cfa8a2219..b72996e129e98 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.h +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.h @@ -3,7 +3,11 @@ #pragma once +#include "gsl/gsl" + #include "core/common/basic_types.h" +#include "core/common/inlined_containers.h" +#include "core/graph/graph.h" #include "core/graph/runtime_optimization_record.h" namespace onnxruntime { @@ -37,9 +41,9 @@ class NodesToOptimize { // nodes to assemble. num_inputs and num_outputs default to the size of input_nodes and output_nodes. // specify num_input_defs/num_output_defs if the last input/output is variadic - NodesToOptimize(const std::vector& input_nodes, + NodesToOptimize(gsl::span input_nodes, Node& target_node, - const std::vector& output_nodes, + gsl::span output_nodes, int num_input_defs = -1, int num_output_defs = -1); // construct from saved NodeIndex values. IsValid() will return false if one or more nodes were missing. @@ -81,7 +85,7 @@ class NodesToOptimize { } // inputs filtered by index. includes all variadic. - std::vector Inputs(const std::vector& indices, bool required = true) const; + InlinedVector Inputs(gsl::span indices, bool required = true) const; Node& Target() const { return *GetNode(NumInputEntries() + 0, /*required*/ true); @@ -92,12 +96,12 @@ class NodesToOptimize { } // outputs filtered by index. includes all variadic. - std::vector Outputs(const std::vector& indices, bool required = true) const; + InlinedVector Outputs(const std::vector& indices, bool required = true) const; // Get the Node or Nodes (if variadic) at a specific index. - std::vector GetNodesAtLocation(const NodeLocation& location, bool required = true) const; + InlinedVector GetNodesAtLocation(const NodeLocation& location, bool required = true) const; - const std::vector& AllNodes() const { return nodes_; } + gsl::span AllNodes() const { return nodes_; } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(NodesToOptimize); @@ -117,15 +121,15 @@ class NodesToOptimize { bool variadic_output_{false}; int num_variadic_inputs_{0}; // how many values does the variadic input have. can be zero or more. int num_variadic_outputs_{0}; - std::vector nodes_; + InlinedVector nodes_; }; // Helper to build a NodesToOptimizeIndices instance // Use in selector to incrementally add pieces struct NodesToOptimizeIndicesBuilder { - std::vector input_nodes; + InlinedVector input_nodes; NodeIndex target_node{NodesToOptimizeIndices::kEmptyNodeIndex}; - std::vector output_nodes; + InlinedVector output_nodes; int num_input_defs{-1}; int num_output_defs{-1}; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index 85d2b0ebe085d..540e0e92d30cd 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -8,6 +8,7 @@ #include #include +#include "core/graph/op_identifier_utils.h" #include "core/graph/runtime_optimization_record_container.h" namespace onnxruntime { @@ -146,11 +147,31 @@ static Status MatchAndProcess( break; } - graph.MutableRuntimeOptimizations().AddRecord( - transformer_name, - RuntimeOptimizationRecord{selector_action_entry.name, - node_selection, - action_saved_state.produced_nodes}); + RuntimeOptimizationRecord::ProducedOpIdVector produced_op_ids{}; + produced_op_ids.reserve(action_saved_state.produced_node_op_schemas.size()); + + for (const auto op_schema : action_saved_state.produced_node_op_schemas) { + produced_op_ids.push_back(utils::MakeOpId(*op_schema)); + if (save_context->record_produced_node_op_schema) { + status = save_context->record_produced_node_op_schema(*op_schema); + if (!status.IsOK()) { + break; + } + } + } + + // handle break out of above for loop on error + if (!status.IsOK()) { + break; + } + + RuntimeOptimizationRecord runtime_optimization_record{selector_action_entry.name, + node_selection, + std::move(produced_op_ids)}; + + graph.MutableRuntimeOptimizations().AddRecord(transformer_name, + std::move(runtime_optimization_record)); + } else { status = action.Run(graph, node_group); if (!status.IsOK()) { @@ -191,21 +212,20 @@ Status SelectorActionTransformer::ApplySelectorsAndActions( #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) -static Status RegisterProducedNodesWithGraph(NodeIndex pre_action_max_num_nodes, NodeIndex post_action_max_num_nodes, - const RuntimeOptimizationRecord& record, - Graph& graph) { +static Status SetOpSinceVersionForProducedNodes(NodeIndex pre_action_max_num_nodes, + NodeIndex post_action_max_num_nodes, + const RuntimeOptimizationRecord& record, + Graph& graph) { assert(post_action_max_num_nodes >= pre_action_max_num_nodes); const auto num_new_node_indices = post_action_max_num_nodes - pre_action_max_num_nodes; - auto produced_node_it = record.produced_nodes.begin(); - const auto produced_nodes_end = record.produced_nodes.end(); - - std::unordered_map node_index_to_kernel_def_hash{}; + auto produced_op_id_it = record.produced_op_ids.begin(); + const auto produced_op_ids_end = record.produced_op_ids.end(); for (NodeIndex i = 0; i < num_new_node_indices; ++i) { const NodeIndex new_node_idx = pre_action_max_num_nodes + i; - const auto* new_node = graph.GetNode(new_node_idx); + auto* new_node = graph.GetNode(new_node_idx); // only account for new nodes that still exist // an action could add a temporary node and then remove it @@ -213,18 +233,23 @@ static Status RegisterProducedNodesWithGraph(NodeIndex pre_action_max_num_nodes, continue; } - ORT_RETURN_IF(produced_node_it == produced_nodes_end, + ORT_RETURN_IF(produced_op_id_it == produced_op_ids_end, "Not enough produced nodes in the runtime optimization record."); - node_index_to_kernel_def_hash.emplace(new_node_idx, produced_node_it->kernel_def_hash); + ORT_RETURN_IF(new_node->Domain() != produced_op_id_it->domain || + new_node->OpType() != produced_op_id_it->op_type, + "New node op (", new_node->Domain(), ':', new_node->OpType(), + ") does not match produced node op in runtime optimization record (", + produced_op_id_it->domain, ':', produced_op_id_it->op_type, ")."); - ++produced_node_it; - } + assert(new_node->SinceVersion() == -1); - ORT_RETURN_IF(produced_node_it != produced_nodes_end, "Too many produced nodes in the runtime optimization record."); + new_node->SetSinceVersion(produced_op_id_it->since_version); - graph.MutableRuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash.merge( - node_index_to_kernel_def_hash); + ++produced_op_id_it; + } + + ORT_RETURN_IF(produced_op_id_it != produced_op_ids_end, "Too many produced nodes in the runtime optimization record."); return Status::OK(); } @@ -261,10 +286,8 @@ Status SelectorActionTransformer::ApplySavedRuntimeOptimizations( const NodeIndex post_action_num_nodes = graph.MaxNodeIndex(); - ORT_RETURN_IF_ERROR(RegisterProducedNodesWithGraph(pre_action_num_nodes, post_action_num_nodes, - record, graph)); - - ++graph.MutableRuntimeOptimizationReplayCtx().num_replayed_optimizations; + ORT_RETURN_IF_ERROR(SetOpSinceVersionForProducedNodes(pre_action_num_nodes, post_action_num_nodes, + record, graph)); } return Status::OK(); diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h index 51062e7ebd0ca..9ae0d6023dfa7 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h @@ -3,10 +3,20 @@ #pragma once -#include #include -#include "core/framework/kernel_registry_manager.h" +#if !defined(ORT_MINIMAL_BUILD) +#include + +#include "core/common/status.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +#if !defined(ORT_MINIMAL_BUILD) +using onnxruntime::common::Status; +namespace ONNX_NAMESPACE { +class OpSchema; +} +#endif // !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { @@ -24,7 +34,9 @@ struct SatDirectApplicationContext { // Context to save runtime optimizations for later replay. struct SatRuntimeOptimizationSaveContext { - std::reference_wrapper kernel_registry_manager; +#if !defined(ORT_MINIMAL_BUILD) + std::function record_produced_node_op_schema; +#endif // !defined(ORT_MINIMAL_BUILD) }; // Context to load runtime optimizations and replay them. diff --git a/onnxruntime/core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h b/onnxruntime/core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h new file mode 100644 index 0000000000000..7d126235a6e88 --- /dev/null +++ b/onnxruntime/core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Note: This file should be self-contained, i.e., no implementation in a .cc file. +// This is to allow it to be used from code that doesn't otherwise have any dependencies on the ORT optimizers. + +#pragma once + +#include + +#include "core/graph/constants.h" +#include "core/graph/op_identifier.h" + +namespace onnxruntime { + +// This is a list of ops and their versions which layout transformations can potentially add to the graph. +// This is needed in minimal build since opschema is not available. +inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { + // Note: these region_begin/end markers are used by tools/ci_build/reduce_op_kernels.py + // @@region_begin(extended_minimal_build_required_kernels)@@ + + // kOnnxDomain ops + OpIdentifierWithStringViews{kOnnxDomain, "Gather", 1}, + OpIdentifierWithStringViews{kOnnxDomain, "Gather", 11}, + OpIdentifierWithStringViews{kOnnxDomain, "Gather", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "Identity", 1}, + OpIdentifierWithStringViews{kOnnxDomain, "Identity", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "Identity", 14}, + OpIdentifierWithStringViews{kOnnxDomain, "Identity", 16}, + OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 1}, + OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 11}, + OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 1}, + OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 1}, + OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 11}, + OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 13}, + +#if !defined(DISABLE_CONTRIB_OPS) + // kMSDomain ops + OpIdentifierWithStringViews{kMSDomain, "NhwcMaxPool", 1}, + OpIdentifierWithStringViews{kMSDomain, "QLinearConv", 1}, +#endif // !defined(DISABLE_CONTRIB_OPS) + + // @@region_end(extended_minimal_build_required_kernels)@@ +}; + +namespace detail { +// std::is_sorted is not constexpr in C++17, so use our own constexpr version for now +template +constexpr bool IsSorted(It begin, It end, Compare cmp) { + if (begin == end) return true; + It curr = begin, next = begin; + while (++next != end) { + if (cmp(*next, *curr)) return false; + curr = next; + } + return true; +} +} // namespace detail + +static_assert(detail::IsSorted(kLayoutTransformationPotentiallyAddedOps.begin(), + kLayoutTransformationPotentiallyAddedOps.end(), + std::less{}), + "kLayoutTransformationPotentiallyAddedOps entries must be in sorted order."); + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h index 63269aa175057..d925d1d38cfa2 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h @@ -359,8 +359,11 @@ class GraphRef { /// /// The new node's op type /// The new node's domain. Empty string signifies default onnx domain. + /// The new node's since_version. If unspecified, use that of the old node. /// The new node - virtual std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain = "") = 0; + virtual std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, + std::string_view domain = "", + std::optional since_version = std::nullopt) = 0; /// /// Deletes a node from the graph. Behavior is undefined if node has any consumers. @@ -532,9 +535,10 @@ std::vector ChannelFirstToLastPerm(size_t rank); std::vector ChannelLastToFirstPerm(size_t rank); /// -/// Swaps out a node for a new copy of that node with the specified op type and domain. Current API does not all nodes -/// to have their op types or domains changed, so a new node is needed. All attributes, inputs, and outputs are moved -/// to the new node. The old node is removed from the graph and should no longer be accessed. +/// Swaps out a node for a new copy of that node with the specified op type and domain. +/// Current API does not allow nodes to have their op types or domains changed, so a new node is needed. All +/// attributes, inputs, and outputs are moved to the new node. The old node is removed from the graph and should no +/// longer be accessed. /// /// Graph containing the node /// Node to copy and remove @@ -544,4 +548,20 @@ std::vector ChannelLastToFirstPerm(size_t rank); std::unique_ptr SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node, std::string_view op_type, std::string_view domain); +/// +/// Swaps out a node for a new copy of that node with the specified op type, domain, and since version. +/// Current API does not allow nodes to have their op types or domains changed, so a new node is needed. All +/// attributes, inputs, and outputs are moved to the new node. The old node is removed from the graph and should no +/// longer be accessed. +/// +/// Graph containing the node +/// Node to copy and remove +/// New node op_type +/// New node domain. "" for the default domain. +/// New node since version. +/// The newly created node. +std::unique_ptr SwapNodeOpTypeDomainAndSinceVersion(api::GraphRef& graph, api::NodeRef& node, + std::string_view op_type, std::string_view domain, + int since_version); + } // namespace onnx_layout_transformation diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index a002d108f865c..1fcfa33010693 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -3,11 +3,17 @@ #include "optimizer_api.h" #include "optimizer_utils.h" + +#include #include -#include "core/graph/graph_utils.h" -#include "core/framework/tensorprotoutils.h" +#include +#include + #include "core/framework/execution_provider.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h" #include "core/providers/cpu/tensor/transpose.h" using namespace ONNX_NAMESPACE; @@ -118,7 +124,8 @@ class ApiGraph final : public api::GraphRef { size_t num_outputs = 1, std::string_view domain = "") override; std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, - std::string_view domain = "") override; + std::string_view domain = "", + std::optional since_version = std::nullopt) override; void RemoveNode(api::NodeRef& node) override; void RemoveInitializer(std::string_view name) override; std::string_view AddInitializer(api::DataType dtype, const std::vector& shape, @@ -642,40 +649,56 @@ static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_typ return node; } -// This is a list of onnx ops and their versions which transpose_optimizer can potentially add to the graph. -// This is needed in minimal build since opschema is not available. -// The versions MUST be sorted due to how the model opset is matched with the most recent operator version. -static const std::unordered_map> onnx_ops_available_versions = { - {"Squeeze", {1, 11, 13}}, - {"Unsqueeze", {1, 11, 13}}, - {"Gather", {1, 11, 13}}, - {"Transpose", {1, 13}}, - {"Identity", {1, 13, 14, 16}}, -}; +static std::optional GetLayoutTransformationPotentiallyAddedOpSinceVersion( + std::string_view domain, std::string_view op_type, int opset_version) { + auto compare_ignoring_since_version = [](const OpIdentifierWithStringViews& a, const OpIdentifierWithStringViews& b) { + if (a.domain == b.domain) { + return a.op_type < b.op_type; + } + return a.domain < b.domain; + }; + + const auto [range_begin, range_end] = + std::equal_range(kLayoutTransformationPotentiallyAddedOps.begin(), + kLayoutTransformationPotentiallyAddedOps.end(), + OpIdentifierWithStringViews{domain, op_type, 0}, + compare_ignoring_since_version); + + // versions are in increasing order + // search backwards for largest since version <= opset_version + const auto range_rbegin = std::make_reverse_iterator(range_end), + range_rend = std::make_reverse_iterator(range_begin); + + const auto result = + std::find_if(range_rbegin, range_rend, + [&opset_version](const OpIdentifierWithStringViews& a) { + return a.since_version <= opset_version; + }); + + if (result != range_rend) { + return result->since_version; + } + + return std::nullopt; +} // Based on the opset version imported for this model, returns the since version for the node. static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain, const std::unordered_map& domain_to_version_map) { - int since_version = -1; + // TODO do we need this check? we will also check kLayoutTransformationPotentiallyAddedOps ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ", domain, " provided for op: ", op_type); - auto opset_import_iter = domain_to_version_map.find(std::string(domain)); - ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), "Onnx domain not found in opset imports."); + const auto opset_import_iter = domain_to_version_map.find(std::string(domain)); + ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), domain, " domain not found in opset imports."); - int opset_version = opset_import_iter->second; - auto iter = onnx_ops_available_versions.find(std::string(op_type)); - ORT_ENFORCE(iter != onnx_ops_available_versions.end(), + const int opset_version = opset_import_iter->second; + const auto since_version = GetLayoutTransformationPotentiallyAddedOpSinceVersion(domain, op_type, opset_version); + ORT_ENFORCE(since_version.has_value(), "Transpose Optimizer is adding an unexpected node: ", op_type, - "An entry for this node should be added in onnx_ops_available_versions and static_kernel_hashes map."); - - for (auto version : iter->second) { - if (version <= opset_version) { - since_version = version; - } - } + "An entry for this node should be added in kLayoutTransformationPotentiallyAddedOps."); - return since_version; + return *since_version; } std::unique_ptr ApiGraph::AddNode(std::string_view op_type, @@ -689,9 +712,10 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, } std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, - std::string_view domain) { + std::string_view domain, std::optional since_version) { + const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion(); Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), - source_node.Outputs().size(), domain, source_node.SinceVersion(), + source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); std::unique_ptr new_node = std::make_unique(node, graph_); diff --git a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc index 2cac02033e90f..3fc0070127670 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc @@ -1607,7 +1607,7 @@ static bool HandleMaxPool(HandlerArgs& args) { return false; } - auto new_node = SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, "NhwcMaxPool", "com.microsoft"); + auto new_node = SwapNodeOpTypeDomainAndSinceVersion(args.ctx.graph, args.node, "NhwcMaxPool", "com.microsoft", 1); new_node->ClearAttribute("storage_order"); // Only relevant for indices output. Prohibited for NhwcMaxPool. TransposeFirstInput(args.ctx, *new_node, args.perm_inv); TransposeOutputs(args.ctx, *new_node, args.perm); @@ -2074,10 +2074,11 @@ void WrapTransposesAroundNode(api::GraphRef& graph, api::NodeRef& node, } } -std::unique_ptr SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node, - std::string_view op_type, std::string_view domain) { +static std::unique_ptr SwapNodeImpl(api::GraphRef& graph, api::NodeRef& node, + std::string_view op_type, std::string_view domain, + std::optional since_version) { auto outputs = node.Outputs(); - auto new_node = graph.CopyNode(node, op_type, domain); + auto new_node = graph.CopyNode(node, op_type, domain, since_version); for (size_t j = 0; j < outputs.size(); ++j) { if (outputs[j] != "") { @@ -2088,4 +2089,15 @@ std::unique_ptr SwapNodeOpTypeAndDomain(api::GraphRef& graph, api: return new_node; } +std::unique_ptr SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node, + std::string_view op_type, std::string_view domain) { + return SwapNodeImpl(graph, node, op_type, domain, std::nullopt); +} + +std::unique_ptr SwapNodeOpTypeDomainAndSinceVersion(api::GraphRef& graph, api::NodeRef& node, + std::string_view op_type, std::string_view domain, + int since_version) { + return SwapNodeImpl(graph, node, op_type, domain, since_version); +} + } // namespace onnx_layout_transformation diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index 2a17d8b05f170..6c9010559481d 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -119,13 +119,4 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const return kernel_registry; } -std::vector> -ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { - std::vector> - result = IExecutionProvider::GetCapability(graph, kernel_registries); - - return result; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index c85a6899b18f6..587d28e47ce50 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -26,10 +26,6 @@ class ACLExecutionProvider : public IExecutionProvider { explicit ACLExecutionProvider(const ACLExecutionProviderInfo& info); virtual ~ACLExecutionProvider(); - std::vector> GetCapability( - const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const override; - const void* GetExecutionHandle() const noexcept override { // The ACL interface does not return anything interesting. return nullptr; diff --git a/onnxruntime/core/providers/armnn/armnn_execution_provider.cc b/onnxruntime/core/providers/armnn/armnn_execution_provider.cc index 60d2ef2fe6a22..fdbf3712ccbc0 100644 --- a/onnxruntime/core/providers/armnn/armnn_execution_provider.cc +++ b/onnxruntime/core/providers/armnn/armnn_execution_provider.cc @@ -125,13 +125,4 @@ std::shared_ptr ArmNNExecutionProvider::GetKernelRegistry() cons return kernel_registry; } -std::vector> -ArmNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { - std::vector> - result = IExecutionProvider::GetCapability(graph, kernel_registries); - - return result; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/armnn/armnn_execution_provider.h b/onnxruntime/core/providers/armnn/armnn_execution_provider.h index d16bacc5c594b..5728ec906a114 100755 --- a/onnxruntime/core/providers/armnn/armnn_execution_provider.h +++ b/onnxruntime/core/providers/armnn/armnn_execution_provider.h @@ -26,10 +26,6 @@ class ArmNNExecutionProvider : public IExecutionProvider { explicit ArmNNExecutionProvider(const ArmNNExecutionProviderInfo& info); virtual ~ArmNNExecutionProvider(); - std::vector> GetCapability( - const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const override; - const void* GetExecutionHandle() const noexcept override { // The ArmNN interface does not return anything interesting. return nullptr; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index dbaab52d116a7..c3d29d70d3a88 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -43,7 +43,7 @@ CoreMLExecutionProvider::~CoreMLExecutionProvider() {} std::vector> CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 6977dfdc1ff9b..d1ecd32207a8e 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -18,7 +18,7 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index a3955018832d2..8f0c55f0228e9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -790,8 +790,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, ST // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** - If kernel registrations are incorrectly updated, ORT format models get broken as the kernel hashes may be invalidated. - NEVER update a versioned entry to change the start or end version. These MUST be treated as immutable. i.e. if the macro has 'VERSIONED' in it, do not modify that entry diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.cc b/onnxruntime/core/providers/cpu/nn/batch_norm.cc index 46e812efb48fa..737371f3883e2 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.cc @@ -42,10 +42,6 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, float, KernelDefBuilder() .Alias(3, 1) .Alias(4, 2) - // ORT 1.8 was shipped with just the "T" type constraint and - // we want to maintain backwards compatibility for - // the hash and hence just use "T" for the hash generation - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}) .TypeConstraint("T", DataTypeImpl::GetTensorType()) .TypeConstraint("U", DataTypeImpl::GetTensorType()), BatchNorm); @@ -54,10 +50,6 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, double, KernelDefBuilder() .Alias(3, 1) .Alias(4, 2) - // ORT 1.8 was shipped with just the "T" type constraint and - // we want to maintain backwards compatibility for - // the hash and hence just use "T" for the hash generation - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}) .TypeConstraint("T", DataTypeImpl::GetTensorType()) .TypeConstraint("U", DataTypeImpl::GetTensorType()), BatchNorm); diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc index 90daf96ae7c09..f63c91b374d27 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc @@ -18,8 +18,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 12, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}), SpaceToDepth); ONNX_CPU_OPERATOR_KERNEL( @@ -27,8 +26,7 @@ ONNX_CPU_OPERATOR_KERNEL( 13, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}), SpaceToDepth); ONNX_CPU_OPERATOR_VERSIONED_KERNEL( @@ -36,8 +34,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 1, 10, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}), DepthToSpace); ONNX_CPU_OPERATOR_VERSIONED_KERNEL( @@ -46,8 +43,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 12, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}), DepthToSpace); ONNX_CPU_OPERATOR_KERNEL( @@ -55,8 +51,7 @@ ONNX_CPU_OPERATOR_KERNEL( 13, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}), DepthToSpace); // intermediate tensor shapes are: diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 9d09f112b81f0..7dc5b0d4fde24 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -27,18 +27,13 @@ using SplitDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( using EnabledSplitDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Split, Input, 0); -using OldSplitDataTypes = onnxruntime::TypeList; - ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Split, 2, 10, KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList(), - BuildKernelDefConstraintsFromTypeList()) - .FixedTypeConstraintForHash( - "T", - BuildKernelDefConstraintsFromTypeList()), + BuildKernelDefConstraintsFromTypeList()), Split); // Opset 11 starts to support Neg Axis. @@ -48,10 +43,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 12, KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList(), - BuildKernelDefConstraintsFromTypeList()) - .FixedTypeConstraintForHash( - "T", - BuildKernelDefConstraintsFromTypeList()), + BuildKernelDefConstraintsFromTypeList()), Split); // Opset 13 starts to supports 'split' as optional input. @@ -60,10 +52,7 @@ ONNX_CPU_OPERATOR_KERNEL( 13, KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList(), - BuildKernelDefConstraintsFromTypeList()) - .FixedTypeConstraintForHash( - "T", - BuildKernelDefConstraintsFromTypeList()), + BuildKernelDefConstraintsFromTypeList()), Split); Status SplitBase::PrepareForCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc old mode 100755 new mode 100644 index 24e55a884efab..34ae76a0a60d3 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2400,7 +2400,7 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { + const IKernelLookup& kernel_lookup) const { InlinedVector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); @@ -2408,19 +2408,11 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - const KernelCreateInfo* cuda_kernel_def = nullptr; if (!node.GetExecutionProviderType().empty()) { continue; } - for (auto registry : kernel_registries) { - auto st = registry->TryFindKernel(node, Type(), &cuda_kernel_def); - - // at least one registry has a CUDA kernel for this node - if (st.IsOK()) - break; - } - + const KernelCreateInfo* cuda_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a CUDA kernel for this node if (cuda_kernel_def == nullptr) { LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); @@ -2462,7 +2454,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, Type(), kernel_registries, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index fdbd7c4a23888..f420ab98ce267 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -97,7 +97,7 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const override; + const IKernelLookup& kernel_lookup) const override; int GetDeviceId() const override { return info_.device_id; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index b9b00d91f430e..8d6b50592f57a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -38,7 +38,7 @@ namespace Dml using namespace onnxruntime::common; ExecutionProvider::~ExecutionProvider() - { + { if (m_impl) { m_impl->Close(); @@ -53,7 +53,7 @@ namespace Dml Dml::RegisterDmlOperators(abiRegistry.Get()); assert(abiRegistry->GetRegistries().size() == 1); - + auto customRegistry = *abiRegistry->GetRegistries().begin(); *registry = customRegistry->GetKernelRegistry(); *internalRegInfoMap = abiRegistry->GetInternalRegInfoMap(); @@ -86,12 +86,12 @@ namespace Dml std::vector> ExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_registries); + return m_impl->GetCapability(graph, kernel_lookup); #else - return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_registries); + return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); #endif } @@ -100,16 +100,16 @@ namespace Dml m_context->Close(); } - void ExecutionProviderImpl::WaitForOutstandingWork() + void ExecutionProviderImpl::WaitForOutstandingWork() { Flush(); m_context->GetCurrentCompletionEvent().WaitForSignal(); } - + HRESULT __stdcall ExecutionProviderImpl::AllocatePooledResource( - size_t size, + size_t size, AllocatorRoundingMode roundingMode, - ID3D12Resource **d3dResource, + ID3D12Resource **d3dResource, IUnknown** pooledResource ) const noexcept { @@ -142,7 +142,7 @@ namespace Dml } // ORT release pipelines agent pools do not have 19H1 SDK installed which defines D3D_FEATURE_LEVEL_1_0_CORE. -// Once ORT/WinML github project can be built with VS2019, we can update these pools to use install the 19H1 SDK +// Once ORT/WinML github project can be built with VS2019, we can update these pools to use install the 19H1 SDK // using the command line installer tool with VS2019 // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) @@ -189,7 +189,7 @@ namespace Dml m_uploadHeap = std::make_unique(m_d3d12Device.Get(), m_context); m_readbackHeap = std::make_unique(m_d3d12Device.Get(), m_context); - + // CPU Allocator used to create buffers for the MemcpyFromHost operator. m_cpuInputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUInput); m_cpuOutputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUOutput); @@ -271,8 +271,8 @@ namespace Dml inputBufferArrayDesc.BindingCount = gsl::narrow_cast(inputBufferBindings.size()); inputBufferArrayDesc.Bindings = inputBufferBindings.data(); - DML_BINDING_DESC inputArrayBindingDesc = hasInputsToBind ? - DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER_ARRAY, &inputBufferArrayDesc } : + DML_BINDING_DESC inputArrayBindingDesc = hasInputsToBind ? + DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER_ARRAY, &inputBufferArrayDesc } : DML_BINDING_DESC{ DML_BINDING_TYPE_NONE, nullptr }; m_context->InitializeOperator( @@ -360,7 +360,7 @@ namespace Dml FillBindings(outputBufferBindings, outputBindings, outputTensors); ORT_THROW_IF_FAILED(ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings)); - + return S_OK; } ORT_CATCH_RETURN @@ -376,14 +376,14 @@ namespace Dml ORT_TRY { assert(!m_closed); - + DML_BINDING_DESC persistentResourceBindingDesc = persistentResourceBinding ? DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER, persistentResourceBinding } : DML_BINDING_DESC{ DML_BINDING_TYPE_NONE, nullptr }; m_context->ExecuteOperator( - op, + op, persistentResourceBindingDesc, inputTensors, outputTensors); @@ -420,11 +420,11 @@ namespace Dml if (src->IsCpuData() && !dst->IsCpuData()) { - // + // // CPU -> GPU copy (upload) - // + // const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get()); - + ID3D12Resource* dstData = dstAllocInfo->GetResource(); const void* srcData = src->GetData(); @@ -435,9 +435,9 @@ namespace Dml } else if (!src->IsCpuData() && dst->IsCpuData()) { - // + // // GPU -> CPU copy (readback) - // + // void* dstData = dst->GetData(); const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get()); @@ -452,9 +452,9 @@ namespace Dml } else if (!src->IsCpuData() && !dst->IsCpuData()) { - // + // // GPU -> GPU copy - // + // const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get()); const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get()); @@ -521,7 +521,7 @@ namespace Dml std::vector> ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& registries) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_"; uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); @@ -529,16 +529,16 @@ namespace Dml return PartitionGraph( graph, *m_internalRegInfoMap, - registries, + kernel_lookup, deviceDataTypeMask, m_kernelRegistry.get(), partitionKernelPrefix ); } - - bool IsGpuTensor(const onnxruntime::Tensor& tensor) + + bool IsGpuTensor(const onnxruntime::Tensor& tensor) { - return strcmp(tensor.Location().name, onnxruntime::CPU) && + return strcmp(tensor.Location().name, onnxruntime::CPU) && !(tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput); } @@ -549,14 +549,14 @@ namespace Dml auto provider = const_cast(this); TensorWrapper destInternal( - &dst, - IsGpuTensor(dst), + &dst, + IsGpuTensor(dst), provider, true); TensorWrapper srcInternal( - const_cast(&src), - IsGpuTensor(src), + const_cast(&src), + IsGpuTensor(src), provider, true); @@ -579,21 +579,21 @@ namespace Dml { // This batching implementation only handles GPU -> CPU copies. Other copies do not require synchronization // and are batched across multiple calls to CopyTensor. - if (!IsGpuTensor(src_dst_pairs[i].src) || IsGpuTensor(src_dst_pairs[i].dst)) + if (!IsGpuTensor(src_dst_pairs[i].src) || IsGpuTensor(src_dst_pairs[i].dst)) { ORT_RETURN_IF_ERROR(CopyTensor(src_dst_pairs[i].src, src_dst_pairs[i].dst)); continue; } - + TensorWrapper srcWrapper = TensorWrapper( - const_cast(&src_dst_pairs[i].src.get()), + const_cast(&src_dst_pairs[i].src.get()), true, provider, true); TensorWrapper dstWrapper = TensorWrapper( - &src_dst_pairs[i].dst.get(), - false, + &src_dst_pairs[i].dst.get(), + false, provider, true); @@ -619,7 +619,7 @@ namespace Dml // Performs a blocking call to synchronize and read back data from the GPU into the destination buffer m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState); - + return onnxruntime::common::Status::OK(); } @@ -639,7 +639,7 @@ namespace Dml m_context->ReleaseCompletedReferences(); } - void ExecutionProviderImpl::QueueReference(IUnknown* object) + void ExecutionProviderImpl::QueueReference(IUnknown* object) { assert(!m_closed); m_context->QueueReference(object); @@ -669,7 +669,7 @@ namespace Dml data->AddRef(); } else - { + { #ifdef _GAMING_XBOX ComPtr wrappedResource = Microsoft::WRL::Make(m_allocator->DecodeDataHandle(data)->GetResource()); *abiData = wrappedResource.Detach(); @@ -677,7 +677,7 @@ namespace Dml ComPtr resource = m_allocator->DecodeDataHandle(data)->GetResource(); *abiData = resource.Detach(); #endif - } + } } uint64_t ExecutionProviderImpl::TryGetPooledAllocationId( @@ -690,7 +690,7 @@ namespace Dml void ExecutionProviderImpl::GetABIExecutionInterfaceAndInvalidateState( bool isInternalOperator, - IUnknown** abiExecutionObject) const + IUnknown** abiExecutionObject) const { assert(!m_closed); @@ -709,9 +709,9 @@ namespace Dml #else *abiExecutionObject = commandList.Detach(); #endif - } + } } - + bool ExecutionProviderImpl::TransitionsRequiredForOperator( bool isInternalOperator ) @@ -725,7 +725,7 @@ namespace Dml bool isBeforeOp, uint32_t resourceCount, IUnknown** resources - ) + ) { std::vector barriers; barriers.reserve(resourceCount); @@ -738,8 +738,8 @@ namespace Dml // Custom operators receive resources in Common state and must return them to Common // state when finished. Resources are otherwise kept in UAV state (or are promotable to UAV). barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition( - resource.Get(), - isBeforeOp ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS : D3D12_RESOURCE_STATE_COMMON, + resource.Get(), + isBeforeOp ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS : D3D12_RESOURCE_STATE_COMMON, isBeforeOp ? D3D12_RESOURCE_STATE_COMMON : D3D12_RESOURCE_STATE_UNORDERED_ACCESS )); } @@ -754,7 +754,7 @@ namespace Dml { return m_context->GetCommandListTypeForQueue(); } - + bool __stdcall ExecutionProviderImpl::IsMcdmDevice() const noexcept { return m_isMcdmDevice; @@ -765,7 +765,7 @@ namespace Dml return m_areMetacommandsEnabled; } - std::shared_ptr + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { return m_internalRegInfoMap; @@ -786,8 +786,8 @@ namespace Dml return m_cpuOutputAllocator; } - - onnxruntime::common::Status ExecutionProviderImpl::OnSessionInitializationEnd() + + onnxruntime::common::Status ExecutionProviderImpl::OnSessionInitializationEnd() { // Flush and trim resources, including staging memory used to upload weights. // This reduces memory usage immediately after session creation, and avoids @@ -833,8 +833,8 @@ namespace Dml } onnxruntime::common::Status CopyTensor( - onnxruntime::IExecutionProvider* provider, - const onnxruntime::Tensor& src, + onnxruntime::IExecutionProvider* provider, + const onnxruntime::Tensor& src, onnxruntime::Tensor& dst ) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 8024657e13984..3fba9c755a3de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -83,7 +83,7 @@ namespace Dml std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& registries + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup ) const; uint32_t GetSupportedDeviceDataTypeMask() const; @@ -103,14 +103,14 @@ namespace Dml bool isInternalOperator, IUnknown* data, IUnknown** abiData) const override; - + uint64_t TryGetPooledAllocationId( IUnknown* data, bool isInternalOperator) override; void GetABIExecutionInterfaceAndInvalidateState( bool isInternalOperator, - IUnknown** abiExecutionObject) const override; + IUnknown** abiExecutionObject) const override; bool TransitionsRequiredForOperator( bool isInternalOperator @@ -127,27 +127,27 @@ namespace Dml void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); - // Waits for flushed work, discards unflushed work, and discards associated references to + // Waits for flushed work, discards unflushed work, and discards associated references to // prevent circular references. Must be the last call on the object before destruction. - void Close() override; + void Close() override; void WaitForOutstandingWork(); - + // Allocate a resource from pools. Releasing pooledResource returns it to the pool. STDMETHOD(AllocatePooledResource)( size_t size, AllocatorRoundingMode roundingMode, - ID3D12Resource **d3dResource, + ID3D12Resource **d3dResource, IUnknown* *pooledResource ) const noexcept final; - + STDMETHOD_(ID3D12Resource*, DecodeResource)(void* allocation) const noexcept final; std::shared_ptr GetKernelRegistry() const { return m_kernelRegistry; } - + STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; @@ -155,9 +155,9 @@ namespace Dml std::shared_ptr GetCpuInputAllocator(); std::shared_ptr GetCpuOutputAllocator(); - std::shared_ptr - GetInternalRegistrationInfoMap() const; - + std::shared_ptr + GetInternalRegistrationInfoMap() const; + onnxruntime::common::Status OnSessionInitializationEnd(); private: @@ -199,8 +199,8 @@ namespace Dml assert(exec_queue_id == 0); return m_impl->CopyTensor(src, dst); } - - onnxruntime::common::Status CopyTensors(const std::vector& src_dst_pairs) const + + onnxruntime::common::Status CopyTensors(const std::vector& src_dst_pairs) const { return m_impl->CopyTensors(src_dst_pairs); } @@ -226,7 +226,7 @@ namespace Dml ID3D12CommandQueue* commandQueue, bool enableMetacommands = true ); - + std::unique_ptr GetDataTransfer() const final override { return std::make_unique(m_impl.Get()); @@ -244,10 +244,10 @@ namespace Dml std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const final override; + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const final override; onnxruntime::common::Status OnSessionInitializationEnd() override - { + { return m_impl->OnSessionInitializationEnd(); } @@ -270,18 +270,18 @@ namespace Dml void Flush() { return m_impl->Flush(); - } - + } + void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode) { return m_impl->SetDefaultRoundingMode(roundingMode); } - + void ReleaseCompletedReferences() { return m_impl->ReleaseCompletedReferences(); } - + ExecutionProviderImpl* GetImpl() { return m_impl.Get(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 1698733bf9d3c..362e110fbd5bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -17,7 +17,7 @@ #include "GraphPartitioner.h" //#define PRINT_PARTITON_INFO - + using namespace Windows::AI::MachineLearning::Adapter; namespace Dml @@ -87,7 +87,7 @@ namespace Dml m_nodeIndices.push_back(index); } - + void GraphPartition::AddInput(const std::string& name) { assert(!IsFinalized()); @@ -117,7 +117,7 @@ namespace Dml assert(partitionToMerge->IsDmlGraphPartition() == IsDmlGraphPartition()); partitionToMerge->m_mergedPartition = this; - + m_nodeIndices.insert(m_nodeIndices.begin(), partitionToMerge->m_nodeIndices.begin(), partitionToMerge->m_nodeIndices.end()); m_inputs.insert(partitionToMerge->m_inputs.begin(), partitionToMerge->m_inputs.end()); m_outputs.insert(partitionToMerge->m_outputs.begin(), partitionToMerge->m_outputs.end()); @@ -126,12 +126,12 @@ namespace Dml // Adds the outputs of a node to the specified partition void AddNodeOutputsToPartitionMap( - const onnxruntime::Node& node, + const onnxruntime::Node& node, GraphPartition* partition, std::unordered_map& nodeNameToPartitionMap ) { - for (uint32_t i = 0; i < node.OutputDefs().size(); ++i) + for (uint32_t i = 0; i < node.OutputDefs().size(); ++i) { const auto* arg = node.OutputDefs()[i]; if (arg->Exists()) @@ -247,14 +247,13 @@ namespace Dml bool IsNodeSupportedByDml( const onnxruntime::Node& node, - const onnxruntime::KernelRegistry& registry, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const InternalRegistrationInfoMap& internalRegInfoMap ) { - const onnxruntime::KernelCreateInfo* createInfo; - Status st = registry.TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo); - if (!st.IsOK()) + const onnxruntime::KernelCreateInfo* createInfo = kernel_lookup.LookUpKernel(node); + if (!createInfo) { return false; } @@ -284,7 +283,7 @@ namespace Dml void GetRegistrationProperties( const onnxruntime::GraphViewer& graph, const onnxruntime::Node& node, - const std::vector& dmlRegistries, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const InternalRegistrationInfoMap& internalRegInfoMap, _In_opt_ const std::unordered_map* nodeNameToPartitionMap, @@ -299,66 +298,60 @@ namespace Dml // Find the highest priority DML registry supporting this node, and get its highest-priority // registration. Determine if that registration supports usage as a graph node. - for (auto registry : dmlRegistries) + + if (IsNodeSupportedByDml(node, kernel_lookup, supportedDeviceDataTypeMask, + internalRegInfoMap)) { - if (IsNodeSupportedByDml(node, *registry, supportedDeviceDataTypeMask, internalRegInfoMap)) - { - *isDmlNode = true; + *isDmlNode = true; - // Get the kernel creation info for the registration, and check if it carries the property - // set during registration of kernels that support DML graph node usage. - auto graphNodeProperty = dmlNodePropertyMap.insert(std::make_pair(&node, GraphNodeProperties())); + // Get the kernel creation info for the registration, and check if it carries the property + // set during registration of kernels that support DML graph node usage. + auto graphNodeProperty = dmlNodePropertyMap.insert(std::make_pair(&node, GraphNodeProperties())); - // Ensure that shape information is known statically for the inputs and outputs of the node, - // which is required for MLGraph compilation. - const onnxruntime::KernelCreateInfo* createInfo; - if (!registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo).IsOK()) - { - continue; - } + // Ensure that shape information is known statically for the inputs and outputs of the node, + // which is required for MLGraph compilation. + const onnxruntime::KernelCreateInfo* createInfo = kernel_lookup.LookUpKernel(node); + assert(createInfo != nullptr); // since IsNodeSupportedByDml() returned true - auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get()); - if (regInfoIter != internalRegInfoMap.end()) - { - auto internalRegInfo = regInfoIter->second; + auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get()); + if (regInfoIter != internalRegInfoMap.end()) + { + auto internalRegInfo = regInfoIter->second; - if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) + if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) + { + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) - { - continue; - } - - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; - } - - requiredInitializerMap.insert(inputName); + continue; } - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + requiredCpuInputsConstant = false; + break; } + + requiredInitializerMap.insert(inputName); } - } - break; + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } + } } } } @@ -366,7 +359,7 @@ namespace Dml // Creates a partition for a node which is not a DML graph node, and finalizes partitions // which are inputs of the new partition. std::unique_ptr CreateNonGraphNodePartitionAndFinalizeInputs( - const onnxruntime::Node& node, + const onnxruntime::Node& node, bool isDmlNode, std::unordered_map& nodeNameToPartitionMap ) @@ -376,19 +369,19 @@ namespace Dml partition->SetIsDmlPartition(isDmlNode); partition->AddNodeIndex(node.Index()); - for (uint32_t i = 0; i < node.InputDefs().size(); ++i) + for (uint32_t i = 0; i < node.InputDefs().size(); ++i) { const auto* arg = node.InputDefs()[i]; if (arg->Exists()) { const std::string& argName = arg->Name(); - + if (nodeNameToPartitionMap.find(argName) != nodeNameToPartitionMap.end()) { // Finalize the partition which contains an input to a non-DML-graph partition. - // The connections from that partition to other partitions, such as this one, - // must become outputs of that partition. As subsequent downstream nodes of - // the finalized partition are visited, other outputs will subsequently be + // The connections from that partition to other partitions, such as this one, + // must become outputs of that partition. As subsequent downstream nodes of + // the finalized partition are visited, other outputs will subsequently be // added to the partition, too. GraphPartition* inputPartition = nodeNameToPartitionMap[argName]->GetRootMergedPartition(); inputPartition->SetFinalized(); @@ -407,13 +400,13 @@ namespace Dml // Get the partitions which are inputs to the specified node and which are not finalized. std::vector GetNonFinalizedInputPartitions( - const onnxruntime::Node& node, + const onnxruntime::Node& node, std::unordered_map& nodeNameToPartitionMap ) { std::vector inputNonFinalPartitions; - for (uint32_t i = 0; i < node.InputDefs().size(); ++i) + for (uint32_t i = 0; i < node.InputDefs().size(); ++i) { const auto* arg = node.InputDefs()[i]; if (arg->Exists()) @@ -437,15 +430,15 @@ namespace Dml return inputNonFinalPartitions; } - + // Add graph outputs of the new node to a partition. void AddGraphOutputsFromNodeToPartition( - const onnxruntime::Node& node, + const onnxruntime::Node& node, const std::set& graphOutputs, GraphPartition* partition ) { - for (uint32_t i = 0; i < node.OutputDefs().size(); ++i) + for (uint32_t i = 0; i < node.OutputDefs().size(); ++i) { const auto* arg = node.OutputDefs()[i]; if (arg->Exists()) @@ -459,7 +452,7 @@ namespace Dml } std::unique_ptr CreateNewPartitionWithFinalizedInputPartitions( - const onnxruntime::Node& node, + const onnxruntime::Node& node, const std::set& graphOutputs, std::unordered_map& nodeNameToPartitionMap ) @@ -471,7 +464,7 @@ namespace Dml // Inputs of the partition are added when partitions are created and extended when // nodes are added with inputs which are not inside the partition - for (uint32_t i = 0; i < node.InputDefs().size(); ++i) + for (uint32_t i = 0; i < node.InputDefs().size(); ++i) { const auto* arg = node.InputDefs()[i]; if (arg->Exists()) @@ -496,11 +489,11 @@ namespace Dml return partition; } - + std::unique_ptr ComputationCapacityFromPartition( - GraphPartition* partition, - uint32_t partitionIndex, - const onnxruntime::GraphViewer& graph, + GraphPartition* partition, + uint32_t partitionIndex, + const onnxruntime::GraphViewer& graph, std::unordered_map&& graphNodePropertyMap, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::string& partitionKernelPrefix, @@ -511,7 +504,7 @@ namespace Dml if (partition->IsDmlGraphPartition()) { assert(partition->IsDmlGraphPartition()); - + // Create a definition for the node. The name must be unique. auto def = std::make_unique(); def->name = std::string("DmlFusedNode_") + partitionKernelPrefix + std::to_string(partitionIndex); @@ -525,7 +518,7 @@ namespace Dml for (auto nodeIndex : partition->GetNodeIndices()) { const onnxruntime::Node* node = graph.GetNode(nodeIndex); - + #ifdef PRINT_PARTITON_INFO printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); #endif @@ -560,7 +553,7 @@ namespace Dml .Provider(onnxruntime::kDmlExecutionProvider); ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); - + subGraph->SetMetaDef(std::move(def)); } @@ -581,7 +574,7 @@ namespace Dml const std::vector& toplogicalOrder = graph.GetNodesInTopologicalOrder(); - for (size_t nodeIndex : toplogicalOrder) + for (size_t nodeIndex : toplogicalOrder) { const onnxruntime::Node& node = *graph.GetNode(nodeIndex); if (node.ContainsSubgraph()) @@ -593,21 +586,21 @@ namespace Dml return false; } - // + // // A simple graph partitioning algorithm is used: // // - If a node has any input which is already in a graph, and that graph is not finalized, // then the node and all such input graphs are merged. // - // - Once a node has an output which cannot be merged with its graph, its graph is marked - // as final, which disallows its future extensions. This ensures that no indirect + // - Once a node has an output which cannot be merged with its graph, its graph is marked + // as final, which disallows its future extensions. This ensures that no indirect // downstream dependencies of the external output node are later merged. // std::vector> BuildPartitions( const onnxruntime::GraphViewer& graph, const InternalRegistrationInfoMap& internalRegInfoMap, - const std::vector& registries, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, @@ -628,28 +621,28 @@ namespace Dml for (const auto* arg : graph.GetInputsIncludingInitializers()) { graphInputs.insert(arg->Name()); - } - + } + // If a model contains an intializer which is not also a graph input, it will not be returned // by GetInputsIncludingInitializers above. Such models would be invalid, however they loaded - // in RS5. For compatibility, this ensures that such models continue to load. This is + // in RS5. For compatibility, this ensures that such models continue to load. This is // verified by an ONNX conformance test for Add. for (const auto& arg : graph.GetAllInitializedTensors()) { // This adds the initializer to the input set if it didn't already exist. graphInputs.insert(arg.first); } - + for (const auto* arg : graph.GetOutputs()) { graphOutputs.insert(arg->Name()); } - - // Check whether this graph is a subgraph, or contains any node with a subgraph. + + // Check whether this graph is a subgraph, or contains any node with a subgraph. bool modelUsesSubgraph = ModelUsesSubgraph(graph); // Build up partitions while traversing the graph. - for (size_t nodeIndex : toplogicalOrder) + for (size_t nodeIndex : toplogicalOrder) { const onnxruntime::Node& node = *graph.GetNode(nodeIndex); @@ -659,12 +652,12 @@ namespace Dml // Whether the node is implemented through DML and as a graph node, meaning it // can generate DML operations through a private interface for use as an MLGraph node. bool isDmlGraphNode = false; - + // Get the registration properties above and populate nodeNameToPartitionMap. GetRegistrationProperties( graph, node, - registries, + kernel_lookup, supportedDeviceDataTypeMask, internalRegInfoMap, &nodeNameToPartitionMap, @@ -676,9 +669,9 @@ namespace Dml // Add a unique partition if graph node usage is not supported. // - // Partitioning is disabled in models with subgraphs to work around issues with implicit inputs. - // The partitioning algorithm does not currently consider such inputs. Transfering shared initializers - // for partitions could also cause problems. Note, operators with subgraphs are currently not efficient + // Partitioning is disabled in models with subgraphs to work around issues with implicit inputs. + // The partitioning algorithm does not currently consider such inputs. Transfering shared initializers + // for partitions could also cause problems. Note, operators with subgraphs are currently not efficient // anyhow due to CPU/GPU copies. if (modelUsesSubgraph || !isDmlGraphNode) { @@ -690,9 +683,9 @@ namespace Dml partitions.push_back(CreateNonGraphNodePartitionAndFinalizeInputs(node, isDmlNode, nodeNameToPartitionMap)); continue; } - + std::vector inputNonFinalPartitions = GetNonFinalizedInputPartitions(node, nodeNameToPartitionMap); - + if (inputNonFinalPartitions.empty()) { partitions.push_back(CreateNewPartitionWithFinalizedInputPartitions(node, graphOutputs, nodeNameToPartitionMap)); @@ -706,7 +699,7 @@ namespace Dml AddNodeOutputsToPartitionMap(node, firstNonFinalInputPartition, nodeNameToPartitionMap); // Add inputs for the new node which span partitions - for (uint32_t i = 0; i < node.InputDefs().size(); ++i) + for (uint32_t i = 0; i < node.InputDefs().size(); ++i) { const auto* arg = node.InputDefs()[i]; if (arg->Exists()) @@ -716,18 +709,18 @@ namespace Dml // Add the input of the current node into the partition which the node will be merged into. // Skip this if the input is already merged into the same partition or is not finalized, // and so will be subsequently merged below. - if (inputPartition != nodeNameToPartitionMap.end() && + if (inputPartition != nodeNameToPartitionMap.end() && inputPartition->second->GetRootMergedPartition() != firstNonFinalInputPartition && inputPartition->second->GetRootMergedPartition()->IsFinalized()) { - // Add this input of the current node as an output of the final partition to which - // it belongs. + // Add this input of the current node as an output of the final partition to which + // it belongs. inputPartition->second->GetRootMergedPartition()->AddOutput(arg->Name()); firstNonFinalInputPartition->AddInput(arg->Name()); } - + if (graphInputs.find(arg->Name()) != graphInputs.end()) - { + { firstNonFinalInputPartition->AddInput(arg->Name()); } } @@ -740,7 +733,7 @@ namespace Dml if (inputNonFinalPartitions.size() > 1) { firstNonFinalInputPartition->Merge(gsl::span(&inputNonFinalPartitions[1], inputNonFinalPartitions.size() - 1)); - } + } } } @@ -784,7 +777,7 @@ namespace Dml PartitionGraph( const onnxruntime::GraphViewer& graph, const InternalRegistrationInfoMap& internalRegInfoMap, - const std::vector& registries, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. onnxruntime::KernelRegistry* registryForPartitionKernels, const std::string& partitionKernelPrefix @@ -798,10 +791,10 @@ namespace Dml std::unordered_map graphNodePropertyMap; std::vector> partitions = BuildPartitions( graph, - internalRegInfoMap, - registries, + internalRegInfoMap, + kernel_lookup, supportedDeviceDataTypeMask, - graphNodePropertyMap, + graphNodePropertyMap, requiredInitializerMap); // Create a map between each initialized tensor and the partition(s) it is part of. @@ -817,7 +810,7 @@ namespace Dml continue; } - // Create a map which will store by name each initializer which should be transferred to the + // Create a map which will store by name each initializer which should be transferred to the // partition. This prevents OnnxRuntime from allocating GPU resources and uploading those initializers, // so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU // backed resource, avoiding an extra set of GPU resources in memory. @@ -853,16 +846,16 @@ namespace Dml onnx::TensorProto partitionTensor; graphTensor.Swap(&partitionTensor); (*transferredInitializerMap)[input] = std::move(partitionTensor); - + const_cast(graph.GetAllInitializedTensors()).erase(graph.GetAllInitializedTensors().find(input)); } } } result.push_back(ComputationCapacityFromPartition( - partition.get(), - partitionIndex, - graph, + partition.get(), + partitionIndex, + graph, std::move(graphNodePropertyMap), registryForPartitionKernels, partitionKernelPrefix, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index e0fd8af31d6af..b82999caffa1b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -44,7 +44,7 @@ namespace Dml BuildPartitions( const onnxruntime::GraphViewer& graph, const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap, - const std::vector& registries, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, @@ -54,7 +54,7 @@ namespace Dml PartitionGraph( const onnxruntime::GraphViewer& graph, const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap, - const std::vector& registries, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. onnxruntime::KernelRegistry* registryForPartitionKernels, const std::string& partitionKernelPrefix @@ -62,7 +62,7 @@ namespace Dml bool IsNodeSupportedByDml( const onnxruntime::Node& node, - const onnxruntime::KernelRegistry& registry, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap ); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index dc8399059f41b..e1076ab235b81 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -10,13 +10,15 @@ #include "core/providers/dml/OperatorAuthorHelper/Attributes.h" #include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h" #include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h" +#include "core/framework/kernel_lookup.h" #include "core/framework/kernel_registry.h" +#include "core/framework/kernel_type_str_resolver.h" #include "core/graph/graph_utils.h" namespace Dml { GraphTransformer::GraphTransformer( - const std::string& name, + const std::string& name, const onnxruntime::IExecutionProvider* provider ) : onnxruntime::GraphTransformer(name), @@ -27,7 +29,7 @@ namespace Dml onnxruntime::common::Status GraphTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, - int graph_level, const onnxruntime::logging::Logger&) const + int graph_level, const onnxruntime::logging::Logger&) const { modified = false; @@ -41,7 +43,7 @@ namespace Dml PerformQuantizedOperatorDecomposition(&graph, &transformModifiedGraph); modified |= transformModifiedGraph; - if (modified) + if (modified) { ORT_RETURN_IF_ERROR(graph.Resolve()); } @@ -60,11 +62,9 @@ namespace Dml } return ss.str(); } - + void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const { - onnxruntime::KernelRegistry* registry = m_providerImpl->GetKernelRegistry().get(); - struct NodeToAdd { std::string name; @@ -84,6 +84,13 @@ namespace Dml // graph while iterating over it std::vector nodesToAdd; + onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, + gsl::make_span(®istry, 1), + kernel_type_str_resolver}; + for (auto& node : graph->Nodes()) { // We need to predict whether the nodes will be assigned to the DML transformer by Lotus, @@ -91,7 +98,7 @@ namespace Dml if (!IsNodeSupportedByDml( node, - *registry, + kernel_lookup, m_providerImpl->GetSupportedDeviceDataTypeMask(), *m_providerImpl->GetInternalRegistrationInfoMap().get() )) @@ -114,7 +121,7 @@ namespace Dml // We need to predict whether the nodes will be assigned to the DML transformer by Lotus, // which occurs in IExecutionProvider::GetCapability. - if (!onnxruntime::KernelRegistry::HasImplementationOf(*registry, outputNode, onnxruntime::kDmlExecutionProvider)) + if (!kernel_lookup.LookUpKernel(outputNode)) { // Can't fuse nodes that don't belong to this execution provider continue; @@ -165,7 +172,7 @@ namespace Dml fusedNode.activationAttributes = activationNode.GetAttributes(); // Inputs to the fused node are the inputs to the fuseable node - for (const auto *input : fuseableNode.InputDefs()) + for (const auto *input : fuseableNode.InputDefs()) { fusedNode.inputs.push_back(graph->GetNodeArg(input->Name())); } @@ -230,7 +237,7 @@ namespace Dml std::vector inputs; std::vector outputs; }; - + // Defer adding and removing nodes in the graph until after we're done iterating over it, because we can't mutate the // graph while iterating over it std::vector nodesToAdd; @@ -255,13 +262,13 @@ namespace Dml dequantizeNode.name = "decomposed_QLinearSigmoid_DequantizeLinear_" + GetUniqueNodeName(&node); dequantizeNode.description = ""; dequantizeNode.opType = "DequantizeLinear"; - dequantizeNode.domain = ""; + dequantizeNode.domain = ""; dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[0]->Name())); dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[1]->Name())); dequantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[2]->Name())); dequantizeNode.outputs.push_back(sigmoidInputArg); - + nodesToAdd.push_back(std::move(dequantizeNode)); } @@ -270,10 +277,10 @@ namespace Dml sigmoidNode.name = "decomposed_QLinearSigmoid_Sigmoid_" + GetUniqueNodeName(&node); sigmoidNode.description = ""; sigmoidNode.opType = "Sigmoid"; - sigmoidNode.domain = ""; + sigmoidNode.domain = ""; sigmoidNode.inputs.push_back(sigmoidInputArg); - sigmoidNode.outputs.push_back(sigmoidOutputArg); - nodesToAdd.push_back(std::move(sigmoidNode)); + sigmoidNode.outputs.push_back(sigmoidOutputArg); + nodesToAdd.push_back(std::move(sigmoidNode)); } { @@ -282,13 +289,13 @@ namespace Dml quantizeNode.description = ""; quantizeNode.opType = "QuantizeLinear"; quantizeNode.domain = ""; - + quantizeNode.inputs.push_back(sigmoidOutputArg); quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[3]->Name())); quantizeNode.inputs.push_back(graph->GetNodeArg(node.InputDefs()[4]->Name())); quantizeNode.outputs.push_back(graph->GetNodeArg(node.OutputDefs()[0]->Name())); - - nodesToAdd.push_back(std::move(quantizeNode)); + + nodesToAdd.push_back(std::move(quantizeNode)); } nodesToRemove.push_back(node.Index()); @@ -308,7 +315,7 @@ namespace Dml nodeToAdd.domain); } - for (const auto& nodeIndex : nodesToRemove) + for (const auto& nodeIndex : nodesToRemove) { onnxruntime::Node* node = graph->GetNode(nodeIndex); onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, *node); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h index 1dc650c5d6f00..fa04bcf6edf41 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h @@ -14,9 +14,9 @@ namespace SchemaInferenceOverrider // require type inference functions. template void OverrideSchemaInferenceFunction( - _In_z_ const char* name, - int version, - bool isLatest, + _In_z_ const char* name, + int version, + bool isLatest, gsl::span constantCpuInputs ) { @@ -41,8 +41,8 @@ namespace SchemaInferenceOverrider } auto abiContext = Windows::AI::MachineLearning::Adapter::MLSchemaInferenceContext::Create( - &nodeInfo, - &ctx, + &nodeInfo, + &ctx, constantCpuInputsCapture); ORT_THROW_IF_FAILED(shapeInferrer->InferOutputShapes(abiContext.Get())); @@ -54,7 +54,7 @@ namespace SchemaInferenceOverrider { // Assert that this is the latest schema version for the operator, since a new version might need // the same treatment. - const uint32_t maxVersion = 9; + [[maybe_unused]] constexpr uint32_t maxVersion = 9; assert( !onnx::OpSchemaRegistry::Schema(name, maxVersion) || onnx::OpSchemaRegistry::Schema(name, maxVersion) == onnx::OpSchemaRegistry::Schema(name, version)); @@ -65,7 +65,7 @@ namespace SchemaInferenceOverrider #define OVERRIDE_SCHEMA(version, isLatest, opName) \ OverrideSchemaInferenceFunction( \ #opName, OperatorHelper::OnnxOperatorSet##version##::sc_sinceVer_##opName, isLatest, gsl::span()); - + #pragma push_macro("OVERRIDE_SCHEMA_EX") #define OVERRIDE_SCHEMA_EX(version, isLatest, opName, shapeInferenceName, /*CPU constant tensor indices*/ ...) \ OverrideSchemaInferenceFunction( \ @@ -93,4 +93,4 @@ OverrideSchemaInferenceFunction> DNNLExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer) const { std::vector> supported_node_vecs; std::vector supported_node_vec; - + std::unordered_map all_nodes_count; std::unordered_map supported_nodes_count; @@ -119,18 +119,16 @@ std::vector> DNNLExecutionProvider::GetSupportedNodes(con LOGS_DEFAULT(ERROR) << "Total coverge: " << support_counts << ":" << all_counts << " percentage: " << (float)support_counts / (float)all_counts; } - + return supported_node_vecs; } std::vector> DNNLExecutionProvider::GetCapability( const GraphViewer& graph_viewer, - const std::vector& kernel_registries) const { + const IKernelLookup& /*kernel_lookup*/) const { //follow from coreml ep's Getcapability - ORT_UNUSED_PARAMETER(kernel_registries); - std::vector> result; if (graph_viewer.IsSubgraph()) { diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index d4e49c2e17b4e..99cab3f9d237f 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -32,7 +32,7 @@ class DNNLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 1f1ee2239f22f..bc9e079363851 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -866,7 +866,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, const s std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index af0c0aeec01ff..d16a982414f6d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -43,7 +43,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; @@ -62,7 +62,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool fp16_enable_ = false; bool dump_model_ops_ = false; int device_id_; - migraphx::target t_; + migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc index 7cf451123e6b0..caa78358020a7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc @@ -389,7 +389,7 @@ Status Shaper::ConcatImpl(const std::vector& input_names, Status Shaper::SplitImpl(const std::string& input_name, int32_t axis, const std::vector& output_names) { const auto& input_shape = shape_map_.at(input_name); - const auto count = output_names.size(); + const auto count = static_cast(output_names.size()); ORT_RETURN_IF_NOT(input_shape[axis] % count == 0, "count [", count, "] does not evenly divide dimension ", axis, " [", input_shape[axis], "]"); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 52c4cb5732c54..782b384a63c96 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -71,7 +71,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index d59660821d4da..fd2c621f9f9c8 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -21,7 +21,7 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 16bd72117aa97..d9fe6f95b011c 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -93,9 +93,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv } std::vector> -OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const std::vector& kernel_registries) const { - ORT_UNUSED_PARAMETER(kernel_registries); - +OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; //Enable CI Logs if (!(GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG").empty())) { @@ -144,7 +143,7 @@ common::Status OpenVINOExecutionProvider::Compile( openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; #else openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false; -#endif +#endif std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index f8b949f4c3dc9..4d62944a1436a 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -155,7 +155,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph_viewer, - const std::vector& kernel_registries) const override; + const IKernelLookup& /*kernel_lookup*/) const override; Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index cfe2527befafc..085d496a52a82 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -65,7 +65,7 @@ std::vector> RknpuExecutionProvider::GetSupportedNodes( std::vector> RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { // Find inputs, initializers and outputs for each supported subgraph std::vector> result; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index 22ad8097816a8..1289c8569f8e8 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -19,7 +19,7 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index de1fd8e7af0f7..f7df4c367b748 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2251,7 +2251,7 @@ std::unique_ptr ROCMExecutionProvider::GetDataTransf std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { + const IKernelLookup& kernel_lookup) const { InlinedVector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); @@ -2259,19 +2259,11 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - const KernelCreateInfo* rocm_kernel_def = nullptr; if (!node.GetExecutionProviderType().empty()) { continue; } - for (auto registry : kernel_registries) { - auto st = registry->TryFindKernel(node, Type(), &rocm_kernel_def); - - // at least one registry has a ROCM kernel for this node - if (st.IsOK()) - break; - } - + const KernelCreateInfo* rocm_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a ROCM kernel for this node if (rocm_kernel_def == nullptr) { LOGS_DEFAULT(INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); @@ -2302,7 +2294,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, Type(), kernel_registries, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); std::vector> result; for (auto& node_index : candidates) { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 7a528cd1131b3..721364836c3b8 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -84,7 +84,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const override; + const IKernelLookup& kernel_lookup) const override; int GetDeviceId() const override { return info_.device_id; } const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index a7fcfe8901fd4..f1ae644551e4a 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -253,8 +253,7 @@ std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const c std::unique_ptr CreateGPUDataTransfer(void* stream); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes); std::string GetEnvironmentVar(const std::string& var_name); @@ -327,4 +326,3 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { r #define LOGS_DEFAULT(severity) \ LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) - diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 94048e78a5734..0824a82c2b3b8 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -288,8 +288,8 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) { } std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_registries); + const IKernelLookup& kernel_lookup) const { + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { @@ -337,10 +337,9 @@ std::string GetEnvironmentVar(const std::string& var_name) { } std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes) { - return g_host->GetCpuPreferredNodes(graph, provider_type, kernel_registries, tentative_nodes); + return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); } namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5c6ac3a623d92..6d0ad41cfc78c 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -182,8 +182,7 @@ struct ProviderHost { #endif virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0; @@ -224,7 +223,7 @@ struct ProviderHost { virtual AllocatorPtr IExecutionProvider__GetAllocator(const IExecutionProvider* p, int id, OrtMemType mem_type) = 0; virtual void IExecutionProvider__InsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) = 0; virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) = 0; + const IExecutionProvider::IKernelLookup& kernel_lookup) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; @@ -506,7 +505,6 @@ struct ProviderHost { virtual std::shared_ptr KernelRegistry__construct() = 0; virtual void KernelRegistry__operator_delete(KernelRegistry* p) = 0; virtual Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) = 0; - virtual Status KernelRegistry__TryFindKernel(const KernelRegistry* p, const Node& node, ProviderType exec_provider, const KernelCreateInfo** out) = 0; // PrimitiveDataTypeBase virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 597ee06850197..90718cc7ddd55 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -511,8 +511,6 @@ struct KernelRegistry final { Status Register(KernelCreateInfo&& create_info) { return g_host->KernelRegistry__Register(this, std::move(create_info)); } - Status TryFindKernel(const Node& node, ProviderType exec_provider, const KernelCreateInfo** out) const { return g_host->KernelRegistry__TryFindKernel(this, node, exec_provider, out); } - KernelRegistry() = delete; KernelRegistry(const KernelRegistry&) = delete; void operator=(const KernelRegistry&) = delete; diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index d103ba43b4538..e514b197e47e9 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -80,7 +80,7 @@ SNPEExecutionProvider::~SNPEExecutionProvider() {} std::vector> SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { + const IKernelLookup& kernel_lookup) const { std::vector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); @@ -88,24 +88,11 @@ SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - const KernelCreateInfo* snpe_kernel_def = nullptr; if (!node.GetExecutionProviderType().empty()) { continue; } - for (auto registry : kernel_registries) { -#if defined(ORT_MINIMAL_BUILD) - auto st = registry->TryFindKernel(node, Type(), uint64_t(0), &snpe_kernel_def); -#else - auto st = registry->TryFindKernel(node, Type(), &snpe_kernel_def); -#endif - - // at least one registry has a SNPE kernel for this node - if (st.IsOK()) - break; - } - - // none of the provided registries has a SNPE kernel for this node + const KernelCreateInfo* snpe_kernel_def = kernel_lookup.LookUpKernel(node); if (snpe_kernel_def == nullptr) { LOGS_DEFAULT(WARNING) << "Snpe kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.h b/onnxruntime/core/providers/snpe/snpe_execution_provider.h index 7d2a6ac2e100e..c0a62eea11a25 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.h +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.h @@ -18,7 +18,7 @@ class SNPEExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector&) const override; + const IKernelLookup& kernel_lookup) const override; std::shared_ptr GetKernelRegistry() const override; std::unordered_map GetRuntimeOptions() const { return runtime_options_; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 2454c0c790eea..53932df1efd0f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -564,7 +564,7 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const GraphViewer& gra } -// Check whether all the nodes of the graph are assigned to specific ep +// Check whether all the nodes of the graph are assigned to specific ep bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const GraphViewer& graph, const std::string& provider_type) const { const int number_of_ort_nodes = graph.NumberOfNodes(); std::vector nodes_vector(number_of_ort_nodes); @@ -1030,7 +1030,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { // Get ModelPath const auto& path_string = graph.ModelPath().ToPathString(); #ifdef _WIN32 @@ -1050,7 +1050,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const auto& node = graph.GetNode(node_index[index]); /* If current node is control flow op, we take different approach based on following four cases: - * + * * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. @@ -1145,7 +1145,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, if (all_subgraphs_are_supported) { // We want the subgraph nodes to be assigned to TRT EP but don't want them to be fused until later at the control flow op level. - // Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does). + // Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does). for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { for (const auto& index : group.first) { @@ -1157,8 +1157,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; return result; - } - } + } + } int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 76ee443ca50c1..a89350a0112f1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -124,7 +124,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return device_id_; } @@ -168,7 +168,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { int device_id_; AllocatorPtr allocator_; bool context_memory_sharing_enable_ = false; - size_t max_ctx_mem_size_ = 0; + size_t max_ctx_mem_size_ = 0; IAllocatorUniquePtr context_memory_ = nullptr; mutable char model_path_[4096]; // Reserved for max path length bool engine_decryption_enable_ = false; @@ -201,8 +201,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph, bool remove_cycles = true) const; - /** - Get a unique_lock object to control the concurrency behavior. + /** + Get a unique_lock object to control the concurrency behavior. Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) should be protected by a lock when invoked by multiple threads concurrently. */ @@ -211,7 +211,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { /**Check the graph is the subgraph of control flow op*/ bool IsSubGraphOfControlFlowOp(const GraphViewer& graph) const; - /**Check whether all the nodes of the graph are assigned to specific ep*/ + /**Check whether all the nodes of the graph are assigned to specific ep*/ bool AllNodesAssignedToSpecificEP(const GraphViewer& graph, const std::string& provider_type) const; /**Check whether all the nodes of subgraph are supported*/ diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc index 95316a349827e..c677a5cd79446 100644 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc +++ b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc @@ -57,7 +57,7 @@ TvmExecutionProvider::~TvmExecutionProvider() {} std::vector> TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; if (graph_viewer.IsSubgraph()) { return result; diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_execution_provider.h index 7be73e985e993..1cfd18e794b13 100644 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.h +++ b/onnxruntime/core/providers/tvm/tvm_execution_provider.h @@ -34,7 +34,7 @@ class TvmExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc index aa4b4b8d96afd..f4a57f9e6dd0e 100644 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc +++ b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc @@ -56,7 +56,7 @@ TvmSoExecutionProvider::~TvmSoExecutionProvider() {} std::vector> TvmSoExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; if (graph_viewer.IsSubgraph()) { return result; diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h index 43d433b53b760..52095edb83390 100644 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h +++ b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h @@ -34,7 +34,7 @@ class TvmSoExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index db0c5e209cd73..c0241c02b12c7 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -215,9 +215,7 @@ static void AppendClusterToSubGraph(const std::vector& nodes, std::vector> VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { - ORT_UNUSED_PARAMETER(kernel_registries); - + const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; // Dump model Proto to file to pass it to pyxir diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 1282e0283f5d8..10d2845ed8a27 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -25,7 +25,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return device_id_; } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 8b670e89ddb84..c1984d7b0201e 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -182,7 +182,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { std::vector> capabilities; std::shared_ptr registry = GetKernelRegistry(); diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h index 4d478d79df694..5266ad7465ac3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h @@ -25,8 +25,8 @@ class XnnpackExecutionProvider : public IExecutionProvider { ~XnnpackExecutionProvider() override; std::vector> GetCapability( - const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const override; + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c7cfa14ff4dd9..a6c480c1fbd96 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -14,6 +14,7 @@ #include "core/common/denormal.h" #include "core/common/logging/logging.h" #include "core/common/parse_string.h" +#include "core/common/path_string.h" #include "core/flatbuffers/flatbuffers_utils.h" #include "core/flatbuffers/ort_format_version.h" #include "core/framework/allocatormgr.h" @@ -24,15 +25,15 @@ #include "core/framework/graph_partitioner.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/kernel_registry.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_type_str_resolver_utils.h" #include "core/framework/mldata_type_utils.h" -#include "core/framework/session_state_flatbuffers_utils.h" #include "core/framework/TensorSeq.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/op_kernel_context_internal.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/utils.h" -#include "core/framework/kernel_def_hash_helpers.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_transformer_utils.h" @@ -594,29 +595,33 @@ common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) con flatbuffers::FlatBufferBuilder builder(fbs_buffer_size); auto ort_model_version = builder.CreateString(kOrtModelVersion); - flatbuffers::Offset model; + flatbuffers::Offset fbs_model; ORT_RETURN_IF_ERROR( - model_->SaveToOrtFormat(builder, model)); + model_->SaveToOrtFormat(builder, fbs_model)); + + flatbuffers::Offset fbs_kernel_type_str_resolver; + KernelTypeStrResolver kernel_type_str_resolver{}; + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterGraphNodeOpSchemas(model_->MainGraph())); + for (const auto* op_schema : saved_runtime_optimization_produced_node_op_schemas_) { + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterOpSchema(*op_schema)); + } - flatbuffers::Offset session_state; ORT_RETURN_IF_ERROR( - session_state_->SaveToOrtFormat(builder, session_state)); + kernel_type_str_resolver.SaveToOrtFormat(builder, fbs_kernel_type_str_resolver)); fbs::InferenceSessionBuilder sb(builder); sb.add_ort_version(ort_model_version); - sb.add_model(model); - sb.add_session_state(session_state); + sb.add_model(fbs_model); + sb.add_kernel_type_str_resolver(fbs_kernel_type_str_resolver); auto session = sb.Finish(); builder.Finish(session, fbs::InferenceSessionIdentifier()); - // TODO: Do we need to catch any std::exceptions from creating/writing to disk and convert to Status codes? { std::ofstream file(filepath, std::ios::binary); - uint8_t* buf = builder.GetBufferPointer(); int size = builder.GetSize(); file.write(reinterpret_cast(buf), size); - file.close(); + ORT_RETURN_IF_NOT(file, "Failed to save ORT format model to file: ", ToUTF8String(filepath)); } return Status::OK(); @@ -1020,9 +1025,18 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort // Check version mismatch, for now we will only proceed when runtime version matches the model's ort version const auto* fbs_ort_model_version = fbs_session->ort_version(); ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model."); - ORT_RETURN_IF_NOT(IsOrtModelVersionSupported(fbs_ort_model_version->str()), - "The ORT format model version [", fbs_ort_model_version->str(), - "] is not supported this build ", ORT_VERSION); + + // Note about the ORT format version 5 breaking change. + // TODO This change was introduced in 1.13. Remove this note a few releases later, e.g., 1.15. + // TODO(edgchen1) update link to point to 1.13 release branch + constexpr auto* kOrtFormatVersion5BreakingChangeNote = + "This build doesn't support ORT format models older than version 5. " + "See: https://github.com/microsoft/onnxruntime/blob/main/docs/ORT_Format_Update_in_1.13.md"; + + ORT_RETURN_IF_NOT(IsOrtModelVersionSupported(fbs_ort_model_version->string_view()), + "The ORT format model version [", fbs_ort_model_version->string_view(), + "] is not supported in this build ", ORT_VERSION, ". ", + kOrtFormatVersion5BreakingChangeNote); const auto* fbs_model = fbs_session->model(); ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model."); @@ -1051,9 +1065,17 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model)); model_ = std::move(tmp_model); - // Initialize takes the session_mutex_ as well so we need to have released it prior to calling this - const auto* fbs_sess_state = fbs_session->session_state(); - ORT_RETURN_IF(nullptr == fbs_sess_state, "SessionState is null. Invalid ORT format model."); + KernelTypeStrResolver kernel_type_str_resolver{}; + if (const auto* fbs_kernel_type_str_resolver = fbs_session->kernel_type_str_resolver(); + fbs_kernel_type_str_resolver != nullptr) { + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.LoadFromOrtFormat(*fbs_kernel_type_str_resolver)); + } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + ORT_RETURN_IF_ERROR( + kernel_type_str_resolver_utils::AddLayoutTransformationRequiredOpsToKernelTypeStrResolver( + kernel_type_str_resolver)); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + kernel_registry_manager_.SetKernelTypeStrResolver(std::move(kernel_type_str_resolver)); is_model_loaded_ = true; @@ -1122,32 +1144,29 @@ common::Status InferenceSession::AddPrePackedWeightsContainer(PrepackedWeightsCo } namespace { -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, SessionState& session_state) { - std::unordered_map compiled_kernel_hashes; - +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // only provide NCWH to NHWC layout transformer if supported TransformLayoutFunction transform_layout_fn = layout_transformer::IsSupportedOpset(graph) ? layout_transformer::TransformLayoutForEP : nullptr; +#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + TransformLayoutFunction transform_layout_fn{}; +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, - GraphPartitioner::Mode::kOrtFormatLoad, - &compiled_kernel_hashes)); - - if (!compiled_kernel_hashes.empty()) { - session_state.SetCompiledKernelHashes(std::move(compiled_kernel_hashes)); - } + GraphPartitioner::Mode::kOrtFormatLoad)); return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { @@ -1324,13 +1343,8 @@ common::Status InferenceSession::Initialize() { return false; }(); - const fbs::SessionState* serialized_session_state = - loading_ort_format - ? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state() - : nullptr; - -#if !defined(ORT_MINIMAL_BUILD) if (!loading_ort_format) { +#if !defined(ORT_MINIMAL_BUILD) const auto minimal_build_opt_config_value = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigMinimalBuildOptimizations, ""); MinimalBuildOptimizationHandling minimal_build_optimization_handling{}; @@ -1338,10 +1352,16 @@ common::Status InferenceSession::Initialize() { saving_ort_format, minimal_build_optimization_handling)); + auto record_runtime_optimization_produced_op_schema = [this](const ONNX_NAMESPACE::OpSchema& op_schema) { + saved_runtime_optimization_produced_node_op_schemas_.insert(&op_schema); + return Status::OK(); + }; + // add predefined transformers ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, - minimal_build_optimization_handling)); + minimal_build_optimization_handling, + record_runtime_optimization_produced_op_schema)); // apply any transformations to the main graph and any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, @@ -1391,17 +1411,16 @@ common::Status InferenceSession::Initialize() { // Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_)); - } else +#else // !defined(ORT_MINIMAL_BUILD) + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Loading anything other than ORT format models is not enabled in this build.")); #endif // !defined(ORT_MINIMAL_BUILD) - { - ORT_ENFORCE(loading_ort_format && serialized_session_state != nullptr); + } else { + ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, + *session_state_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // nodes are already partitioned, but a custom EP may compile some at runtime. - // run the partitioning to allow that to happen. - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_)); - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, @@ -1412,7 +1431,6 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_( session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, session_options_, - serialized_session_state, // need to keep the initializers if saving the optimized model !saving_model, saving_ort_format)); @@ -1496,6 +1514,7 @@ common::Status InferenceSession::Initialize() { #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif + // This method should be called from within Initialize() only and before the creation of the session state. // This ensures all providers have been registered in the session and the session state is consistent with the providers. void InferenceSession::UpdateProvidersWithSharedAllocators() { @@ -1884,7 +1903,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // scope of owned_run_logger is just the call to Execute. // If Execute ever becomes async we need a different approach std::unique_ptr owned_run_logger; - auto run_logger = CreateLoggerForRun(run_options, owned_run_logger); + const auto& run_logger = CreateLoggerForRun(run_options, owned_run_logger); std::optional> sequential_run_lock; if (is_concurrent_run_supported_ == false) { @@ -2325,7 +2344,8 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { common::Status InferenceSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, - MinimalBuildOptimizationHandling minimal_build_optimization_handling) const { + MinimalBuildOptimizationHandling minimal_build_optimization_handling, + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); @@ -2343,7 +2363,8 @@ common::Status InferenceSession::AddPredefinedTransformers( const auto sat_context = minimal_build_optimization_handling == MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations - ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{kernel_registry_manager_}} + ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{ + record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index aae2394ebd1bc..303fedcc45faa 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -122,6 +122,8 @@ class InferenceSession { OnlyApplyMinimalBuildOptimizations, }; + using RecordRuntimeOptimizationProducedNodeOpSchemaFn = std::function; + #endif /** @@ -632,7 +634,8 @@ class InferenceSession { virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, - MinimalBuildOptimizationHandling minimal_build_optimization_handling) const; + MinimalBuildOptimizationHandling minimal_build_optimization_handling, + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const; common::Status TransformGraph(onnxruntime::Graph& graph, const onnxruntime::GraphTransformerManager& graph_transformer_mgr, @@ -645,6 +648,8 @@ class InferenceSession { InsertCastTransformer insert_cast_transformer_; + // assuming that OpSchema* elements are not null. our version of gsl::not_null doesn't specialize std::hash. + InlinedHashSet saved_runtime_optimization_produced_node_op_schemas_; #endif // Any GraphTransformer/RewriteRule name in this set will not be enabled. InlinedHashSet optimizers_to_disable_; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 54aa606bae3d5..c4a7799442111 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -224,10 +224,9 @@ struct ProviderHostImpl : ProviderHost { std::string demangle(const std::string& name) override { return onnxruntime::profiling::demangle(name); } std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, - const std::string& provider_type, - gsl::span kernel_registries, + const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes) override { - return onnxruntime::GetCpuPreferredNodes(graph, provider_type, kernel_registries, tentative_nodes); + return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } @@ -278,8 +277,11 @@ struct ProviderHostImpl : ProviderHost { // IExecutionProvider (direct) AllocatorPtr IExecutionProvider__GetAllocator(const IExecutionProvider* p, int id, OrtMemType mem_type) override { return p->IExecutionProvider::GetAllocator(id, mem_type); } void IExecutionProvider__InsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) override { return p->IExecutionProvider::InsertAllocator(allocator); } - std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) override { return p->IExecutionProvider::GetCapability(graph_viewer, kernel_registries); } + std::vector> IExecutionProvider__GetCapability( + const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) override { + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); + } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs); @@ -587,10 +589,6 @@ struct ProviderHostImpl : ProviderHost { void KernelRegistry__operator_delete(KernelRegistry* p) override { delete p; } Status KernelRegistry__Register(KernelRegistry* p, KernelCreateInfo&& create_info) override { return p->Register(std::move(create_info)); } - Status KernelRegistry__TryFindKernel(const KernelRegistry* p, const Node& node, ProviderType exec_provider, const KernelCreateInfo** out) override { - return p->TryFindKernel(node, exec_provider, out); - } - // PrimitiveDataTypeBase (wrapped) int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); } diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 3004c61b6cfb5..c7fe119659363 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -262,15 +262,19 @@ class PlannerTest : public ::testing::Test { state_->GetDataTransferMgr()); op_kernel_infos_.push_back(std::move(info)); - if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider)) { - auto st = reg->Register( + const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; + if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider, + kernel_type_str_resolver)) { + ASSERT_STATUS_OK(reg->Register( KernelCreateInfo(std::make_unique(kernel_def), - [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); })); - ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + out = std::make_unique(info); + return Status::OK(); + }))); } const KernelCreateInfo* kci; - ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", &kci)); + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci)); kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } @@ -301,7 +305,7 @@ class PlannerTest : public ::testing::Test { // CreatePlan is called inside FinalizeSessionState and usually the initializers are removed following that. // Leave initializers so we can duplicate the call to CreatePlan from here to validate. constexpr bool remove_initializers = false; - status = state_->FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, {}, nullptr, remove_initializers); + status = state_->FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, {}, remove_initializers); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); SequentialPlannerTestContext test_context(&shape_map_); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 9f0b639588f57..a96f645e442e4 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -131,7 +131,7 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& /*kernel_registries*/) const override { + const IKernelLookup& /*kernel_lookup*/) const override { // Fuse two add into one. std::vector> result; std::unique_ptr sub_graph = std::make_unique(); diff --git a/onnxruntime/test/framework/kernel_def_test.cc b/onnxruntime/test/framework/kernel_def_test.cc deleted file mode 100644 index 6d9d9fba75e47..0000000000000 --- a/onnxruntime/test/framework/kernel_def_test.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/kernel_def_builder.h" - -#include "gtest/gtest.h" - -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace test { - -TEST(KernelDefTest, HashIgnoresTypeConstraintTypeOrdering) { - auto build_kernel_def = [](std::vector type_constraint_types) { - return KernelDefBuilder{} - .SetName("MyOp") - .TypeConstraint("T", type_constraint_types) - .Build(); - }; - - const auto a = build_kernel_def(BuildKernelDefConstraints()); - const auto b = build_kernel_def(BuildKernelDefConstraints()); - - ASSERT_EQ(a->GetHash(), b->GetHash()); -} - -TEST(KernelDefTest, HashUsesFixedTypeConstraint) { - const auto a = - KernelDefBuilder{} - .SetName("MyOp") - .TypeConstraint("T", BuildKernelDefConstraints()) - .Build(); - const auto b = - KernelDefBuilder{} - .SetName("MyOp") - .TypeConstraint("T", BuildKernelDefConstraints()) - .FixedTypeConstraintForHash("T", BuildKernelDefConstraints()) - .Build(); - - ASSERT_EQ(a->GetHash(), b->GetHash()); -} - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/framework/kernel_registry_test.cc b/onnxruntime/test/framework/kernel_registry_test.cc index e2490b6ff437b..9c9030d89ae83 100644 --- a/onnxruntime/test/framework/kernel_registry_test.cc +++ b/onnxruntime/test/framework/kernel_registry_test.cc @@ -101,36 +101,4 @@ TEST(KernelRegistryTests, two_versions4) { ASSERT_STATUS_NOT_OK(RegKernels(r, function_table, CreateFakeKernel)); } -TEST(KernelRegistryTests, TryFindKernelByHash) { - auto kernel_def = - KernelDefBuilder() - .MayInplace(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .SetName("Elu") - .SetDomain("") - .SinceVersion(6) - .Provider(kCpuExecutionProvider) - .Build(); - const auto kernel_def_hash = kernel_def->GetHash(); - std::vector> function_table{}; - function_table.emplace_back(std::move(kernel_def)); - KernelRegistry r{}; - ASSERT_STATUS_OK(RegKernels(r, function_table, CreateFakeKernel)); - - const KernelCreateInfo* pkci = nullptr; - ASSERT_TRUE(r.TryFindKernelByHash(kernel_def_hash, &pkci)); - - const auto unregistered_kernel_def_hash = - KernelDefBuilder() - .MayInplace(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .SetName("Elu") - .SetDomain("") - .SinceVersion(1) // different from registered kernel - .Provider(kCpuExecutionProvider) - .Build() - ->GetHash(); - ASSERT_FALSE(r.TryFindKernelByHash(unregistered_kernel_def_hash, &pkci)); -} - } // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc new file mode 100644 index 0000000000000..69dd81b0fca00 --- /dev/null +++ b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/kernel_type_str_resolver_utils.h" + +#include +#include + +#include "gtest/gtest.h" + +#include "core/flatbuffers/schema/ort.fbs.h" +#include "core/graph/schema_registry.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime::test { + +static Status LoadLayoutTransformationRequiredOpsFromOpSchemas(KernelTypeStrResolver& kernel_type_str_resolver) { + const auto required_op_ids = kernel_type_str_resolver_utils::GetLayoutTransformationRequiredOpIdentifiers(); + const auto schema_registry = SchemaRegistryManager{}; + for (const auto& op_id : required_op_ids) { + const auto* op_schema = schema_registry.GetSchema(std::string{op_id.op_type}, op_id.since_version, + std::string{op_id.domain}); + ORT_RETURN_IF(op_schema == nullptr, "Failed to get op schema."); + ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterOpSchema(*op_schema)); + } + return Status::OK(); +} + +TEST(KernelTypeStrResolverUtilsTest, VerifyLayoutTransformationRequiredOpsResolver) { + KernelTypeStrResolver expected_resolver; + ASSERT_STATUS_OK(LoadLayoutTransformationRequiredOpsFromOpSchemas(expected_resolver)); + + KernelTypeStrResolver actual_resolver; + ASSERT_STATUS_OK( + kernel_type_str_resolver_utils::AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(actual_resolver)); + +#if !defined(DISABLE_CONTRIB_OPS) + ASSERT_EQ(actual_resolver.GetOpKernelTypeStrMap(), expected_resolver.GetOpKernelTypeStrMap()); +#else // !defined(DISABLE_CONTRIB_OPS) + // check that each element of expected_resolver is present and equivalent in actual_resolver + const auto& expected_op_kernel_type_str_map = expected_resolver.GetOpKernelTypeStrMap(); + const auto& actual_op_kernel_type_str_map = actual_resolver.GetOpKernelTypeStrMap(); + + for (const auto& [expected_op_id, expected_kernel_type_str_map] : expected_op_kernel_type_str_map) { + const auto actual_op_kernel_type_str_map_it = actual_op_kernel_type_str_map.find(expected_op_id); + ASSERT_NE(actual_op_kernel_type_str_map_it, actual_op_kernel_type_str_map.end()); + ASSERT_EQ(actual_op_kernel_type_str_map_it->second, expected_kernel_type_str_map); + } +#endif // !defined(DISABLE_CONTRIB_OPS) +} + +// run this test manually to output a hard-coded byte array +TEST(KernelTypeStrResolverUtilsTest, DISABLED_PrintExpectedLayoutTransformationRequiredOpsResolverByteArray) { +#if defined(DISABLE_CONTRIB_OPS) + FAIL() << "Contrib ops must be enabled."; +#endif // defined(DISABLE_CONTRIB_OPS) + KernelTypeStrResolver expected_resolver; + ASSERT_STATUS_OK(LoadLayoutTransformationRequiredOpsFromOpSchemas(expected_resolver)); + + flatbuffers::DetachedBuffer buffer; + gsl::span buffer_span; + ASSERT_STATUS_OK(kernel_type_str_resolver_utils::SaveKernelTypeStrResolverToBuffer(expected_resolver, + buffer, buffer_span)); + + constexpr size_t kBytesPerLine = 16; + std::ostringstream os; + os << std::hex << std::setfill('0') + << " constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = {\n "; + for (size_t i = 0; i < buffer_span.size(); ++i) { + os << "0x" << std::setw(2) << static_cast(buffer_span[i]) << ","; + if (i < buffer_span.size() - 1) { + os << ((i % kBytesPerLine == kBytesPerLine - 1) ? "\n " : " "); + } + } + os << "\n };\n"; + + std::cout << os.str(); +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/opaque_kernels_test.cc b/onnxruntime/test/framework/opaque_kernels_test.cc index 07755fe2a12cc..f039a83d95ba1 100644 --- a/onnxruntime/test/framework/opaque_kernels_test.cc +++ b/onnxruntime/test/framework/opaque_kernels_test.cc @@ -15,6 +15,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "asserts.h" #include "test_utils.h" using namespace ONNX_NAMESPACE; @@ -174,13 +175,13 @@ KernelDefBuilder ConstructSparseTensorDef() { .SetDomain(onnxruntime::kMLDomain) .SinceVersion(8) .Provider(onnxruntime::kCpuExecutionProvider) - .TypeConstraint("sparse_values", + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("sparse_indicies", + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("sparse_shape", + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) - .TypeConstraint("sparse_rep", + .TypeConstraint("T", DataTypeImpl::GetType()); return def; } @@ -191,9 +192,9 @@ KernelDefBuilder ConstructFetchSparseShape() { .SetDomain(onnxruntime::kMLDomain) .SinceVersion(8) .Provider(onnxruntime::kCpuExecutionProvider) - .TypeConstraint("sparse_rep", + .TypeConstraint("T1", DataTypeImpl::GetType()) - .TypeConstraint("sparse_tensor_shape", + .TypeConstraint("T", DataTypeImpl::GetTensorType()); return def; } @@ -283,17 +284,27 @@ TEST_F(OpaqueTypeTests, RunModel) { // so we construct it here before the model std::shared_ptr registry = std::make_shared(); InferenceSession session_object{so, GetEnvironment()}; - EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); + ASSERT_STATUS_OK(session_object.RegisterCustomRegistry(registry)); auto ops_schema = GetConstructSparseTensorSchema(); auto shape_schema = GetFetchSparseShapeSchema(); std::vector schemas = {ops_schema, shape_schema}; - EXPECT_TRUE(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 8, 9).IsOK()); + ASSERT_STATUS_OK(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 8, 9)); // Register our kernels here auto ctor_def = ConstructSparseTensorDef(); - EXPECT_TRUE(registry->RegisterCustomKernel(ctor_def, [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { out = std::make_unique(info); return Status::OK(); }).IsOK()); + ASSERT_STATUS_OK(registry->RegisterCustomKernel( + ctor_def, + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { + out = std::make_unique(info); + return Status::OK(); + })); auto shape_def = ConstructFetchSparseShape(); - EXPECT_TRUE(registry->RegisterCustomKernel(shape_def, [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { out = std::make_unique(info); return Status::OK(); }).IsOK()); + ASSERT_STATUS_OK(registry->RegisterCustomKernel( + shape_def, + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { + out = std::make_unique(info); + return Status::OK(); + })); IOnnxRuntimeOpSchemaRegistryList custom_schema_registries_ = {registry->GetOpschemaRegistry()}; std::unordered_map domain_to_version = {{onnxruntime::kMLDomain, 8}}; @@ -348,15 +359,15 @@ TEST_F(OpaqueTypeTests, RunModel) { node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); } - EXPECT_TRUE(graph.Resolve().IsOK()); + ASSERT_STATUS_OK(graph.Resolve()); // Get a proto and load from it std::string serialized_model; auto model_proto = model.ToProto(); EXPECT_TRUE(model_proto.SerializeToString(&serialized_model)); std::stringstream sstr(serialized_model); - EXPECT_TRUE(session_object.Load(sstr).IsOK()); - EXPECT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; @@ -390,7 +401,7 @@ TEST_F(OpaqueTypeTests, RunModel) { output_names.push_back("sparse_tensor_shape"); std::vector fetches; - EXPECT_TRUE(session_object.Run(run_options, feeds, output_names, &fetches).IsOK()); + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); ASSERT_EQ(1u, fetches.size()); auto& rtensor = fetches.front().Get(); // Should get the original shape back in the form of a tensor diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index a73780c4ee678..0898895be9580 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -369,7 +369,7 @@ TEST(OrtModelOnlyTests, MetadataSerialization) { #if !defined(DISABLE_ML_OPS) TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { const std::basic_string ort_file = - ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft_converted.test_output.ort"); + ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.test_output.ort"); SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file); OrtModelTestInfo test_info; @@ -479,7 +479,7 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopyInitializersUseBuffer) // for a model with sequence and map outputs OrtModelTestInfo GetTestInfoForLoadOrtFormatModelMLOps() { OrtModelTestInfo test_info; - test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort"); + test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.onnx.ort"); test_info.logid = "LoadOrtFormatModelMLOps"; OrtValue ml_value; @@ -536,26 +536,6 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) { RunOrtModel(test_info); } -TEST(OrtModelOnlyTests, TestBackwardsCompat) { - auto v110_dir = ORT_TSTR("testdata/ort_backwards_compat/ORTv1.10/"); - std::vector models = {"gathernd9.basic.ort", - "not1.basic.ort", - "roialign10.basic.ort", - "scan9.basic.ort"}; - - SessionOptions session_options; - session_options.session_logid = "TestBackwardsCompat"; - - for (const auto& model : models) { - // test loading old model succeeds. if it does the hash replacement worked. - InferenceSession session{session_options, GetEnvironment()}; - auto model_uri = v110_dir + ToPathString(model); - - ASSERT_STATUS_OK(session.Load(model_uri)); - ASSERT_STATUS_OK(session.Initialize()); - } -} - #endif // !defined(DISABLE_ML_OPS) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index a049af3edc87d..62cc441e6de92 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -15,6 +15,7 @@ #include "core/graph/model.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" +#include "asserts.h" #include "test_utils.h" #include "file_util.h" #include "default_providers.h" @@ -147,13 +148,10 @@ This operator constructs a sparse tensor from three tensors that provide a COO static KernelDefBuilder KernelDef() { KernelDefBuilder def; def.SetName(SparseFromCOO::OpName()) - .TypeConstraint("values", DataTypeImpl::GetTensorType()) - .TypeConstraint("indices", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) #if !defined(DISABLE_SPARSE_TENSORS) - .TypeConstraint("shape", DataTypeImpl::GetTensorType()) - .TypeConstraint("sparse_rep", DataTypeImpl::GetSparseTensorType()); -#else - .TypeConstraint("shape", DataTypeImpl::GetTensorType()); + .TypeConstraint("T", DataTypeImpl::GetSparseTensorType()); #endif return def; } @@ -294,8 +292,8 @@ struct SparseToValues { KernelDefBuilder def; #if !defined(DISABLE_SPARSE_TENSORS) def.SetName(OpName()) - .TypeConstraint("sparse_rep", DataTypeImpl::GetSparseTensorType()) - .TypeConstraint("values", DataTypeImpl::GetTensorType()); + .TypeConstraint("T1", DataTypeImpl::GetSparseTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()); #endif return def; } @@ -313,10 +311,12 @@ class SparseTensorTests : public testing::Test { std::vector register_actions; std::vector types; - public: SparseTensorTests() : session_object(SessionOptions(), GetEnvironment()), registry(std::make_shared()) { - EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); + } + + void SetUp() override { + ASSERT_STATUS_OK(session_object.RegisterCustomRegistry(registry)); } template @@ -334,13 +334,13 @@ class SparseTensorTests : public testing::Test { .SinceVersion(10) .Provider(onnxruntime::kCpuExecutionProvider); KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { out = std::make_unique(info); return Status::OK(); }; - EXPECT_TRUE(registry2->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK()); + ASSERT_STATUS_OK(registry2->RegisterCustomKernel(kernel_def_builder, kernel_create_fn)); }; register_actions.push_back(register_kernel); } void RegisterOps() { - EXPECT_TRUE(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 10, 11).IsOK()); + ASSERT_STATUS_OK(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 10, 11)); for (auto& registerop : register_actions) registerop(registry.get()); } @@ -357,8 +357,8 @@ class SparseTensorTests : public testing::Test { auto model_proto = model->ToProto(); EXPECT_TRUE(model_proto.SerializeToString(&serialized_model)); std::stringstream sstr(serialized_model); - EXPECT_TRUE(session_object.Load(sstr).IsOK()); - EXPECT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); } #if !defined(DISABLE_SPARSE_TENSORS) @@ -438,7 +438,7 @@ class SparseTensorTests : public testing::Test { RunOptions run_options; std::vector fetches; - EXPECT_TRUE(session_object.Run(run_options, feeds, output_names, &fetches).IsOK()); + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); ASSERT_EQ(expected_output.size(), fetches.size()); for (size_t i = 0; i < fetches.size(); ++i) { @@ -481,7 +481,7 @@ TEST_F(SparseTensorTests, Test1) { // Check graph, serialize it and deserialize it back Graph& graph = model->MainGraph(); - EXPECT_TRUE(graph.Resolve().IsOK()); + ASSERT_STATUS_OK(graph.Resolve()); SerializeAndLoad(); // Run the model @@ -525,7 +525,7 @@ TEST_F(SparseTensorTests, Test2) { // Check graph, serialize it and deserialize it back Graph& graph = model->MainGraph(); - EXPECT_TRUE(graph.Resolve().IsOK()); + ASSERT_STATUS_OK(graph.Resolve()); SerializeAndLoad(); // Run the model @@ -537,16 +537,16 @@ TEST_F(SparseTensorTests, Test2) { } TEST(SparseCrcsFormatTests, Test1) { - //const std::vector input_data = { - // 0, 1, 2, 0, 0, 0, 3, 4, 5, - // 6, 7, 8, 0, 0, 0, 9, 10, 11, - // 12, 13, 14, 0, 0, 0, 15, 16, 17, - // 0, 0, 0, 18, 19, 20, 21, 22, 23, - // 0, 0, 0, 24, 25, 26, 27, 28, 29, - // 0, 0, 0, 30, 31, 32, 33, 34, 35, - // 36, 37, 38, 39, 40, 41, 0, 0, 0, - // 42, 43, 44, 45, 46, 47, 0, 0, 0, - // 48, 49, 50, 51, 52, 53, 0, 0, 0}; + // const std::vector input_data = { + // 0, 1, 2, 0, 0, 0, 3, 4, 5, + // 6, 7, 8, 0, 0, 0, 9, 10, 11, + // 12, 13, 14, 0, 0, 0, 15, 16, 17, + // 0, 0, 0, 18, 19, 20, 21, 22, 23, + // 0, 0, 0, 24, 25, 26, 27, 28, 29, + // 0, 0, 0, 30, 31, 32, 33, 34, 35, + // 36, 37, 38, 39, 40, 41, 0, 0, 0, + // 42, 43, 44, 45, 46, 47, 0, 0, 0, + // 48, 49, 50, 51, 52, 53, 0, 0, 0}; auto* cpu_provider = TestCPUExecutionProvider(); auto cpu_transfer = cpu_provider->GetDataTransfer(); @@ -698,10 +698,9 @@ struct InsertIndices { std::vector indices_data; insert_indices_data(indices_1D, values_size, shape_size, indices_data, indices_tp); indices_tp.set_data_type(utils::ToTensorProtoElementType()); - if constexpr(sizeof(T) == sizeof(int8_t)) { + if constexpr (sizeof(T) == sizeof(int8_t)) { indices_tp.mutable_raw_data()->assign(reinterpret_cast(indices_data.data()), indices_data.size()); - } - else { + } else { // Conversion on the fly to the target data type std::vector indices(indices_data.cbegin(), indices_data.cend()); indices_tp.mutable_raw_data()->assign(reinterpret_cast(indices.data()), indices.size() * sizeof(T)); diff --git a/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc b/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc new file mode 100644 index 0000000000000..184bd6dacf1b1 --- /dev/null +++ b/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h" + +#include "gtest/gtest.h" + +#include "onnx/defs/schema.h" + +#include "core/graph/constants.h" + +namespace onnxruntime::test { + +// This test is to ensure the latest opset version for ops which can be added +// during layout transformation step are added. If this test fails then it means +// there is a new version available for one of the ops in the map. +TEST(LayoutTransformationPotentiallyAddedOpsTests, OpsHaveLatestVersions) { + const auto* schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); + + // kLayoutTransformationPotentiallyAddedOps is sorted in increasing order of + // iterate backwards and check the latest since_version of each domain, op_type + std::string_view prev_domain, prev_op_type{}; + + for (auto it = std::rbegin(kLayoutTransformationPotentiallyAddedOps), + end = std::rend(kLayoutTransformationPotentiallyAddedOps); + it != end; ++it) { + if (prev_domain != it->domain || prev_op_type != it->op_type) { + const auto* schema = schema_registry->GetSchema(std::string{it->op_type}, INT_MAX, std::string{it->domain}); + ASSERT_NE(schema, nullptr); + EXPECT_EQ(schema->SinceVersion(), it->since_version) + << "A new version for op " << it->op_type << " (" << schema->SinceVersion() + << ") is available. Please update kLayoutTransformationPotentiallyAddedOps to include it."; + prev_domain = it->domain; + prev_op_type = it->op_type; + } + } +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index d5c5e6de8e637..037108b3ebe00 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -88,12 +88,6 @@ class TestTransformer : public SelectorActionTransformer { } }; } // namespace sat - -std::unique_ptr CreateKernelRegistryManager() { - auto krm = std::make_unique(); - krm->RegisterKernelRegistry(TestCPUExecutionProvider()->GetKernelRegistry()); - return krm; -} } // namespace TEST(GraphRuntimeOptimizationTest, SaveRuntimeOptimizationToOrtFormat) { @@ -112,8 +106,7 @@ TEST(GraphRuntimeOptimizationTest, SaveRuntimeOptimizationToOrtFormat) { // run SAT to save runtime optimization { - auto kernel_registry_manager = CreateKernelRegistryManager(); - auto save_context = SatRuntimeOptimizationSaveContext{std::cref(*kernel_registry_manager)}; + auto save_context = SatRuntimeOptimizationSaveContext{}; auto test_transformer = std::make_unique(save_context); auto transformer_manager = GraphTransformerManager{/* steps */ 5}; ASSERT_STATUS_OK(transformer_manager.Register(std::move(test_transformer), TransformerLevel::Level1)); @@ -211,9 +204,9 @@ void SaveAndLoadRuntimeOptimizationsForModel( const GraphOpCountsCheckerFn& graph_op_counts_checker_for_replay) { auto run_test = [&](bool do_save) { // the two versions of the saved runtime optimizations file should be the same - // the one with the ".generated" suffix is generated by the test and the other is checked in + // the one with the ".test_output" suffix is generated by the test and the other is checked in const PathString saved_runtime_optimizations_model_path = - do_save ? ort_model_with_runtime_opt_path + ORT_TSTR(".generated") + do_save ? ort_model_with_runtime_opt_path + ORT_TSTR(".test_output") : ort_model_with_runtime_opt_path; SCOPED_TRACE(MakeString("ONNX model: '", ToUTF8String(onnx_model_path), diff --git a/onnxruntime/test/platform/ios/ios_package_test/README.md b/onnxruntime/test/platform/ios/ios_package_test/README.md index 96c3e055ee4e0..4b41036e59d1f 100644 --- a/onnxruntime/test/platform/ios/ios_package_test/README.md +++ b/onnxruntime/test/platform/ios/ios_package_test/README.md @@ -4,7 +4,7 @@ This End-to-End test app for iOS will test ORT Mobile C/C++ API framework using ## Requirements -- [Prerequisites for building ORT-Mobile for iOS](https://onnxruntime.ai/docs/build/android-ios.html#prerequisites-1) +- [Prerequisites for building ORT-Mobile for iOS](https://onnxruntime.ai/docs/build/ios.html#prerequisites) - [CocoaPods](https://cocoapods.org/) ## iOS End-to-End Test App Overview @@ -14,11 +14,13 @@ The iOS End-to-End Test App will use CocoaPods to install the Onnx Runtime C/C++ ### Model used - [sigmoid ONNX model](https://github.com/onnx/onnx/blob/f9b0cc99344869c246b8f4011b8586a39841284c/onnx/backend/test/data/node/test_sigmoid/model.onnx) converted to ORT format - Here's the [document](https://onnxruntime.ai/docs/tutorials/mobile/model-conversion.html) about how you can convert an ONNX model into ORT format. + Here's [documentation](https://onnxruntime.ai/docs/reference/ort-format-models.html#convert-onnx-models-to-ort-format) about how you can convert an ONNX model into ORT format. + + Run `python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed /path/to/model.onnx` and rename the resulting .ort file accordingly. ### Tests - [Tests for C++ API ](./ios_package_testUITests/ios_package_uitest_cpp_api.mm) ## Build and Test iOS Framework using [build.py](../../../../../tools/ci_build/build.py) -Use the [build for iOS simulator](https://onnxruntime.ai/docs/build/android-ios.html#cross-build-for-ios-simulator) with `--build_apple_framework` +Use the [build for iOS simulator](https://onnxruntime.ai/docs/build/ios.html#cross-compile-for-ios-simulator) with `--build_apple_framework` diff --git a/onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort b/onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort index 355bef8d92927..70d7659bbee25 100644 Binary files a/onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort and b/onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort differ diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 3165ec6352f8a..cb212aece8647 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -143,7 +143,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { // mnist model that has only had basic optimizations applied. CoreML should be able to take at least some of the nodes - const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort"); + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.basic.ort"); #if defined(__APPLE__) RandomValueGenerator random{}; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 1ec8d72af1656..58729bf933854 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -82,7 +82,7 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*registries*/) const { + const IKernelLookup& /*kernel_lookup*/) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; std::unordered_set supported_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 88e39c9572993..6313dd4e8dce3 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -17,7 +17,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const std::vector& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index a87932adf36bd..5727c44df34c5 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -33,7 +33,7 @@ using namespace onnxruntime::internal_testing_ep; #define ORT_MODEL_FOLDER ORT_TSTR("testdata/") -static void CreateSession(const SessionOptions& so, std::unique_ptr& session, +static Status CreateSession(const SessionOptions& so, std::unique_ptr& session, const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "mnist.onnx", // arbitrary test model bool enable_custom_ep = true, const std::unordered_set* override_supported_ops = nullptr) { @@ -48,12 +48,13 @@ static void CreateSession(const SessionOptions& so, std::unique_ptrRegisterExecutionProvider( + ORT_RETURN_IF_ERROR(session->RegisterExecutionProvider( std::make_unique(*supported_ops))); } - ASSERT_STATUS_OK(session->Load(model_path)); - ASSERT_STATUS_OK(session->Initialize()); + ORT_RETURN_IF_ERROR(session->Load(model_path)); + ORT_RETURN_IF_ERROR(session->Initialize()); + return Status::OK(); } static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enabled) { @@ -104,20 +105,20 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { SessionOptions so; so.optimized_model_filepath = ort_model_path; - CreateSession(so, session); + ASSERT_STATUS_OK(CreateSession(so, session)); // this graph should include the original nodes that the custom EP will take at runtime auto num_nodes = session->GetGraph().NumberOfNodes(); // // Second, load the ORT format model with just the CPU EP to make sure it can be executed. This tests that the - // fallback to the CPU EP kernel hashes works. + // fallback to the CPU EP works. // std::unique_ptr session2; so.optimized_model_filepath.clear(); bool enable_custom_ep = false; - CreateSession(so, session2, ort_model_path, enable_custom_ep); + ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path, enable_custom_ep)); const auto& graph1 = session2->GetGraph(); // model should have all the original nodes and we should be able to execute with the fallback to CPU EP ASSERT_EQ(graph1.NumberOfNodes(), num_nodes); @@ -129,7 +130,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { // for the ORT format model. // enable_custom_ep = true; - CreateSession(so, session2, ort_model_path, enable_custom_ep); + ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path, enable_custom_ep)); const auto& graph2 = session2->GetGraph(); // model should be able to be loaded, and we should compile using custom ep. that will result in one node for the // custom EP (with Conv/Add/Relu/MaxPool), one for a reshape, and one for the fused MatMul+Add. @@ -331,7 +332,7 @@ TEST(InternalTestingEP, TestLoadOrtModel) { std::unique_ptr session; bool enable_custom_ep = true; - CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep); + ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep)); ExecuteMnist(*session, enable_custom_ep); } @@ -344,7 +345,7 @@ TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { std::unique_ptr session; bool enable_custom_ep = true; - CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops); + ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops)); const auto& graph = session->GetGraph(); // Conv+Add gets fused by level 1 optimizer into single node. The 'Conv'/'Add'/'Relu' nodes should be compiled and @@ -405,7 +406,7 @@ TEST(InternalTestingEP, TestModelWithSubgraph) { std::unique_ptr session; bool enable_custom_ep = true; - CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops); + ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops)); const auto& graph = session->GetGraph(); auto& func_mgr = const_cast(session->GetSessionState()).GetMutableFuncMgr(); diff --git a/onnxruntime/test/providers/kernel_def_hash_test.cc b/onnxruntime/test/providers/kernel_def_hash_test.cc deleted file mode 100644 index 5e946ac39dd96..0000000000000 --- a/onnxruntime/test/providers/kernel_def_hash_test.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/** - * IMPORTANT NOTE AT THE TOP OF THE FILE - * - * This file contains tests which verify expected kernel def hashes. - * It is important for these to remain stable so that ORT format models are - * backward compatible. - * - * If you are seeing a test failure from one of these tests, it is likely that - * some kernel definition changed in a way that updated its hash value. - * This is what we want to catch! Please update the kernel definition. - * If adding more supported types to an existing kernel definition, consider - * using KernelDefBuilder::FixedTypeConstraintForHash(). - * - * For example: - * Say we have a kernel definition like this, which supports types int and - * double: - * KernelDefBuilder{} - * .TypeConstraint( - * "T", BuildKernelDefConstraints()) - * If we want to update the kernel definition to add support for float, we can - * change it to something like this: - * KernelDefBuilder{} - * .TypeConstraint( - * "T", BuildKernelDefConstraints()) - * .FixedTypeConstraintForHash( - * "T", BuildKernelDefConstraints()) - * In the updated kernel definition, the original types are specified with - * FixedTypeConstraintForHash(). - * - * New kernel definitions should not use FixedTypeConstraintForHash(). - * It is a way to keep the hash stable as kernel definitions change. - * - * It is also possible that you have added a new kernel definition and are - * seeing a message from one of these tests about updating the expected data. - * Please do that if appropriate. - * - * The expected value files are in this directory: - * onnxruntime/test/testdata/kernel_def_hashes - * The data is specified in JSON as an array of key-value arrays. - * Example data can be written to stdout with this test: - * KernelDefHashTest.DISABLED_PrintCpuKernelDefHashes - * Use the option --gtest_also_run_disabled_tests to enable it. - * Be careful about updating the expected values - as mentioned before, the - * values should be stable. Typically, we should only add new entries. - * - * In the unlikely event that we need to make a change to the kernel def - * hashing that breaks backward compatibility, the expected values may need to - * be updated. You will also need to update UpdateHashForBackwardsCompatibility - * in onnxruntime/core/framework/kernel_def_hash_helpers.cc, add a - * test model for the operator in question to onnxruntime/test/testdata/ort_backwards_compat - * and update OrtModelOnlyTests.TestBackwardsCompat in onnxruntime/test/framework/ort_model_only_test.cc - * to load the new model and validate the hash replacement works correctly. - */ - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "onnxruntime_config.h" - -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 28020) -#elif __aarch64__ && defined(HAS_FORMAT_TRUNCATION) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-truncation" -#endif -#include "nlohmann/json.hpp" -#ifdef _WIN32 -#pragma warning(pop) -#elif __aarch64__ && defined(HAS_FORMAT_TRUNCATION) -#pragma GCC diagnostic pop -#endif - -#include "asserts.h" -#include "core/common/common.h" -#include "core/common/path_string.h" -#include "core/framework/kernel_registry.h" -#include "core/mlas/inc/mlas.h" -#include "core/platform/env_var_utils.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "gtest/gtest.h" - -using json = nlohmann::json; - -namespace onnxruntime { -namespace test { - -namespace { -// If set to 1, do strict checking of the kernel def hash values. -// With strict checking, the expected and actual values must match exactly. -// Otherwise, the expected values must be present in the actual values. -static constexpr const char* kStrictKernelDefHashCheckEnvVar = - "ORT_TEST_STRICT_KERNEL_DEF_HASH_CHECK"; - -std::string DumpKernelDefHashes(const onnxruntime::KernelDefHashes& kernel_def_hashes) { - const json j(kernel_def_hashes); - return j.dump(/* indent */ 4); -} - -KernelDefHashes ParseKernelDefHashes(std::istream& in) { - KernelDefHashes kernel_def_hashes{}; - const json j = json::parse(in); - j.get_to(kernel_def_hashes); - return kernel_def_hashes; -} - -void AppendKernelDefHashesFromFile(const PathString& path, KernelDefHashes& kernel_def_hashes) { - std::ifstream in{path}; - ORT_ENFORCE(in, "Failed to open file: ", ToUTF8String(path)); - const auto file_kernel_def_hashes = ParseKernelDefHashes(in); - kernel_def_hashes.insert( - kernel_def_hashes.end(), file_kernel_def_hashes.begin(), file_kernel_def_hashes.end()); -} - -void CheckKernelDefHashes(const KernelDefHashes& actual, const KernelDefHashes& expected, bool is_strict) { - ASSERT_TRUE(std::is_sorted(actual.begin(), actual.end())); - ASSERT_TRUE(std::is_sorted(expected.begin(), expected.end())); - - constexpr const char* kNoteReference = "Note: Please read the note at the top of this file: " __FILE__; - - KernelDefHashes expected_minus_actual{}; - std::set_difference(expected.begin(), expected.end(), actual.begin(), actual.end(), - std::back_inserter(expected_minus_actual)); - if (!expected_minus_actual.empty()) { - const auto message = MakeString( - "Some expected kernel def hashes were not found.\n", - kNoteReference, "\n", - DumpKernelDefHashes(expected_minus_actual)); - ADD_FAILURE() << message; - } - - KernelDefHashes actual_minus_expected{}; - std::set_difference(actual.begin(), actual.end(), expected.begin(), expected.end(), - std::back_inserter(actual_minus_expected)); - if (!actual_minus_expected.empty()) { - const auto message = MakeString( - "Unexpected kernel def hashes were found, please update the expected values as needed " - "(see the output below).\n", - kNoteReference, "\n", - DumpKernelDefHashes(actual_minus_expected)); - if (is_strict) { - ADD_FAILURE() << message; - } else { - std::cerr << message << "\n"; - } - } -} -} // namespace - -TEST(KernelDefHashTest, DISABLED_PrintCpuKernelDefHashes) { - KernelRegistry kernel_registry{}; - ASSERT_STATUS_OK(RegisterCPUKernels(kernel_registry)); - const auto cpu_kernel_def_hashes = kernel_registry.ExportKernelDefHashes(); - std::cout << DumpKernelDefHashes(cpu_kernel_def_hashes) << "\n"; -} - -TEST(KernelDefHashTest, ExpectedCpuKernelDefHashes) { - const bool is_strict = ParseEnvironmentVariableWithDefault(kStrictKernelDefHashCheckEnvVar, false); - - const auto expected_cpu_kernel_def_hashes = []() { - KernelDefHashes result{}; - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/onnx.cpu.json"), result); -#if !defined(DISABLE_ML_OPS) - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/onnx.ml.cpu.json"), result); -#endif // !DISABLE_ML_OPS -#if !defined(DISABLE_CONTRIB_OPS) - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/contrib.cpu.json"), result); - // NCHWc kernels are enabled if MlasNchwcGetBlockSize() > 1 - if (MlasNchwcGetBlockSize() > 1) { - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/contrib.nchwc.cpu.json"), result); - } -#endif // !DISABLE_CONTRIB_OPS -#if defined(ENABLE_TRAINING_OPS) - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/training_ops.cpu.json"), result); -#endif // ENABLE_TRAINING_OPS -#if !defined(DISABLE_OPTIONAL_TYPE) - AppendKernelDefHashesFromFile(ORT_TSTR("testdata/kernel_def_hashes/onnx.optional_type_ops.cpu.json"), result); -#endif // !DISABLE_OPTIONAL_TYPE - // TODO also handle kernels enabled by these symbols: BUILD_MS_EXPERIMENTAL_OPS - std::sort(result.begin(), result.end()); - return result; - }(); - - KernelRegistry kernel_registry{}; - ASSERT_STATUS_OK(RegisterCPUKernels(kernel_registry)); - auto cpu_kernel_def_hashes = kernel_registry.ExportKernelDefHashes(); - - CheckKernelDefHashes(cpu_kernel_def_hashes, expected_cpu_kernel_def_hashes, is_strict); -} - -// This test is to ensure the latest opset version for ops which can be added -// during layout transformation step are added. IF this test fails then it means -// there is a new version available for one of the ops in the map. -// Adding this test here because resolution for this test failure requires fetching the hash -// for one of the ops in the list below and this file has information around that. -// Please update the following 3 places: -// 1. optimizer/transpose_optimizer/optimizer_api_impl.cc "onnx_ops_available_versions" map, -// include the latest version in the map -// 2. framework/kernel_def_hash_helpers.cc:GetHashValueFromStaticKernelHashMap "static_kernel_hashes" map, -// add an entry for latest version and its associated hash -// 3. KernelDefHashTest.TestNewOpsVersionSupportDuringLayoutTransform "onnx_ops_available_versions" map, -// include the latest version in the map -TEST(KernelDefHashTest, TestNewOpsVersionSupportDuringLayoutTransform) { - static const std::unordered_map> onnx_ops_available_versions = { - {"Squeeze", {1, 11, 13}}, - {"Unsqueeze", {1, 11, 13}}, - {"Gather", {1, 11, 13}}, - {"Transpose", {1, 13}}, - {"Identity", {1, 13, 14, 16}}, - }; - - auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); - for (const auto& [op_type, version_list] : onnx_ops_available_versions) { - auto schema = schema_registry->GetSchema(op_type, INT_MAX, kOnnxDomain); - EXPECT_EQ(schema->SinceVersion(), version_list[version_list.size() - 1]) << "A new version for op: " << op_type - << "is available. Please update the files mentioned in the comments of this test."; - } -} -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 3d557ad6da57c..e16c4b15f876b 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -34,9 +34,10 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) namespace { +[[maybe_unused]] void TestModelLoad(const ORTCHAR_T* model_file_name, const std::function& check_graph) { SessionOptions so; InferenceSessionWrapper session_object{so, GetEnvironment()}; @@ -47,6 +48,10 @@ void TestModelLoad(const ORTCHAR_T* model_file_name, const std::function/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -484,7 +489,7 @@ TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { TEST(NnapiExecutionProviderTest, TestOrtFormatModel) { // mnist model that has only had basic optimizations applied. nnapi should be able to take at least some of the nodes - const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort"); + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.basic.ort"); // The execution can only be performed on Android #if defined(__ANDROID__) @@ -514,7 +519,7 @@ TEST(NnapiExecutionProviderTest, TestOrtFormatModel) { // test that NNAPI EP can process an activation node that is outside of its partition TEST(NnapiExecutionProviderTest, ActivationOutsideOfPartition) { // model starts with Conv -> Relu - constexpr auto* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort"); + constexpr auto* model_file_name = ORT_TSTR("testdata/mnist.basic.ort"); // stop NNAPI partitioning at Relu so NNAPI EP only takes first Conv const auto nnapi_partitioning_stop_ops = "Relu"; SessionOptions so; diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 3412a6d4334c3..1079b01031cae 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -1142,6 +1142,7 @@ void OpTester::Run( continue; bool valid = true; + const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; // set execution provider for all nodes in the graph for (auto& node : graph.Nodes()) { @@ -1159,11 +1160,13 @@ void OpTester::Run( provider_type == onnxruntime::kSnpeExecutionProvider) continue; auto reg = execution_provider->GetKernelRegistry(); - if (!KernelRegistry::HasImplementationOf(*reg, node, execution_provider->Type())) { + if (!KernelRegistry::HasImplementationOf(*reg, node, execution_provider->Type(), + kernel_type_str_resolver)) { valid = false; for (auto& custom_session_registry : custom_session_registries_) { if (KernelRegistry::HasImplementationOf(*custom_session_registry->GetKernelRegistry(), - node, execution_provider->Type())) { + node, execution_provider->Type(), + kernel_type_str_resolver)) { valid = true; break; } diff --git a/onnxruntime/test/testdata/foo_1.onnx.ort b/onnxruntime/test/testdata/foo_1.onnx.ort index 71a863db06546..7ebbbe736ee7f 100644 Binary files a/onnxruntime/test/testdata/foo_1.onnx.ort and b/onnxruntime/test/testdata/foo_1.onnx.ort differ diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json deleted file mode 100644 index 5fb55faa14c5c..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json +++ /dev/null @@ -1,306 +0,0 @@ -[ - [ - "Affine ai.onnx CPUExecutionProvider", - 7811918192248490408 - ], - [ - "BeamSearch com.microsoft CPUExecutionProvider", - 6968087233460196528 - ], - [ - "Crop ai.onnx CPUExecutionProvider", - 6914973556202621376 - ], - [ - "DynamicSlice ai.onnx CPUExecutionProvider", - 5387668728763060584 - ], - [ - "GreedySearch com.microsoft CPUExecutionProvider", - 9790977725959310408 - ], - [ - "ImageScaler ai.onnx CPUExecutionProvider", - 2013093418027264536 - ], - [ - "LayerNormalization ai.onnx CPUExecutionProvider", - 4827261308628792072 - ], - [ - "LayerNormalization ai.onnx CPUExecutionProvider", - 18418354579469131656 - ], - [ - "MeanVarianceNormalization ai.onnx CPUExecutionProvider", - 13114085849278607104 - ], - [ - "NhwcMaxPool com.microsoft CPUExecutionProvider", - 11773579655431087496 - ], - [ - "ParametricSoftplus ai.onnx CPUExecutionProvider", - 17971715260566574960 - ], - [ - "Scale ai.onnx CPUExecutionProvider", - 12599351089228483328 - ], - [ - "ScaledTanh ai.onnx CPUExecutionProvider", - 15584477984618710520 - ], - [ - "SimplifiedLayerNormalization ai.onnx CPUExecutionProvider", - 4809288790945391544 - ], - [ - "SimplifiedLayerNormalization ai.onnx CPUExecutionProvider", - 13556035637124174064 - ], - [ - "SparseToDenseMatMul com.microsoft CPUExecutionProvider", - 1738555621134130312 - ], - [ - "ThresholdedRelu ai.onnx CPUExecutionProvider", - 17820769706565099200 - ], - [ - "Attention com.microsoft CPUExecutionProvider", - 16464502426529915864 - ], - [ - "AttnLSTM com.microsoft CPUExecutionProvider", - 15421184737689665128 - ], - [ - "BiasGelu com.microsoft CPUExecutionProvider", - 12457646955212583504 - ], - [ - "BifurcationDetector com.microsoft CPUExecutionProvider", - 12148442056374193608 - ], - [ - "CDist com.microsoft CPUExecutionProvider", - 889036143745127232 - ], - [ - "CDist com.microsoft CPUExecutionProvider", - 6280134801002897280 - ], - [ - "ConvTransposeWithDynamicPads com.microsoft CPUExecutionProvider", - 1596732273609633752 - ], - [ - "CropAndResize com.microsoft CPUExecutionProvider", - 8620498355864235632 - ], - [ - "DequantizeLinear com.microsoft CPUExecutionProvider", - 9034670407031092344 - ], - [ - "DequantizeLinear com.microsoft CPUExecutionProvider", - 12760451019866331016 - ], - [ - "DynamicQuantizeLSTM com.microsoft CPUExecutionProvider", - 640700714499684624 - ], - [ - "DynamicQuantizeMatMul com.microsoft CPUExecutionProvider", - 9591826880179639184 - ], - [ - "EmbedLayerNormalization com.microsoft CPUExecutionProvider", - 14614049725238705256 - ], - [ - "ExpandDims com.microsoft CPUExecutionProvider", - 5671892069881567792 - ], - [ - "FastGelu com.microsoft CPUExecutionProvider", - 18210983793195477200 - ], - [ - "FusedConv com.microsoft CPUExecutionProvider", - 11366858116389652832 - ], - [ - "FusedGemm com.microsoft CPUExecutionProvider", - 1341171831223136792 - ], - [ - "FusedMatMul com.microsoft CPUExecutionProvider", - 665364151288353496 - ], - [ - "GatherND com.microsoft CPUExecutionProvider", - 8466578404783779600 - ], - [ - "Gelu com.microsoft CPUExecutionProvider", - 4658746266161736328 - ], - [ - "GridSample com.microsoft CPUExecutionProvider", - 11924582339825775592 - ], - [ - "Inverse com.microsoft CPUExecutionProvider", - 1037755270231788608 - ], - [ - "MatMulInteger16 com.microsoft CPUExecutionProvider", - 5265636774129358144 - ], - [ - "MatMulIntegerToFloat com.microsoft CPUExecutionProvider", - 1363870470731747600 - ], - [ - "MatMulIntegerToFloat com.microsoft CPUExecutionProvider", - 7172777464471435800 - ], - [ - "MaxpoolWithMask com.microsoft CPUExecutionProvider", - 3144686615632467360 - ], - [ - "MurmurHash3 com.microsoft CPUExecutionProvider", - 2533733396673225096 - ], - [ - "NhwcMaxPool com.microsoft CPUExecutionProvider", - 8512357837341844248 - ], - [ - "NGramRepeatBlock com.microsoft CPUExecutionProvider", - 17162613206685017176 - ], - [ - "Pad com.microsoft CPUExecutionProvider", - 15076596470814458544 - ], - [ - "QAttention com.microsoft CPUExecutionProvider", - 9844377440996919912 - ], - [ - "QLinearAdd com.microsoft CPUExecutionProvider", - 9958112514164905192 - ], - [ - "QLinearAdd com.microsoft CPUExecutionProvider", - 16322459350118343880 - ], - [ - "QLinearAveragePool com.microsoft CPUExecutionProvider", - 9152647959212466896 - ], - [ - "QLinearConv com.microsoft CPUExecutionProvider", - 10904143578341560456 - ], - [ - "QLinearConv com.microsoft CPUExecutionProvider", - 16835965565578160400 - ], - [ - "QLinearGlobalAveragePool com.microsoft CPUExecutionProvider", - 8729391959357542728 - ], - [ - "QLinearLeakyRelu com.microsoft CPUExecutionProvider", - 3677670974923917280 - ], - [ - "QLinearLeakyRelu com.microsoft CPUExecutionProvider", - 17073324515720209136 - ], - [ - "QLinearMul com.microsoft CPUExecutionProvider", - 2406593953080780408 - ], - [ - "QLinearMul com.microsoft CPUExecutionProvider", - 17403503869116794888 - ], - [ - "QLinearSigmoid com.microsoft CPUExecutionProvider", - 17020165931626188400 - ], - [ - "QLinearSigmoid com.microsoft CPUExecutionProvider", - 17315947486917903320 - ], - [ - "QuantizeLinear com.microsoft CPUExecutionProvider", - 616915237400456368 - ], - [ - "QuantizeLinear com.microsoft CPUExecutionProvider", - 13556449850953958792 - ], - [ - "Range com.microsoft CPUExecutionProvider", - 9333951582187402912 - ], - [ - "SampleOp com.microsoft CPUExecutionProvider", - 11028204786545834016 - ], - [ - "SkipLayerNormalization com.microsoft CPUExecutionProvider", - 1829676129267529920 - ], - [ - "SkipLayerNormalization com.microsoft CPUExecutionProvider", - 15124962608939318760 - ], - [ - "Tokenizer com.microsoft CPUExecutionProvider", - 12821105347567077024 - ], - [ - "TransposeMatMul com.microsoft CPUExecutionProvider", - 3696625852111461496 - ], - [ - "Trilu com.microsoft CPUExecutionProvider", - 1828108687906670152 - ], - [ - "Unique com.microsoft CPUExecutionProvider", - 17512097873619224240 - ], - [ - "WordConvEmbedding com.microsoft CPUExecutionProvider", - 7416606351345164776 - ], - [ - "QLinearConcat com.microsoft CPUExecutionProvider", - 1734858160766311432 - ], - [ - "QEmbedLayerNormalization com.microsoft CPUExecutionProvider", - 9235385557940152248 - ], - [ - "QGemm com.microsoft CPUExecutionProvider", - 13009794669709617232 - ], - [ - "QGemm com.microsoft CPUExecutionProvider", - 13737193491843065240 - ], - [ - "QLinearSoftmax com.microsoft CPUExecutionProvider", - 10339195975968977840 - ] -] diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.nchwc.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.nchwc.cpu.json deleted file mode 100644 index 0553469b86706..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.nchwc.cpu.json +++ /dev/null @@ -1,34 +0,0 @@ -[ - [ - "AveragePool com.microsoft.nchwc CPUExecutionProvider", - 12528194512485261552 - ], - [ - "Conv com.microsoft.nchwc CPUExecutionProvider", - 10643058043438608528 - ], - [ - "GlobalAveragePool com.microsoft.nchwc CPUExecutionProvider", - 9401543287182687288 - ], - [ - "GlobalMaxPool com.microsoft.nchwc CPUExecutionProvider", - 17341568537930161320 - ], - [ - "MaxPool com.microsoft.nchwc CPUExecutionProvider", - 14527249939908647936 - ], - [ - "ReorderInput com.microsoft.nchwc CPUExecutionProvider", - 14330795113746035424 - ], - [ - "ReorderOutput com.microsoft.nchwc CPUExecutionProvider", - 13428915370009679360 - ], - [ - "Upsample com.microsoft.nchwc CPUExecutionProvider", - 16347985363638744760 - ] -] \ No newline at end of file diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json deleted file mode 100644 index d9271921c1e90..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json +++ /dev/null @@ -1,2730 +0,0 @@ -[ - [ - "Abs ai.onnx CPUExecutionProvider", - 504798580737304624 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 2183148449174722344 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 2655036503625551040 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 3009851702411644032 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 3190939484633723656 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 6055268975966688000 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 6145569101795971584 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 7827897323524555008 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 8502448309239022272 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 8971473958183026704 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 9434261760991393328 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 10431403221717028816 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 11069696326942726312 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 11486507016884713936 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 11869111214782765680 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 14387905014097763280 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 15605730591228317016 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 16777769116337640080 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 17083986781747169264 - ], - [ - "Abs ai.onnx CPUExecutionProvider", - 17122435461807292320 - ], - [ - "Acos ai.onnx CPUExecutionProvider", - 12559777476169161176 - ], - [ - "Acosh ai.onnx CPUExecutionProvider", - 8147308773618631360 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 2809400508152416640 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 4765849540931245312 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 7755855355830999664 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 9430048329647241312 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 2811294571843125440 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 3911830094991457848 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 7526632132387248720 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 10898881467681895920 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 10954442170366922400 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 14227555443561151544 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 17163619535281201712 - ], - [ - "Add ai.onnx CPUExecutionProvider", - 17837365715369996800 - ], - [ - "And ai.onnx CPUExecutionProvider", - 7931711152704979424 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 725939091209285112 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 1506815612154586624 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 2317240841713684752 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 2424977109168024608 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 2722235460370749928 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 3530049861321274904 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 3776111440658888968 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 4620340115180881848 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 10486830423259333928 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 12157492629288967928 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 12195502223979275536 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 13027735142126919896 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 14910778597193697496 - ], - [ - "ArgMax ai.onnx CPUExecutionProvider", - 18394842867298045376 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 3259714541817453752 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 6306867649701766768 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 7388204937281788320 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 9847316221864129112 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 15228181434236004952 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 15675055269913165744 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 17063009147054142544 - ], - [ - "ArgMin ai.onnx CPUExecutionProvider", - 18140858399351623728 - ], - [ - "Asin ai.onnx CPUExecutionProvider", - 2279037228747137360 - ], - [ - "Asinh ai.onnx CPUExecutionProvider", - 4335637335083209024 - ], - [ - "Atan ai.onnx CPUExecutionProvider", - 6156207978523300880 - ], - [ - "Atanh ai.onnx CPUExecutionProvider", - 7281471694245287776 - ], - [ - "AveragePool ai.onnx CPUExecutionProvider", - 1691409737937146952 - ], - [ - "AveragePool ai.onnx CPUExecutionProvider", - 7064530023477327152 - ], - [ - "AveragePool ai.onnx CPUExecutionProvider", - 10567645262006979816 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 3151445409522059920 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 8260480971610783200 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 12009127760899929456 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 18128921553709069152 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 13094179255141648608 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 17832136363477464736 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 3016597991190826984 - ], - [ - "BatchNormalization ai.onnx CPUExecutionProvider", - 9270095107043637928 - ], - [ - "BitShift ai.onnx CPUExecutionProvider", - 4758677670685660688 - ], - [ - "BitShift ai.onnx CPUExecutionProvider", - 6590346922848896960 - ], - [ - "BitShift ai.onnx CPUExecutionProvider", - 8765933529403563240 - ], - [ - "BlackmanWindow ai.onnx CPUExecutionProvider", - 4230790036355038984 - ], - [ - "Cast ai.onnx CPUExecutionProvider", - 4892631558605514456 - ], - [ - "Cast ai.onnx CPUExecutionProvider", - 13977811166737747504 - ], - [ - "Ceil ai.onnx CPUExecutionProvider", - 2776352081059075072 - ], - [ - "Ceil ai.onnx CPUExecutionProvider", - 8175374217867982592 - ], - [ - "Celu ai.onnx CPUExecutionProvider", - 6586235226200663216 - ], - [ - "Clip ai.onnx CPUExecutionProvider", - 10165399686383129936 - ], - [ - "Clip ai.onnx CPUExecutionProvider", - 11050032992058305024 - ], - [ - "Clip ai.onnx CPUExecutionProvider", - 15464277967303611680 - ], - [ - "Clip ai.onnx CPUExecutionProvider", - 15835199378384293888 - ], - [ - "Compress ai.onnx CPUExecutionProvider", - 7362152988702475176 - ], - [ - "Compress ai.onnx CPUExecutionProvider", - 12011480601252498104 - ], - [ - "Concat ai.onnx CPUExecutionProvider", - 6735462296507386696 - ], - [ - "Concat ai.onnx CPUExecutionProvider", - 13881250836434776192 - ], - [ - "Concat ai.onnx CPUExecutionProvider", - 17984471790875795296 - ], - [ - "ConcatFromSequence ai.onnx CPUExecutionProvider", - 12684746449905165864 - ], - [ - "ConstantOfShape ai.onnx CPUExecutionProvider", - 11399309062544840088 - ], - [ - "Conv ai.onnx CPUExecutionProvider", - 8328794455908578232 - ], - [ - "Conv ai.onnx CPUExecutionProvider", - 16516917846545343592 - ], - [ - "ConvInteger ai.onnx CPUExecutionProvider", - 100167825365193136 - ], - [ - "ConvTranspose ai.onnx CPUExecutionProvider", - 35623394092856472 - ], - [ - "ConvTranspose ai.onnx CPUExecutionProvider", - 4454044225968077856 - ], - [ - "Cos ai.onnx CPUExecutionProvider", - 1499142588988375424 - ], - [ - "Cosh ai.onnx CPUExecutionProvider", - 15345496796160944720 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 5082944229190913256 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 10469673094092328264 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 10623377395346323656 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 11339703050808462936 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 12655022500731400520 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 13407812519646293824 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 14253657812461767552 - ], - [ - "CumSum ai.onnx CPUExecutionProvider", - 15718188077307851520 - ], - [ - "DepthToSpace ai.onnx CPUExecutionProvider", - 1243399198864605832 - ], - [ - "DepthToSpace ai.onnx CPUExecutionProvider", - 2375217984187160248 - ], - [ - "DepthToSpace ai.onnx CPUExecutionProvider", - 8969883522517745168 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 659465705888746048 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 4681578354178800352 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 10633816119191824336 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 11078037650349877208 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 12063337043167780184 - ], - [ - "DequantizeLinear ai.onnx CPUExecutionProvider", - 13994311551260810568 - ], - [ - "Det ai.onnx CPUExecutionProvider", - 4355346295804324544 - ], - [ - "DFT ai.onnx CPUExecutionProvider", - 2809655513372322840 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 3765227735719542728 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 6096830237778328064 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 8806467529483118056 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 10863732091582123872 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 1070579688547233192 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 3530003233152221064 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 4576785625651393824 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 4887042330757426576 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 10014579724542017176 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 15607446330033001248 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 15682512038856025312 - ], - [ - "Div ai.onnx CPUExecutionProvider", - 18087152238491883272 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 959924377557845840 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 16708009824840936392 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 4039545479597223936 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 4144638777043373176 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 7435465647001245928 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 10786060100999379200 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 10811776928758670648 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 13037112407948355832 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 17469316612540314360 - ], - [ - "Dropout ai.onnx CPUExecutionProvider", - 17666880288828056264 - ], - [ - "DynamicQuantizeLinear ai.onnx CPUExecutionProvider", - 15568473034242820680 - ], - [ - "Einsum ai.onnx CPUExecutionProvider", - 14280390279553192696 - ], - [ - "Elu ai.onnx CPUExecutionProvider", - 3332615861526569160 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 465760068542515872 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 2485312216842316792 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 5003664709567541688 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 5513564698408644984 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 6741326113368386488 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 6935339698501351552 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 7519280692663295968 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 9545015110259334152 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 10522089514617653216 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 11314599873697707856 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 12332165502849781208 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 13454187278267603432 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 13975464527332248864 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 16427932130421637448 - ], - [ - "Equal ai.onnx CPUExecutionProvider", - 16722605182207544888 - ], - [ - "Erf ai.onnx CPUExecutionProvider", - 260353630666949200 - ], - [ - "Erf ai.onnx CPUExecutionProvider", - 16431777616328093792 - ], - [ - "Exp ai.onnx CPUExecutionProvider", - 4663516500907248816 - ], - [ - "Exp ai.onnx CPUExecutionProvider", - 7129230664682662520 - ], - [ - "Exp ai.onnx CPUExecutionProvider", - 10374157387442478048 - ], - [ - "Exp ai.onnx CPUExecutionProvider", - 12745858707361077304 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 511570327851406656 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 999468057623660600 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 1198571072854699920 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 1971417416603291272 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 1981128892509939080 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 4183894922616455816 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 4263550559080595904 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 5514423297929304016 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 6347847198186518064 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 6784735496995453856 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 6974166611804396280 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 6977454377250272128 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 7035613483753156272 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 9054423380544288056 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 9206973352625403832 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 13173453356648366440 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 14040663054230481832 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 14528904907882240528 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 14796966018183006720 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 15910553252200910392 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 16797173020098987784 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 16840132632255706048 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 17054642365717222176 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 17179539307526447632 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 17988714816908604160 - ], - [ - "Expand ai.onnx CPUExecutionProvider", - 18204147471244917832 - ], - [ - "EyeLike ai.onnx CPUExecutionProvider", - 14424733160521273664 - ], - [ - "Flatten ai.onnx CPUExecutionProvider", - 2834436778710670368 - ], - [ - "Flatten ai.onnx CPUExecutionProvider", - 15460288903846889664 - ], - [ - "Flatten ai.onnx CPUExecutionProvider", - 15671451779761409376 - ], - [ - "Flatten ai.onnx CPUExecutionProvider", - 16880471824466361464 - ], - [ - "Floor ai.onnx CPUExecutionProvider", - 4415851876932201928 - ], - [ - "Floor ai.onnx CPUExecutionProvider", - 10741594030400229240 - ], - [ - "Gather ai.onnx CPUExecutionProvider", - 625186873870077080 - ], - [ - "Gather ai.onnx CPUExecutionProvider", - 7462749543760614528 - ], - [ - "Gather ai.onnx CPUExecutionProvider", - 11761559382112736008 - ], - [ - "GatherElements ai.onnx CPUExecutionProvider", - 4500561021631132744 - ], - [ - "GatherElements ai.onnx CPUExecutionProvider", - 8274891310147440848 - ], - [ - "GatherND ai.onnx CPUExecutionProvider", - 3064028185911332496 - ], - [ - "GatherND ai.onnx CPUExecutionProvider", - 634727773751317256 - ], - [ - "GatherND ai.onnx CPUExecutionProvider", - 11311962292460032936 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 924315840375058080 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 2778484524162833808 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 8509578291145888416 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 9230597922178086344 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 9404518639984045696 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 11532684775587515192 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 11876146058839305680 - ], - [ - "Gemm ai.onnx CPUExecutionProvider", - 13401942613499179992 - ], - [ - "GlobalAveragePool ai.onnx CPUExecutionProvider", - 13997705024068872760 - ], - [ - "GlobalLpPool ai.onnx CPUExecutionProvider", - 12804454614028655440 - ], - [ - "GlobalMaxPool ai.onnx CPUExecutionProvider", - 1549541000860714808 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 449531580553467312 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 6244347306129731472 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 9673327496509586840 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 9970070545282376512 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 10403955381817977536 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 11554950443553084568 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 15279867060861067528 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 16240681799461676968 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 16777287786353424024 - ], - [ - "Greater ai.onnx CPUExecutionProvider", - 16852011221046024392 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 3999586969438630368 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 8317279776362716048 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 14896183015337647264 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 17416867432093505280 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 4445196831337347808 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 5607996550991708432 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 15145566961923527384 - ], - [ - "GreaterOrEqual ai.onnx CPUExecutionProvider", - 16172564801671050120 - ], - [ - "GridSample ai.onnx CPUExecutionProvider", - 15150264021585158264 - ], - [ - "GRU ai.onnx CPUExecutionProvider", - 5656895183302764352 - ], - [ - "GRU ai.onnx CPUExecutionProvider", - 2706165712066264784 - ], - [ - "HammingWindow ai.onnx CPUExecutionProvider", - 7960927909626268504 - ], - [ - "HannWindow ai.onnx CPUExecutionProvider", - 11998243503561799520 - ], - [ - "Hardmax ai.onnx CPUExecutionProvider", - 3471079605532327368 - ], - [ - "Hardmax ai.onnx CPUExecutionProvider", - 4275195189155928568 - ], - [ - "Hardmax ai.onnx CPUExecutionProvider", - 17512886821071582168 - ], - [ - "HardSigmoid ai.onnx CPUExecutionProvider", - 8073874088399301792 - ], - [ - "Identity ai.onnx CPUExecutionProvider", - 16879814636194901248 - ], - [ - "Identity ai.onnx CPUExecutionProvider", - 18001636502361632792 - ], - [ - "Identity ai.onnx CPUExecutionProvider", - 16515685968327103576 - ], - [ - "Identity ai.onnx CPUExecutionProvider", - 17661628575887109792 - ], - [ - "If ai.onnx CPUExecutionProvider", - 2236645891232685176 - ], - [ - "If ai.onnx CPUExecutionProvider", - 8375779015031532008 - ], - [ - "If ai.onnx CPUExecutionProvider", - 17525337343748410800 - ], - [ - "If ai.onnx CPUExecutionProvider", - 14605093928914589528 - ], - [ - "InstanceNormalization ai.onnx CPUExecutionProvider", - 17511351958774814976 - ], - [ - "IsInf ai.onnx CPUExecutionProvider", - 17892003161986514744 - ], - [ - "IsNaN ai.onnx CPUExecutionProvider", - 1410678655050221856 - ], - [ - "IsNaN ai.onnx CPUExecutionProvider", - 4005964532771713904 - ], - [ - "IsNaN ai.onnx CPUExecutionProvider", - 6643135320397294616 - ], - [ - "IsNaN ai.onnx CPUExecutionProvider", - 8601031213185457224 - ], - [ - "LeakyRelu ai.onnx CPUExecutionProvider", - 7020782930120154768 - ], - [ - "LeakyRelu ai.onnx CPUExecutionProvider", - 830582302303937272 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 2529281912870061232 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 2613688346938587336 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 2728241242035668880 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 2895263051054160696 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 9397354372953308456 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 15035453610574415224 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 15639388961500593472 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 15763680961107898296 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 16996631125248308216 - ], - [ - "Less ai.onnx CPUExecutionProvider", - 17960128831236491008 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 1261667279452953168 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 2051143717905239376 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 4697898477799165704 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 8848289292300988248 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 2622939867325521672 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 10411797434628406592 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 13001179628642472376 - ], - [ - "LessOrEqual ai.onnx CPUExecutionProvider", - 15565321713560893128 - ], - [ - "Log ai.onnx CPUExecutionProvider", - 268464912229648680 - ], - [ - "Log ai.onnx CPUExecutionProvider", - 3112553152557524648 - ], - [ - "Log ai.onnx CPUExecutionProvider", - 11759591563691157568 - ], - [ - "Log ai.onnx CPUExecutionProvider", - 17676497422436172176 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 115736193724605672 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 2520466664629116256 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 4492727363140130712 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 7315901443224771096 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 11503158751336478728 - ], - [ - "LogSoftmax ai.onnx CPUExecutionProvider", - 14810560469352034472 - ], - [ - "Loop ai.onnx CPUExecutionProvider", - 4722530828690003040 - ], - [ - "Loop ai.onnx CPUExecutionProvider", - 12504563769172665504 - ], - [ - "Loop ai.onnx CPUExecutionProvider", - 18226282479886582688 - ], - [ - "Loop ai.onnx CPUExecutionProvider", - 3437102755395866648 - ], - [ - "LpNormalization ai.onnx CPUExecutionProvider", - 5940113166682524360 - ], - [ - "LpNormalization ai.onnx CPUExecutionProvider", - 7839016206143985376 - ], - [ - "LpPool ai.onnx CPUExecutionProvider", - 2659465814558270816 - ], - [ - "LpPool ai.onnx CPUExecutionProvider", - 8193428427416727168 - ], - [ - "LRN ai.onnx CPUExecutionProvider", - 12835456436860157072 - ], - [ - "LRN ai.onnx CPUExecutionProvider", - 15047083116516536152 - ], - [ - "LSTM ai.onnx CPUExecutionProvider", - 4392316607181063320 - ], - [ - "LSTM ai.onnx CPUExecutionProvider", - 5860801277476873352 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 52556316079319400 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 3037708961966197464 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 3051970820885644088 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 3074778955695403856 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 4366527517563157992 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 6380816295259527720 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 9907944282496968536 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 10090084904454358640 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 12944936747196752560 - ], - [ - "MatMul ai.onnx CPUExecutionProvider", - 16997422825780227992 - ], - [ - "MatMulInteger ai.onnx CPUExecutionProvider", - 10315437332566125928 - ], - [ - "MatMulInteger ai.onnx CPUExecutionProvider", - 8304157459354278720 - ], - [ - "Max ai.onnx CPUExecutionProvider", - 11384130473520199552 - ], - [ - "Max ai.onnx CPUExecutionProvider", - 11400391412664954200 - ], - [ - "Max ai.onnx CPUExecutionProvider", - 13265498576778444128 - ], - [ - "Max ai.onnx CPUExecutionProvider", - 15826943726640087288 - ], - [ - "MaxPool ai.onnx CPUExecutionProvider", - 8090321298879394920 - ], - [ - "MaxPool ai.onnx CPUExecutionProvider", - 8183201210657604880 - ], - [ - "MaxPool ai.onnx CPUExecutionProvider", - 9018732444147417520 - ], - [ - "MaxRoiPool ai.onnx CPUExecutionProvider", - 15885844739553398488 - ], - [ - "MaxUnpool ai.onnx CPUExecutionProvider", - 32284576063639008 - ], - [ - "MaxUnpool ai.onnx CPUExecutionProvider", - 12239285446297040368 - ], - [ - "Mean ai.onnx CPUExecutionProvider", - 674524861415863432 - ], - [ - "Mean ai.onnx CPUExecutionProvider", - 15995564204118007280 - ], - [ - "Mean ai.onnx CPUExecutionProvider", - 17663800684295189072 - ], - [ - "MeanVarianceNormalization ai.onnx CPUExecutionProvider", - 4538073571749444680 - ], - [ - "MeanVarianceNormalization ai.onnx CPUExecutionProvider", - 17242016597551698064 - ], - [ - "MelWeightMatrix ai.onnx CPUExecutionProvider", - 1589563865873170600 - ], - [ - "Min ai.onnx CPUExecutionProvider", - 5444634510407971152 - ], - [ - "Min ai.onnx CPUExecutionProvider", - 6810496931206290712 - ], - [ - "Min ai.onnx CPUExecutionProvider", - 8711551713147301744 - ], - [ - "Min ai.onnx CPUExecutionProvider", - 9049294222697989120 - ], - [ - "Mod ai.onnx CPUExecutionProvider", - 7252013141748353440 - ], - [ - "Mod ai.onnx CPUExecutionProvider", - 10733303722078722368 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 6564441461158321512 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 12026653868244773992 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 12682993328532509792 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 15671930608208592024 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 2785903326433781024 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 4791260151803733872 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 4865785170965629472 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 6426675508006464640 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 7050294657620187432 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 7162660788918165080 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 10589749217942850296 - ], - [ - "Mul ai.onnx CPUExecutionProvider", - 13192662774919983744 - ], - [ - "Multinomial ai.onnx CPUExecutionProvider", - 4035027560640217112 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 299941700632775384 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 941339082116483920 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 5647833637661484080 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 5951204590602223848 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 8536718915610937776 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 9219512640203829032 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 10410816017567060816 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 11383475793879864256 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 11497410339400054552 - ], - [ - "Neg ai.onnx CPUExecutionProvider", - 13062264815462063544 - ], - [ - "NonMaxSuppression ai.onnx CPUExecutionProvider", - 1370201499365506920 - ], - [ - "NonMaxSuppression ai.onnx CPUExecutionProvider", - 5599941468729814576 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 4582779599163399952 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 5030526776968826448 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 7409810841275121112 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 9279205752738594840 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 9946863128009620032 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 10368470876283158368 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 11258950761227257408 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 14796541421586107008 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 15791942659448002792 - ], - [ - "NonZero ai.onnx CPUExecutionProvider", - 17784759486752176416 - ], - [ - "Not ai.onnx CPUExecutionProvider", - 5212043150202938416 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 978583397600040800 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 1520450948168779496 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 2271501714099757272 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 3388280028895725048 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 5791206859684296928 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 6100777310996997520 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 6920116261827618000 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 7709731812980042552 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 8289807350985807680 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 8313592734779484512 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 8361965853380969048 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 8392130661000703936 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 8728814239843992472 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 9133480318268319368 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 10784514718782139952 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 11582157536088100696 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 11996617401421737624 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 13987905095574261288 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 14326347270202402576 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 14331450753330761992 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 14348901166058188912 - ], - [ - "OneHot ai.onnx CPUExecutionProvider", - 15162687909384154912 - ], - [ - "Or ai.onnx CPUExecutionProvider", - 18295541712828245416 - ], - [ - "Pad ai.onnx CPUExecutionProvider", - 12904240253005862936 - ], - [ - "Pad ai.onnx CPUExecutionProvider", - 15778168444418158464 - ], - [ - "Pad ai.onnx CPUExecutionProvider", - 17900876414644757728 - ], - [ - "Pow ai.onnx CPUExecutionProvider", - 2732884601395419376 - ], - [ - "Pow ai.onnx CPUExecutionProvider", - 8377308095432624176 - ], - [ - "Pow ai.onnx CPUExecutionProvider", - 12963226513247425672 - ], - [ - "Pow ai.onnx CPUExecutionProvider", - 16138602580714332296 - ], - [ - "PRelu ai.onnx CPUExecutionProvider", - 3282999003886175808 - ], - [ - "PRelu ai.onnx CPUExecutionProvider", - 18013428254891374496 - ], - [ - "PRelu ai.onnx CPUExecutionProvider", - 17872917958807301128 - ], - [ - "QLinearConv ai.onnx CPUExecutionProvider", - 1301685544574905024 - ], - [ - "QLinearConv ai.onnx CPUExecutionProvider", - 6630340551594954312 - ], - [ - "QLinearMatMul ai.onnx CPUExecutionProvider", - 17071666635484846840 - ], - [ - "QLinearMatMul ai.onnx CPUExecutionProvider", - 7936962119982759552 - ], - [ - "QuantizeLinear ai.onnx CPUExecutionProvider", - 1532667988771795616 - ], - [ - "QuantizeLinear ai.onnx CPUExecutionProvider", - 2341045801698043944 - ], - [ - "QuantizeLinear ai.onnx CPUExecutionProvider", - 7634601085712939440 - ], - [ - "QuantizeLinear ai.onnx CPUExecutionProvider", - 11284864816493289608 - ], - [ - "RandomNormal ai.onnx CPUExecutionProvider", - 11106330559691891632 - ], - [ - "RandomNormalLike ai.onnx CPUExecutionProvider", - 16132637741722034712 - ], - [ - "RandomUniform ai.onnx CPUExecutionProvider", - 13265289781964897352 - ], - [ - "RandomUniformLike ai.onnx CPUExecutionProvider", - 16430898061899701520 - ], - [ - "Range ai.onnx CPUExecutionProvider", - 3325111002682387088 - ], - [ - "Reciprocal ai.onnx CPUExecutionProvider", - 2850891354519740808 - ], - [ - "Reciprocal ai.onnx CPUExecutionProvider", - 11806012495947779832 - ], - [ - "Reciprocal ai.onnx CPUExecutionProvider", - 12484092722323390008 - ], - [ - "Reciprocal ai.onnx CPUExecutionProvider", - 17341443694738537648 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 1139813399733469792 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 3243638789933182864 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 6808199264646043248 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 12784657175698452520 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 14869942543039134512 - ], - [ - "ReduceL1 ai.onnx CPUExecutionProvider", - 16118051457711035712 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 867424791763712096 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 7648877332496785568 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 11351012496150467016 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 11760351378486873352 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 12635395574144286256 - ], - [ - "ReduceL2 ai.onnx CPUExecutionProvider", - 14689639463724553808 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 2332039309841303536 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 4047536642033380384 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 5847333566715597800 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 7095373462244394120 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 14147219412305287720 - ], - [ - "ReduceLogSum ai.onnx CPUExecutionProvider", - 14876792603687228632 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 1578137980614546184 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 2444082231177199048 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 2552318884715787112 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 2802074054827186944 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 4954470811027637464 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 7534771629024139208 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 8383855610701728736 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 9293246934181404656 - ], - [ - "ReduceLogSumExp ai.onnx CPUExecutionProvider", - 14893941010795227112 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 1506640785403361944 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 2151353831968849952 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 2985763464867451504 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 3762588521304654072 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 5551647046811272248 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 6269158263682288056 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 6953345132994073600 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 8407340488798592744 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 8759268275556416800 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 9625377663291722928 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 10803038243496953112 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 11991870565148818504 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 12690494510801443616 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 12851170686026716664 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 13618950924415359360 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 13664399148702906192 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 15032536557946298480 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 15302447573039984352 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 15968588025975286400 - ], - [ - "ReduceMax ai.onnx CPUExecutionProvider", - 17355842716358289032 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 1841014235270425536 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 2483440124489784048 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 3218961468728384336 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 7017607560662713312 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 7620337157139517672 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 10980558378895407008 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 12368118057245055784 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 14968655452447684912 - ], - [ - "ReduceMean ai.onnx CPUExecutionProvider", - 16630173800581470640 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 268786330812272408 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 2781368088787513656 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 3131568620751235160 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 5314914069158001048 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 5697549792153090816 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 5749891696478874968 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 6050316872530083912 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 6981923349123980904 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 7357167418382928448 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 7385064883153786808 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 10018537486271003776 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 11291266537870182088 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 13294165344051481168 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 15207099834172347424 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 16456448302154466592 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 16859238740139814880 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 16937803876974883736 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 17340745031509256168 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 17758190144581814176 - ], - [ - "ReduceMin ai.onnx CPUExecutionProvider", - 17784882475811026608 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 1183314972083054304 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 5737910396606289640 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 6038794103189808640 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 6535050096259888824 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 7402146841475223544 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 8181303911764095056 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 8637768506869245832 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 12584354805776722856 - ], - [ - "ReduceProd ai.onnx CPUExecutionProvider", - 16349073532452506688 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 1980255433118254864 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 5051926940187829840 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 5506740466752429664 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 6035268158473211448 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 6869090123498269616 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 8382746145639334288 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 9380715542666407912 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 9402732746386583360 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 11742441148498967432 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 12554723003393183392 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 15179274571338977128 - ], - [ - "ReduceSum ai.onnx CPUExecutionProvider", - 17432693777620803080 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 1312654986065495352 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 2140627660946213648 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 3963605022599305624 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 4704537570264301616 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 9885037980564018056 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 10444747001582642680 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 10952132708040173936 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 11399363300408446304 - ], - [ - "ReduceSumSquare ai.onnx CPUExecutionProvider", - 16124074064020394392 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 6348137465784795920 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 8714935363812984712 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 9308941683254701488 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 11010461083452725976 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 12612615492664612000 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 17992752176767437224 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 800813007613791272 - ], - [ - "Relu ai.onnx CPUExecutionProvider", - 8895401642537447048 - ], - [ - "Reshape ai.onnx CPUExecutionProvider", - 15585679549439997664 - ], - [ - "Reshape ai.onnx CPUExecutionProvider", - 312434585414181376 - ], - [ - "Reshape ai.onnx CPUExecutionProvider", - 2930053564253999576 - ], - [ - "Reshape ai.onnx CPUExecutionProvider", - 7420064847738548128 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 2176125637575734736 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 4943403435871631024 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 5590380249602790576 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 6635061613257537184 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 6688747786976434096 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 7586529787086160288 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 8129711815686221072 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 9417332069479988832 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 16654410418270805888 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 11612050915202267528 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 17391179092212899176 - ], - [ - "Resize ai.onnx CPUExecutionProvider", - 18197536715869401872 - ], - [ - "ReverseSequence ai.onnx CPUExecutionProvider", - 777686786711298232 - ], - [ - "RNN ai.onnx CPUExecutionProvider", - 1524012948315159160 - ], - [ - "RNN ai.onnx CPUExecutionProvider", - 7150196822928564664 - ], - [ - "RoiAlign ai.onnx CPUExecutionProvider", - 10225383741733918632 - ], - [ - "RoiAlign ai.onnx CPUExecutionProvider", - 17022700455473327752 - ], - [ - "RoiAlign ai.onnx CPUExecutionProvider", - 12536918025219759808 - ], - [ - "RoiAlign ai.onnx CPUExecutionProvider", - 17630436044120124376 - ], - [ - "Round ai.onnx CPUExecutionProvider", - 6445503327475574168 - ], - [ - "Round ai.onnx CPUExecutionProvider", - 7643555187101537632 - ], - [ - "Round ai.onnx CPUExecutionProvider", - 13985378343775212776 - ], - [ - "Scan ai.onnx CPUExecutionProvider", - 3668627007850399040 - ], - [ - "Scan ai.onnx CPUExecutionProvider", - 11384233164992114368 - ], - [ - "Scan ai.onnx CPUExecutionProvider", - 1718418059112844640 - ], - [ - "Scan ai.onnx CPUExecutionProvider", - 220271302879298784 - ], - [ - "Scatter ai.onnx CPUExecutionProvider", - 15759064509848656392 - ], - [ - "ScatterElements ai.onnx CPUExecutionProvider", - 6975525526054194192 - ], - [ - "ScatterElements ai.onnx CPUExecutionProvider", - 18432023504032212784 - ], - [ - "ScatterElements ai.onnx CPUExecutionProvider", - 6811950547871129912 - ], - [ - "ScatterND ai.onnx CPUExecutionProvider", - 3058556000123462920 - ], - [ - "ScatterND ai.onnx CPUExecutionProvider", - 8479082925474841976 - ], - [ - "ScatterND ai.onnx CPUExecutionProvider", - 8962669673590364688 - ], - [ - "Selu ai.onnx CPUExecutionProvider", - 17981896312364280064 - ], - [ - "SequenceAt ai.onnx CPUExecutionProvider", - 4973237574994879072 - ], - [ - "SequenceConstruct ai.onnx CPUExecutionProvider", - 1758418193363705792 - ], - [ - "SequenceEmpty ai.onnx CPUExecutionProvider", - 1140894844554479184 - ], - [ - "SequenceErase ai.onnx CPUExecutionProvider", - 9132449567156154920 - ], - [ - "SequenceInsert ai.onnx CPUExecutionProvider", - 13543825903947046128 - ], - [ - "SequenceLength ai.onnx CPUExecutionProvider", - 12782512452558492336 - ], - [ - "Shape ai.onnx CPUExecutionProvider", - 7116832049885431040 - ], - [ - "Shape ai.onnx CPUExecutionProvider", - 14989007508280400584 - ], - [ - "Shape ai.onnx CPUExecutionProvider", - 9917761852037658112 - ], - [ - "Shrink ai.onnx CPUExecutionProvider", - 4706529740707835200 - ], - [ - "Sigmoid ai.onnx CPUExecutionProvider", - 394481973749682984 - ], - [ - "Sigmoid ai.onnx CPUExecutionProvider", - 5082642407570374784 - ], - [ - "Sigmoid ai.onnx CPUExecutionProvider", - 11065037976771470120 - ], - [ - "Sigmoid ai.onnx CPUExecutionProvider", - 17627299285547200664 - ], - [ - "Sign ai.onnx CPUExecutionProvider", - 7242670280432058360 - ], - [ - "Sign ai.onnx CPUExecutionProvider", - 15771057264632726008 - ], - [ - "Sin ai.onnx CPUExecutionProvider", - 4888589719281923384 - ], - [ - "Sin ai.onnx CPUExecutionProvider", - 17128842231045328968 - ], - [ - "Sinh ai.onnx CPUExecutionProvider", - 4576126796107617992 - ], - [ - "Size ai.onnx CPUExecutionProvider", - 217840791673793560 - ], - [ - "Size ai.onnx CPUExecutionProvider", - 17016678530831229680 - ], - [ - "Slice ai.onnx CPUExecutionProvider", - 5357620905248938008 - ], - [ - "Slice ai.onnx CPUExecutionProvider", - 11391608604339585104 - ], - [ - "Slice ai.onnx CPUExecutionProvider", - 13519381513691870136 - ], - [ - "Slice ai.onnx CPUExecutionProvider", - 18363944821618547616 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 3465130326277671496 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 9903615929201046200 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 10498752306304274880 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 11571890266623078672 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 14127813883883160600 - ], - [ - "Softmax ai.onnx CPUExecutionProvider", - 17624244089184427864 - ], - [ - "Softplus ai.onnx CPUExecutionProvider", - 1453844285936857688 - ], - [ - "Softsign ai.onnx CPUExecutionProvider", - 6090105046651540184 - ], - [ - "SpaceToDepth ai.onnx CPUExecutionProvider", - 9873252942934071232 - ], - [ - "SpaceToDepth ai.onnx CPUExecutionProvider", - 12194265012469761912 - ], - [ - "Split ai.onnx CPUExecutionProvider", - 3088983478592774216 - ], - [ - "Split ai.onnx CPUExecutionProvider", - 12436994357532875256 - ], - [ - "Split ai.onnx CPUExecutionProvider", - 15368382353097470168 - ], - [ - "SplitToSequence ai.onnx CPUExecutionProvider", - 8162957939300717576 - ], - [ - "Sqrt ai.onnx CPUExecutionProvider", - 2567720088904231480 - ], - [ - "Sqrt ai.onnx CPUExecutionProvider", - 10412643283809566304 - ], - [ - "Sqrt ai.onnx CPUExecutionProvider", - 10542781197602744432 - ], - [ - "Sqrt ai.onnx CPUExecutionProvider", - 10587863051847811344 - ], - [ - "Squeeze ai.onnx CPUExecutionProvider", - 12889825108950034784 - ], - [ - "Squeeze ai.onnx CPUExecutionProvider", - 14725795030460042064 - ], - [ - "Squeeze ai.onnx CPUExecutionProvider", - 16122603335179721968 - ], - [ - "STFT ai.onnx CPUExecutionProvider", - 1739051453790648552 - ], - [ - "StringNormalizer ai.onnx CPUExecutionProvider", - 7767393334034626736 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 2039033718071997632 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 4778873438453976024 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 10264597047275083736 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 12226892770165869472 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 419566716728265872 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 2727932803080255528 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 6712677947990193048 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 8541521562748754016 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 9330359905365434456 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 13039184128400389864 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 16924361586861576328 - ], - [ - "Sub ai.onnx CPUExecutionProvider", - 17828222941588972448 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 590100139402794968 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 5647171547812521792 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 11630776365971756320 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 17085797681836391096 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 17995615846776008176 - ], - [ - "Sum ai.onnx CPUExecutionProvider", - 18320202736703015384 - ], - [ - "Tan ai.onnx CPUExecutionProvider", - 3500007982075582512 - ], - [ - "Tanh ai.onnx CPUExecutionProvider", - 9604368444687947576 - ], - [ - "Tanh ai.onnx CPUExecutionProvider", - 10324147978489520384 - ], - [ - "Tanh ai.onnx CPUExecutionProvider", - 11123828994494025000 - ], - [ - "Tanh ai.onnx CPUExecutionProvider", - 12012944136719804976 - ], - [ - "TfIdfVectorizer ai.onnx CPUExecutionProvider", - 12361724165659823144 - ], - [ - "ThresholdedRelu ai.onnx CPUExecutionProvider", - 4781858005566667480 - ], - [ - "Tile ai.onnx CPUExecutionProvider", - 13093106569145134440 - ], - [ - "Tile ai.onnx CPUExecutionProvider", - 14102078343076871784 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 1153626550939059536 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 1432673385034650208 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 2297925672512898808 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 9888321934045104504 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 12624528444405735856 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 12652594006657629568 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 13224909753327646984 - ], - [ - "TopK ai.onnx CPUExecutionProvider", - 15636662090924439512 - ], - [ - "Transpose ai.onnx CPUExecutionProvider", - 4324835766923221184 - ], - [ - "Transpose ai.onnx CPUExecutionProvider", - 17267477159887372848 - ], - [ - "Trilu ai.onnx CPUExecutionProvider", - 9981347960818883600 - ], - [ - "Unique ai.onnx CPUExecutionProvider", - 248906295587923264 - ], - [ - "Unsqueeze ai.onnx CPUExecutionProvider", - 9466011545409597224 - ], - [ - "Unsqueeze ai.onnx CPUExecutionProvider", - 15964030255371555232 - ], - [ - "Unsqueeze ai.onnx CPUExecutionProvider", - 16989589986691430224 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 1134372751280084232 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 1504517427294670632 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 2377927852844585288 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 12301152523715249896 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 13679625484942243168 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 17854969343220970496 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 4881414272874428968 - ], - [ - "Upsample ai.onnx CPUExecutionProvider", - 7886741677344006832 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 4809895774227191528 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 6240304355560301824 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 8769536045886744832 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 14651270063533962384 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 18032063618220677104 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 18237388578268294064 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 186521469797128048 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 1124762945029945944 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 8727601196279637232 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 10010972164133933208 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 12896498917540022856 - ], - [ - "Where ai.onnx CPUExecutionProvider", - 17544214758602217832 - ], - [ - "Xor ai.onnx CPUExecutionProvider", - 14631049987911195736 - ] -] diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.ml.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.ml.cpu.json deleted file mode 100644 index 8aef60e11fb6d..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.ml.cpu.json +++ /dev/null @@ -1,202 +0,0 @@ -[ - [ - "ArrayFeatureExtractor ai.onnx.ml CPUExecutionProvider", - 2350526174341531704 - ], - [ - "ArrayFeatureExtractor ai.onnx.ml CPUExecutionProvider", - 6530469673967836712 - ], - [ - "ArrayFeatureExtractor ai.onnx.ml CPUExecutionProvider", - 6756475889913415064 - ], - [ - "ArrayFeatureExtractor ai.onnx.ml CPUExecutionProvider", - 11064940903354222584 - ], - [ - "ArrayFeatureExtractor ai.onnx.ml CPUExecutionProvider", - 16432528150559367472 - ], - [ - "Binarizer ai.onnx.ml CPUExecutionProvider", - 5046510072939685048 - ], - [ - "CastMap ai.onnx.ml CPUExecutionProvider", - 8048827990321743320 - ], - [ - "CategoryMapper ai.onnx.ml CPUExecutionProvider", - 2303612724843538232 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 6689651410341671832 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 7005642460916971440 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 9815805762175478424 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 10888259614400691552 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 12166013579191601224 - ], - [ - "DictVectorizer ai.onnx.ml CPUExecutionProvider", - 14175801436008296008 - ], - [ - "FeatureVectorizer ai.onnx.ml CPUExecutionProvider", - 13725963185932500216 - ], - [ - "Imputer ai.onnx.ml CPUExecutionProvider", - 1198484622450883960 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 72584427791628640 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 643259145804007712 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 2667163092143996224 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 5922496927153994000 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 11838196408428523536 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 12429922703862798200 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 12910259746827434448 - ], - [ - "LabelEncoder ai.onnx.ml CPUExecutionProvider", - 13358273281151927664 - ], - [ - "LinearClassifier ai.onnx.ml CPUExecutionProvider", - 2370253020407350728 - ], - [ - "LinearRegressor ai.onnx.ml CPUExecutionProvider", - 16822518791234245504 - ], - [ - "Normalizer ai.onnx.ml CPUExecutionProvider", - 4575839837655511352 - ], - [ - "OneHotEncoder ai.onnx.ml CPUExecutionProvider", - 566681355974248448 - ], - [ - "OneHotEncoder ai.onnx.ml CPUExecutionProvider", - 5824294864440328760 - ], - [ - "OneHotEncoder ai.onnx.ml CPUExecutionProvider", - 10362613019971997872 - ], - [ - "OneHotEncoder ai.onnx.ml CPUExecutionProvider", - 10397925769639456800 - ], - [ - "Scaler ai.onnx.ml CPUExecutionProvider", - 2957692969286147560 - ], - [ - "Scaler ai.onnx.ml CPUExecutionProvider", - 4622934914100034264 - ], - [ - "Scaler ai.onnx.ml CPUExecutionProvider", - 15151724670894748016 - ], - [ - "Scaler ai.onnx.ml CPUExecutionProvider", - 17146108806137553320 - ], - [ - "SVMClassifier ai.onnx.ml CPUExecutionProvider", - 16410276550989508464 - ], - [ - "SVMRegressor ai.onnx.ml CPUExecutionProvider", - 6414907892091941152 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 486797053154945264 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 7696156941733005088 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 7775841366762359432 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 11182621524629506824 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 15634834299270082824 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 16976754609836955392 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 17582366242455530152 - ], - [ - "TreeEnsembleClassifier ai.onnx.ml CPUExecutionProvider", - 18326787492602399840 - ], - [ - "TreeEnsembleRegressor ai.onnx.ml CPUExecutionProvider", - 1006399804521896912 - ], - [ - "TreeEnsembleRegressor ai.onnx.ml CPUExecutionProvider", - 10267269339573101144 - ], - [ - "TreeEnsembleRegressor ai.onnx.ml CPUExecutionProvider", - 12993125630596348064 - ], - [ - "TreeEnsembleRegressor ai.onnx.ml CPUExecutionProvider", - 17387929614635939072 - ], - [ - "ZipMap ai.onnx.ml CPUExecutionProvider", - 868519487849210656 - ] -] \ No newline at end of file diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.optional_type_ops.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.optional_type_ops.cpu.json deleted file mode 100644 index 49c93994862d0..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.optional_type_ops.cpu.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - [ - "Optional ai.onnx CPUExecutionProvider", - 4007199385789893408 - ], - [ - "OptionalGetElement ai.onnx CPUExecutionProvider", - 8727767224223660008 - ], - [ - "OptionalHasElement ai.onnx CPUExecutionProvider", - 103583056104706000 - ] -] diff --git a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json deleted file mode 100644 index 44ea17b71760f..0000000000000 --- a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json +++ /dev/null @@ -1,282 +0,0 @@ -[ - [ - "AdamOptimizer com.microsoft CPUExecutionProvider", - 13275719513674142848 - ], - [ - "AdamWOptimizer com.microsoft CPUExecutionProvider", - 8132205541603584160 - ], - [ - "AveragePoolGrad ai.onnx CPUExecutionProvider", - 5748823370585834408 - ], - [ - "BiasFastGeluGrad_dX com.microsoft CPUExecutionProvider", - 18012658855595136536 - ], - [ - "BiasGeluGrad_dX com.microsoft CPUExecutionProvider", - 15594101660509653368 - ], - [ - "BroadcastGradientArgs com.microsoft CPUExecutionProvider", - 11924624129611280440 - ], - [ - "ConcatTraining com.microsoft CPUExecutionProvider", - 407435603592769928 - ], - [ - "ConvGrad com.microsoft CPUExecutionProvider", - 6051867985469399832 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 5281827689086376112 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 5974036940246406232 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 6251139593746398664 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 11134433709210415000 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 14442689431073529904 - ], - [ - "DropoutGrad com.microsoft CPUExecutionProvider", - 14487527510076876072 - ], - [ - "FastGeluGrad com.microsoft CPUExecutionProvider", - 15449034010577224840 - ], - [ - "GatherElementsGrad com.microsoft CPUExecutionProvider", - 4371565130613539784 - ], - [ - "GatherGrad com.microsoft CPUExecutionProvider", - 16767260238333903216 - ], - [ - "GatherNDGrad com.microsoft CPUExecutionProvider", - 14458526033602252960 - ], - [ - "GeluGrad com.microsoft CPUExecutionProvider", - 3065626784130845072 - ], - [ - "Group com.microsoft CPUExecutionProvider", - 488667512000820344 - ], - [ - "InPlaceAccumulator com.microsoft CPUExecutionProvider", - 10152728201494720480 - ], - [ - "InvertibleLayerNormalizationGrad com.microsoft CPUExecutionProvider", - 7138710605488227064 - ], - [ - "InvertibleLayerNormalizationGrad com.microsoft CPUExecutionProvider", - 14718779537160086944 - ], - [ - "LayerNormalizationGrad com.microsoft CPUExecutionProvider", - 12121902982758237936 - ], - [ - "LayerNormalizationGrad com.microsoft CPUExecutionProvider", - 17776280973585908456 - ], - [ - "LogSoftmaxGrad com.microsoft CPUExecutionProvider", - 2657523710083167200 - ], - [ - "LogSoftmaxGrad_13 com.microsoft CPUExecutionProvider", - 1917456134240183096 - ], - [ - "MaxPoolGrad ai.onnx CPUExecutionProvider", - 17526822836083413768 - ], - [ - "PassThrough com.microsoft CPUExecutionProvider", - 15753758832962034552 - ], - [ - "ReduceAllL2 com.microsoft CPUExecutionProvider", - 10206006236140935288 - ], - [ - "ReduceSumTraining com.microsoft CPUExecutionProvider", - 4143442227179196240 - ], - [ - "ReduceSumTraining com.microsoft CPUExecutionProvider", - 8270932107741781432 - ], - [ - "ReduceSumTraining com.microsoft CPUExecutionProvider", - 10014758284200613760 - ], - [ - "ReduceSumTraining com.microsoft CPUExecutionProvider", - 11970256914235931792 - ], - [ - "ReluGrad com.microsoft CPUExecutionProvider", - 6194712211707544696 - ], - [ - "SGDOptimizer com.microsoft CPUExecutionProvider", - 6413752339355984752 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 4626702086191057400 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 7066774539152819712 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 7086129391615471904 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 8529929466096310624 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 9194708927698241584 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 15004008446550052608 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 16665599877841166880 - ], - [ - "Scale com.microsoft CPUExecutionProvider", - 17615601449933173216 - ], - [ - "SigmoidGrad com.microsoft CPUExecutionProvider", - 13197741041840309432 - ], - [ - "SimplifiedLayerNormalizationGrad com.microsoft CPUExecutionProvider", - 6770225290706127808 - ], - [ - "SimplifiedLayerNormalizationGrad com.microsoft CPUExecutionProvider", - 13749371409358355088 - ], - [ - "SinGrad ai.onnx CPUExecutionProvider", - 6971400065913943256 - ], - [ - "SliceGrad com.microsoft CPUExecutionProvider", - 18003932513454931536 - ], - [ - "SoftmaxCrossEntropy com.microsoft CPUExecutionProvider", - 5470752940282787040 - ], - [ - "SoftmaxCrossEntropyGrad com.microsoft CPUExecutionProvider", - 17579254067477868408 - ], - [ - "SoftmaxCrossEntropyLoss ai.onnx CPUExecutionProvider", - 379111282162709688 - ], - [ - "SoftmaxCrossEntropyLoss ai.onnx CPUExecutionProvider", - 5256121368123320104 - ], - [ - "SoftmaxCrossEntropyLoss ai.onnx CPUExecutionProvider", - 14833827121724789864 - ], - [ - "SoftmaxCrossEntropyLoss ai.onnx CPUExecutionProvider", - 15405100773745075656 - ], - [ - "SoftmaxCrossEntropyLossGrad com.microsoft CPUExecutionProvider", - 8253282220433537112 - ], - [ - "SoftmaxCrossEntropyLossGrad com.microsoft CPUExecutionProvider", - 16283082098560169504 - ], - [ - "SoftmaxCrossEntropyLossInternal com.microsoft CPUExecutionProvider", - 265850656999021376 - ], - [ - "SoftmaxCrossEntropyLossInternal com.microsoft CPUExecutionProvider", - 11233773279707270760 - ], - [ - "SoftmaxCrossEntropyLossInternalGrad com.microsoft CPUExecutionProvider", - 5510696224816806976 - ], - [ - "SoftmaxCrossEntropyLossInternalGrad com.microsoft CPUExecutionProvider", - 16653880432442754672 - ], - [ - "SoftmaxGrad com.microsoft CPUExecutionProvider", - 4483165757863027152 - ], - [ - "SoftmaxGrad_13 com.microsoft CPUExecutionProvider", - 8375491041422269560 - ], - [ - "SparseSoftmaxCrossEntropy ai.onnx CPUExecutionProvider", - 10638058507241762520 - ], - [ - "SparseSoftmaxCrossEntropyGrad ai.onnx CPUExecutionProvider", - 17612183648106847376 - ], - [ - "SplitTraining com.microsoft CPUExecutionProvider", - 12689204749897364688 - ], - [ - "TanhGrad com.microsoft CPUExecutionProvider", - 7147744030478490408 - ], - [ - "ZeroGradient com.microsoft CPUExecutionProvider", - 3284255990062374928 - ], - [ - "InPlaceAccumulatorV2 com.microsoft CPUExecutionProvider", - 12968279839987729832 - ], - [ - "InplaceClipGradNorm com.microsoft CPUExecutionProvider", - 10251631611024722504 - ] -] diff --git a/onnxruntime/test/testdata/mnist.level1_opt.ort b/onnxruntime/test/testdata/mnist.basic.ort similarity index 88% rename from onnxruntime/test/testdata/mnist.level1_opt.ort rename to onnxruntime/test/testdata/mnist.basic.ort index 066e14dcab2ea..75c45fc73ec82 100644 Binary files a/onnxruntime/test/testdata/mnist.level1_opt.ort and b/onnxruntime/test/testdata/mnist.basic.ort differ diff --git a/onnxruntime/test/testdata/mnist.internal_testing_ep.ort b/onnxruntime/test/testdata/mnist.internal_testing_ep.ort index 066e14dcab2ea..49b08f8f97678 100644 Binary files a/onnxruntime/test/testdata/mnist.internal_testing_ep.ort and b/onnxruntime/test/testdata/mnist.internal_testing_ep.ort differ diff --git a/onnxruntime/test/testdata/mnist.readme.txt b/onnxruntime/test/testdata/mnist.readme.txt index bf5cca9451fbf..b6d9c602029a4 100644 --- a/onnxruntime/test/testdata/mnist.readme.txt +++ b/onnxruntime/test/testdata/mnist.readme.txt @@ -1,12 +1,9 @@ The mnist model is used in multiple tests for minimal/mobile builds in both ONNX and ORT formats. -We also save both ONNX and ORT format versions of the model with level 1 (aka 'basic') optimizations applied. - - mnist.level1_opt.onnx makes sure the required operators for this model are automatically included in - required_ops.config, which is used in the reduced ops CI build. - - mnist.level1_opt.ort is used in NNAPI unit tests. +We also save the ORT format version of the model with level 1 (aka 'basic') optimizations applied: mnist.basic.ort. +This file is used in NNAPI and CoreML EP unit tests. -The level 1 optimized model files can be generated with the following steps: +It can be generated with the following steps: - Set environment variable ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL=basic - From this directory, run - $ python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed --save_optimized_onnx_model ./mnist.onnx -- Rename the resulting .onnx and .ort files accordingly + $ python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./mnist.onnx diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.basic.ort b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.basic.ort deleted file mode 100644 index d25a3e3fca38c..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.basic.ort and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.onnx b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.onnx deleted file mode 100644 index 6db1f8bf13af2..0000000000000 --- a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/gathernd9.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:‰ -! -data -indicesoutput"GatherNDtest_gathernd_example_int32Z -data -  - -Z -indices -  - -b -output - - -B \ No newline at end of file diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.basic.ort b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.basic.ort deleted file mode 100644 index 78110f009396f..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.basic.ort and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.onnx b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.onnx deleted file mode 100644 index 166007f42a926..0000000000000 --- a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/not1.onnx +++ /dev/null @@ -1,13 +0,0 @@ - backend-test:P - -xnot"Not test_not_3dZ -x -  - - -b -not -  - - -B \ No newline at end of file diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.basic.ort b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.basic.ort deleted file mode 100644 index cc13ee5968e62..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.basic.ort and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.onnx b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.onnx deleted file mode 100644 index d32472c10b3ef..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/roialign10.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.basic.ort b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.basic.ort deleted file mode 100644 index f894d2436e1f5..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.basic.ort and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.onnx b/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.onnx deleted file mode 100644 index 6b665b65f1962..0000000000000 Binary files a/onnxruntime/test/testdata/ort_backwards_compat/ORTv1.10/scan9.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/ort_backwards_compat/readme.md b/onnxruntime/test/testdata/ort_backwards_compat/readme.md deleted file mode 100644 index c8f31f8f70e0f..0000000000000 --- a/onnxruntime/test/testdata/ort_backwards_compat/readme.md +++ /dev/null @@ -1,18 +0,0 @@ -This directory contains ORT format models to test for backwards compatibility when we are forced to make an update that invalidates a kernel hash. - -When this happens, first create a directory for the currently released ORT version if one doesn't already exist. - -Find a model that uses the operator with the kernel hash change and copy it to the directory for the currently released ORT version. -The ONNX test data is generally a good place to do this. See cmake/external/onnx/onnx/backend/test/data/node. - -Convert the model to ORT format using the currently released ORT version. This model will contain the original hash. - -e.g. -Setting environment variable ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL=basic -and then running `python -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ORTv1.10/not1.onnx` -will create the ORT format model `not1.basic.ort` - -Add both the ONNX and the ORT format models to the repository. - -See onnxruntime/test/providers/kernel_def_hash_test.cc for information on updating the backwards compatibility hash -information and unit tests to use the new model. diff --git a/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort b/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort index 3935cf3d428e7..bb312eac8e1ea 100644 Binary files a/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort and b/onnxruntime/test/testdata/ort_github_issue_4031.onnx.ort differ diff --git a/onnxruntime/test/testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort b/onnxruntime/test/testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort index 11db2e4d53f28..540570e16fb53 100644 Binary files a/onnxruntime/test/testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort and b/onnxruntime/test/testdata/ort_minimal_test_models/sparse_initializer_handling.onnx.ort differ diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx.ort b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx.ort new file mode 100644 index 0000000000000..c93672be600f0 Binary files /dev/null and b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx.ort differ diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort deleted file mode 100644 index 9095106951e4e..0000000000000 Binary files a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort and /dev/null differ diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/conv_clip11.runtime_optimizations.ort b/onnxruntime/test/testdata/transform/runtime_optimization/conv_clip11.runtime_optimizations.ort index d3e098ec1c6c7..49728e028a9a0 100644 Binary files a/onnxruntime/test/testdata/transform/runtime_optimization/conv_clip11.runtime_optimizations.ort and b/onnxruntime/test/testdata/transform/runtime_optimization/conv_clip11.runtime_optimizations.ort differ diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort index 82fd4d708d4f0..173f571baac23 100644 Binary files a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort and b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort differ diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort index 2fc70dedb3c52..1dab3d743acc7 100644 Binary files a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort and b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort differ diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 84c8b5e9f28d2..1bf08fa55ca88 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -757,7 +757,8 @@ void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execu Status TrainingSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, - MinimalBuildOptimizationHandling minimal_build_optimization_handling) const { + MinimalBuildOptimizationHandling minimal_build_optimization_handling, + RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const { ORT_RETURN_IF_NOT( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, "Only applying full build optimizations is supported by TrainingSession."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index f6bf8985835ab..37b708fb7d1dd 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -488,7 +488,8 @@ class TrainingSession : public InferenceSession { common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, - MinimalBuildOptimizationHandling minimal_build_optimization_handling) const override; + MinimalBuildOptimizationHandling minimal_build_optimization_handling, + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 1c7d6ac8bb7f7..fa8b7f782232a 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gradient_op_test_utils.h" +#include "core/framework/kernel_type_str_resolver.h" #include "core/session/inference_session.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/gradient_graph_builder.h" @@ -180,6 +181,8 @@ void GradientOpTester::Run( bool valid = true; + OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; + // set execution provider for all nodes in the graph for (auto& node : graph.Nodes()) { if (node.OpType() == kConstant) @@ -195,7 +198,7 @@ void GradientOpTester::Run( auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; - auto st = reg->TryFindKernel(node, execution_provider->Type(), &kci); + auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci); if (!st.IsOK()) { // The goal here is unclear. It seems best to leave it to the Session // creation to figure out whether the model can be executed using some diff --git a/orttraining/orttraining/training_ops/cuda/tensor/view.cc b/orttraining/orttraining/training_ops/cuda/tensor/view.cc index 47de07ae1f9d8..efa265ea24177 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/view.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/view.cc @@ -34,7 +34,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) - .TypeConstraint("Shape", DataTypeImpl::GetTensorType()) + .TypeConstraint("shapes", DataTypeImpl::GetTensorType()) .InputMemoryType(OrtMemTypeCPUInput, GenerateInputMemoryType()) // all shape inputs are in CPU .Alias(GenerateAliasMapping()), // all output tensors are sharing the same bffer as input[0], // execept that the byte_offset is different diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index d434b23eda4c0..7bfa470c1bfd9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -203,8 +203,8 @@ jobs: inputs: sourceFolder: $(Build.BinariesDirectory)/android-mobile-aar contents: onnxruntime-mobile-*.aar - targetFolder: $(Build.SourcesDirectory)/js/react_native/e2e/node_modules/onnxruntime-react-native/android/libs - displayName: Copy Android package to React Native e2e directory + targetFolder: $(Build.SourcesDirectory)/js/react_native/e2e/android/app/libs + displayName: Copy Android package to Android e2e test directory - task: Gradle@3 inputs: diff --git a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh index 0f91f1ae0d7c2..7dc0160ec1f4f 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh @@ -30,11 +30,6 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ --use_nnapi \ --use_coreml -# Run kernel def hash verification test -pushd ${BUILD_DIR}/Debug -ORT_TEST_STRICT_KERNEL_DEF_HASH_CHECK=1 ./onnxruntime_test_all --gtest_filter="KernelDefHashTest.ExpectedCpuKernelDefHashes" -popd - # Install the ORT python wheel python3 -m pip install --user ${BUILD_DIR}/Debug/dist/* diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py index a0c27a39594e1..5120552e81330 100644 --- a/tools/ci_build/op_registration_utils.py +++ b/tools/ci_build/op_registration_utils.py @@ -6,6 +6,7 @@ """ import os +import pathlib import sys import typing @@ -14,11 +15,12 @@ log = get_logger("op_registration_utils") -def map_ort_constant_to_domain(ort_constant_name: str): +def map_ort_constant_to_domain(ort_constant_name: str, allow_unknown_constant: bool = True): """ Map the name of the internal ONNX Runtime constant used in operator kernel registrations to the domain name used in ONNX models and configuration files. :param ort_constant_name: ONNX Runtime constant name for the domain from a kernel registration entry. + :param allow_unknown_constant: Whether an unknown constant is allowed or treated as an error. :return: String with public domain name. """ @@ -36,9 +38,13 @@ def map_ort_constant_to_domain(ort_constant_name: str): if ort_constant_name in constant_to_domain_map: return constant_to_domain_map[ort_constant_name] - else: - log.warning("Unknown domain for ONNX Runtime constant of {}.".format(ort_constant_name)) - return None + + unknown_constant_message = "Unknown domain for ONNX Runtime constant of {}.".format(ort_constant_name) + if not allow_unknown_constant: + raise ValueError(unknown_constant_message) + + log.warning(unknown_constant_message) + return None def get_kernel_registration_files(ort_root=None, include_cuda=False): @@ -204,7 +210,9 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: return offset + 1 -def process_kernel_registration_file(filename: str, registration_processor: RegistrationProcessor): +def process_kernel_registration_file( + filename: typing.Union[str, pathlib.Path], registration_processor: RegistrationProcessor +): """ Process a kernel registration file using registration_processor. :param filename: Path to file containing kernel registrations. @@ -231,3 +239,5 @@ def process_kernel_registration_file(filename: str, registration_processor: Regi else: registration_processor.process_other_line(line) offset += 1 + + return True diff --git a/tools/ci_build/reduce_op_kernels.py b/tools/ci_build/reduce_op_kernels.py index 0ed8b6f1505af..e2cf5981363fc 100755 --- a/tools/ci_build/reduce_op_kernels.py +++ b/tools/ci_build/reduce_op_kernels.py @@ -16,7 +16,7 @@ # directory containing the reduced op files, relative to the build directory OP_REDUCTION_DIR = "op_reduction.generated" -# add the path to /tools/python so we can import the config parsing and type reduction processing +# add the path to tools/python so we can import the config parsing and type reduction processing SCRIPT_DIR = Path(__file__).parent.resolve() ORT_ROOT = SCRIPT_DIR.parents[1] sys.path.append(str(ORT_ROOT / "tools" / "python")) @@ -34,13 +34,18 @@ def _adapt_filters_for_extended_minimal_build( Adapts the values returned by parse_config() for an extended minimal build or higher. In particular: - Includes ONNX ops needed by layout transformation + - Includes MS ops needed by NHWC optimizer """ - # layout transformation requires certain ONNX ops to be available - layout_transformation_required_ops = dict() # op name -> set of opset versions - layout_transformation_required_ops_file = ORT_ROOT / "onnxruntime/core/framework/kernel_def_hash_helpers.cc" - with open(layout_transformation_required_ops_file, mode="r") as f: - region_boundary_pattern = re.compile(r"@@region_(begin|end)\(layout_transformation_required_kernels\)@@") - op_to_hash_pattern = re.compile(r'\{"(\w+)_(\d+)",\s+\w+\},') + # graph transformations in an extended minimal build require certain ops to be available + extended_minimal_build_required_op_ids = set() # set of (domain, optype, opset) + with open( + ORT_ROOT / "onnxruntime/core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h", + mode="r", + ) as f: + region_boundary_pattern = re.compile(r"@@region_(begin|end)\(extended_minimal_build_required_kernels\)@@") + op_id_pattern = re.compile( + r'OpIdentifierWithStringViews{(?P\w+),\s+"(?P\w+)",\s+(?P\d+)}' + ) in_region = False for line in f: region_boundary_match = region_boundary_pattern.search(line) @@ -51,31 +56,36 @@ def _adapt_filters_for_extended_minimal_build( if not in_region: continue - op_to_hash_match = op_to_hash_pattern.search(line) - if op_to_hash_match: - op_name, opset = op_to_hash_match.group(1, 2) - layout_transformation_required_ops.setdefault(op_name, set()).add(int(opset)) + op_id_match = op_id_pattern.search(line) + if op_id_match: + domain = op_registration_utils.map_ort_constant_to_domain( + op_id_match.group("domain"), allow_unknown_constant=False + ) + optype = op_id_match.group("optype") + opset = int(op_id_match.group("opset")) + extended_minimal_build_required_op_ids.add((domain, optype, opset)) adapted_required_ops = None if base_required_ops is not None: adapted_required_ops = base_required_ops.copy() - required_onnx_ops = adapted_required_ops.setdefault("ai.onnx", dict()) - for op_type, opsets in layout_transformation_required_ops.items(): - for opset in opsets: - required_onnx_opset_ops = required_onnx_ops.setdefault(opset, set()) - required_onnx_opset_ops.add(op_type) + for domain, optype, opset in extended_minimal_build_required_op_ids: + adapted_required_ops.setdefault(domain, dict()).setdefault(opset, set()).add(optype) adapted_op_type_impl_filter = None if base_op_type_impl_filter is not None: class _AdaptedFilter(OpTypeImplFilterInterface): - def __init__(self, filter_to_adapt: OpTypeImplFilterInterface, required_optypes: typing.Set[str]): + def __init__( + self, + filter_to_adapt: OpTypeImplFilterInterface, + required_domain_and_optypes: typing.Set[typing.Tuple[str, str]], + ): self.filter_to_adapt = filter_to_adapt - self.required_optypes = required_optypes + self.required_domain_and_optypes = required_domain_and_optypes def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str): - # Always require registration for ONNX ops in self.required_optypes. - if domain == "ai.onnx" and optype in self.required_optypes: + # Always require registration for ops in self.required_domain_and_optypes. + if (domain, optype) in self.required_domain_and_optypes: return True return self.filter_to_adapt.is_typed_registration_needed(domain, optype, type_registration_str) @@ -86,7 +96,8 @@ def get_cpp_entries(self): return self.filter_to_adapt.get_cpp_entries() adapted_op_type_impl_filter = _AdaptedFilter( - base_op_type_impl_filter, set(layout_transformation_required_ops.keys()) + base_op_type_impl_filter, + set([(domain, optype) for (domain, optype, opset) in extended_minimal_build_required_op_ids]), ) return (adapted_required_ops, adapted_op_type_impl_filter) @@ -284,23 +295,23 @@ def reduce_ops( :param is_extended_minimal_build_or_higher: Whether this build has at least the features of an extended minimal build enabled. """ - build_dir = Path(build_dir).resolve() - build_dir.mkdir(parents=True, exist_ok=True) + build_dir_path = Path(build_dir).resolve() + build_dir_path.mkdir(parents=True, exist_ok=True) required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction) if is_extended_minimal_build_or_higher: required_ops, op_type_impl_filter = _adapt_filters_for_extended_minimal_build(required_ops, op_type_impl_filter) # delete any existing generated files first - op_reduction_root = _get_op_reduction_root(build_dir) + op_reduction_root = _get_op_reduction_root(build_dir_path) if op_reduction_root.is_dir(): log.info(f"Deleting existing op reduction file root directory: {op_reduction_root}") shutil.rmtree(op_reduction_root) - _generate_provider_registrations(ORT_ROOT, build_dir, use_cuda, required_ops, op_type_impl_filter) + _generate_provider_registrations(ORT_ROOT, build_dir_path, use_cuda, required_ops, op_type_impl_filter) type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else [] - _generate_type_control_overrides(ORT_ROOT, build_dir, type_control_cpp_code) + _generate_type_control_overrides(ORT_ROOT, build_dir_path, type_control_cpp_code) if __name__ == "__main__": diff --git a/tools/python/dump_ort_model.py b/tools/python/dump_ort_model.py index ced54ae73adb1..acb69a593eb27 100644 --- a/tools/python/dump_ort_model.py +++ b/tools/python/dump_ort_model.py @@ -2,13 +2,14 @@ # Licensed under the MIT License. import argparse +import contextlib import os import sys import typing # the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py -import ort_flatbuffers_py.fbs as fbs -from util.ort_format_model.types import FbsTypeInfo +from util.ort_format_model.types import FbsTypeInfo # isort:skip +import ort_flatbuffers_py.fbs as fbs # isort:skip class OrtFormatModelDumper: @@ -23,7 +24,8 @@ def __init__(self, model_path: str): self._buffer = bytearray(self._file) if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0): raise RuntimeError("File does not appear to be a valid ORT format model: '{}'".format(model_path)) - self._model = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0).Model() + self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0) + self._model = self._inference_session.Model() def _dump_initializers(self, graph: fbs.Graph): print("Initializers:") @@ -72,12 +74,13 @@ def _dump_nodeargs(self, graph: fbs.Graph): def _dump_node(self, node: fbs.Node): optype = node.OpType().decode() domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx + since_version = node.SinceVersion() inputs = [node.Inputs(i).decode() for i in range(0, node.InputsLength())] outputs = [node.Outputs(i).decode() for i in range(0, node.OutputsLength())] print( - f"{node.Index()}:{node.Name().decode()}({domain}:{optype}) " - f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]' + f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) " + f'inputs=[{",".join(inputs)}] outputs=[{",".join(outputs)}]' ) def _dump_graph(self, graph: fbs.Graph): @@ -110,12 +113,12 @@ def _dump_graph(self, graph: fbs.Graph): print(f"## End Subgraph {k} ##") def dump(self, output: typing.IO): - graph = self._model.Graph() + with contextlib.redirect_stdout(output): + print(f"ORT format version: {self._inference_session.OrtVersion().decode()}") + print("--------") - original_stdout = sys.stdout - sys.stdout = output - self._dump_graph(graph) - sys.stdout = original_stdout + graph = self._model.Graph() + self._dump_graph(graph) def parse_args(): diff --git a/winml/test/scenario/cppwinrt/CustomOps.cpp b/winml/test/scenario/cppwinrt/CustomOps.cpp index a69aba9be2956..926bae30b3dc8 100644 --- a/winml/test/scenario/cppwinrt/CustomOps.cpp +++ b/winml/test/scenario/cppwinrt/CustomOps.cpp @@ -593,7 +593,7 @@ static void CustomKernelWithCustomSchema() { floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor; floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float; - MLOperatorEdgeTypeConstrant kernelConstraint = {"T", &floatTensorEdgeDesc, 1}; + MLOperatorEdgeTypeConstrant kernelConstraint = {"T1", &floatTensorEdgeDesc, 1}; MLOperatorKernelDescription kernelDesc = { @@ -602,7 +602,7 @@ static void CustomKernelWithCustomSchema() { 7, MLOperatorExecutionType::Cpu, &kernelConstraint, - 1}; + testCases[caseIndex].useTypeLabel ? 1u : 0u}; if (!testCases[caseIndex].attributeDefaultsInSchema) { kernelDesc.defaultAttributes = defaultAttributes;