Skip to content

Commit

Permalink
Update kernel matching logic: decouple from op schemas and remove ker…
Browse files Browse the repository at this point in the history
…nel def hashes (#12791)

# Motivation
Currently, ORT minimal builds use kernel def hashes to map from nodes to
kernels to execute when loading the model. As the kernel def hashes must
be known ahead of time, this works for statically registered kernels.
This works well for the CPU EP.
For this approach to work, the kernel def hashes must also be known at
ORT format model conversion time, which means the EP with statically
registered kernels must also be enabled then. This is not an issue for
the always-available CPU EP. However, we do not want to require that any
EP which statically registers kernels is always available too.
Consequently, we explore another approach to match nodes to kernels that
does not rely on kernel def hashes. An added benefit of this is the
possibility of moving away from kernel def hashes completely, which
would eliminate the maintenance burden of keeping the hashes stable.

# Approach
In a full build, ORT uses some information from the ONNX op schema to
match a node to a kernel. We want to avoid including the ONNX op schema
in a minimal build to reduce binary size. Essentially, we take the
necessary information from the ONNX op schema and make it available in a
minimal build.
We decouple the ONNX op schema from the kernel matching logic. The
kernel matching logic instead relies on per-op information which can
either be obtained from the ONNX op schema or another source.
This per-op information must be available in a minimal build when there
are no ONNX op schemas. We put it in the ORT format model.
Existing uses of kernel def hashes to look up kernels are replaced
with the updated kernel matching logic. We no longer store
kernel def hashes in the ORT format model’s session state and runtime
optimization representations. We no longer keep the logic to
generate and ensure stability of kernel def hashes.
  • Loading branch information
edgchen1 authored Sep 20, 2022
1 parent 32878a1 commit 454f77c
Show file tree
Hide file tree
Showing 236 changed files with 3,282 additions and 6,178 deletions.
3 changes: 1 addition & 2 deletions CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Mobile
/onnxruntime/test/testdata/kernel_def_hashes/ @microsoft/onnxruntime-mobile
/onnxruntime/core/framework/kernel_def_hash_helpers.* @microsoft/onnxruntime-mobile
/onnxruntime/core/flatbuffers/schema/ort.fbs @microsoft/onnxruntime-mobile

# Contrib Ops
/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @microsoft/onnxruntime-mlas
Expand Down
12 changes: 12 additions & 0 deletions docs/ORT_Format_Update_in_1.13.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ORT Format Update in 1.13

In ONNX Runtime 1.13, there was a breaking change to the
[ORT format](https://onnxruntime.ai/docs/reference/ort-format-models.html) in order to enable additional execution
providers with statically registered kernels in a minimal build.
More details can be found [here](../onnxruntime/core/flatbuffers/schema/README.md#version-5).

Unfortunately, this means that any older models (prior to ORT format version 5) will no longer work with ONNX Runtime
1.13 or later and must be re-converted.
Please refer
[here](https://onnxruntime.ai/docs/reference/ort-format-models.html#convert-onnx-models-to-ort-format) for instructions
on how to convert an ONNX model to ORT format.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**X** = tensor(float)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcMaxPool|*in* x:**T**<br> *out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
14 changes: 8 additions & 6 deletions include/onnxruntime/core/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#pragma once

#include <algorithm>
#include <cstring>
#include <climits>
#include <cstring>
#include <algorithm>
#include <chrono>
#include <functional>
#include <memory>
#include <numeric>
Expand All @@ -28,8 +29,8 @@
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include <chrono>

#include "core/common/code_location.h"
#include "core/common/exceptions.h"
Expand Down Expand Up @@ -279,9 +280,10 @@ constexpr size_t kMaxStrLen = 2048;
// Returns whether `key` is in `container`.
// Like C++20's map/set contains() member function.
template <typename Key, typename... OtherContainerArgs,
template <typename...> typename AssociativeContainer>
inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, const Key& key) {
return container.find(key) != container.end();
template <typename...> typename AssociativeContainer,
typename LookupKey>
inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, LookupKey&& key) {
return container.find(std::forward<LookupKey>(key)) != container.end();
}

} // namespace onnxruntime
21 changes: 21 additions & 0 deletions include/onnxruntime/core/common/hash_combine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {

// Combine hash value `seed` with hash value `h`, updating `seed` in place.
// TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine()
inline void HashCombineWithHashValue(size_t h, size_t& seed) {
seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// Combine hash value `seed` with the hash value of `value`, updating `seed` in place.
// The hash value computation is specified by the `Hash` template parameter.
template <typename T, typename Hash = std::hash<T>>
inline void HashCombine(const T& value, size_t& seed) {
HashCombineWithHashValue(Hash{}(value), seed);
}

} // namespace onnxruntime
13 changes: 7 additions & 6 deletions include/onnxruntime/core/common/parse_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <locale>
#include <sstream>
#include <string_view>
#include <type_traits>

#include "core/common/common.h"
Expand All @@ -15,7 +16,7 @@ namespace onnxruntime {
* Tries to parse a value from an entire string.
*/
template <typename T>
bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
Expand All @@ -28,7 +29,7 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
return false;
}

std::istringstream is{str};
std::istringstream is{std::string{str}};
is.imbue(std::locale::classic());
T parsed_value{};

Expand All @@ -43,12 +44,12 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
return true;
}

inline bool TryParseStringWithClassicLocale(const std::string& str, std::string& value) {
inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
value = str;
return true;
}

inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value) {
inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
if (str == "0" || str == "False" || str == "false") {
value = false;
return true;
Expand All @@ -66,7 +67,7 @@ inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value)
* Parses a value from an entire string.
*/
template <typename T>
Status ParseStringWithClassicLocale(const std::string& s, T& value) {
Status ParseStringWithClassicLocale(std::string_view s, T& value) {
ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
return Status::OK();
}
Expand All @@ -75,7 +76,7 @@ Status ParseStringWithClassicLocale(const std::string& s, T& value) {
* Parses a value from an entire string.
*/
template <typename T>
T ParseStringWithClassicLocale(const std::string& s) {
T ParseStringWithClassicLocale(std::string_view s) {
T value{};
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value));
return value;
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ class TensorType : public TensorTypeBase {

#if defined(DISABLE_OPTIONAL_TYPE)

// TODO is this still needed after removing kernel def hashes?
/// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
/// with disabled types to keep the ORT format model kernel hashes stable.
class DisabledTypeBase : public DataTypeImpl {
Expand Down
26 changes: 20 additions & 6 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

namespace onnxruntime {
class GraphViewer;
class Node;
struct ComputeCapability;
class KernelRegistry;
class KernelRegistryManager;

struct KernelCreateInfo;
class Node;
} // namespace onnxruntime
#else
#include <memory>
Expand Down Expand Up @@ -89,29 +88,44 @@ class IExecutionProvider {
return nullptr;
}

/**
* Interface for performing kernel lookup within kernel registries.
* Abstracts away lower-level details about kernel registries and kernel matching.
*/
class IKernelLookup {
public:
/**
* Given `node`, try to find a matching kernel for this EP.
* The return value is non-null if and only if a matching kernel was found.
*/
virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
};

/**
Get execution provider's capability for the specified <graph>.
Return a bunch of IndexedSubGraphs <*this> execution provider can run if
the sub-graph contains only one node or can fuse to run if the sub-graph
contains more than one node. The node indexes contained in sub-graphs may
have overlap, and it's ONNXRuntime's responsibility to do the partition
and decide whether a node will be assigned to <*this> execution provider.
For kernels registered in a kernel registry, `kernel_lookup` must be used
to find a matching kernel for this EP.
*/
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const;
const IKernelLookup& kernel_lookup) const;

/**
Get kernel registry per execution provider type.
The KernelRegistry share pointer returned is shared across sessions.
NOTE: this is a tricky but final solution to achieve following goals,
NOTE: this approach was taken to achieve the following goals,
1. The execution provider type based kernel registry should be shared
across sessions.
Only one copy of this kind of kernel registry exists in ONNXRuntime
with multiple sessions/models.
2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
frameowrk/session code.
framework/session code.
3. onnxruntime (framework/session) does not depend on any specific
execution provider lib.
*/
Expand Down
44 changes: 7 additions & 37 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ class KernelDef {
return provider_type_;
}

// TODO(edgchen1) do we need both TypeConstraints() and EnabledTypeConstraints()?

// type constraints with types supported by default
const std::map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
const std::unordered_map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
return default_type_constraints_;
}

// type constraints with types supported in this build
const std::map<std::string, std::vector<MLDataType>>& EnabledTypeConstraints() const {
const std::unordered_map<std::string, std::vector<MLDataType>>& EnabledTypeConstraints() const {
return enabled_type_constraints_;
}

Expand Down Expand Up @@ -108,19 +110,9 @@ class KernelDef {

bool IsConflict(const KernelDef& other) const;

HashValue GetHash() const noexcept {
// if we need to support different hash versions we can update CalculateHash to take a version number
// and calculate any non-default versions dynamically. we only use this during kernel lookup so
// it's not performance critical
return hash_;
}

private:
friend class KernelDefBuilder;

// called once by KernelDefBuilder::Build
void CalculateHash();

// The operator name supported by <*this> kernel..
std::string op_name_;

Expand All @@ -139,18 +131,11 @@ class KernelDef {
std::string provider_type_;

// The data types that are supported by default for inputs/outputs.
// Key is input/output name defined in op schema, Value are supported types.
// note: std::map as we need the order to be deterministic for the hash
// Note: default_type_constraints_ are used to calculate the kernel hash so that the hash is
// stable across builds with and without kernel type reduction enabled.
std::map<std::string, std::vector<MLDataType>> default_type_constraints_;
// Key is input/output/type constraint name defined in op schema, Value are supported types.
std::unordered_map<std::string, std::vector<MLDataType>> default_type_constraints_;

// the type constraints that are supported in this build (enabled) for the kernel
std::map<std::string, std::vector<MLDataType>> enabled_type_constraints_;

// optional alternate type constraints to use to calculate the hash instead of default_type_constraints_
// note: this provides a way to update the default type constraints while preserving the hash value
optional<std::map<std::string, std::vector<MLDataType>>> hash_type_constraints_;
std::unordered_map<std::string, std::vector<MLDataType>> enabled_type_constraints_;

// An element <i, j> means that output j reuses the memory of input i.
std::vector<std::pair<int, int>> inplace_map_;
Expand Down Expand Up @@ -186,9 +171,6 @@ class KernelDef {
OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
// Default memory type for all outputs
OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};

// hash of kernel definition for lookup in minimal build
HashValue hash_ = 0;
};

class KernelDefBuilder {
Expand Down Expand Up @@ -259,17 +241,6 @@ class KernelDefBuilder {
KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType default_type);
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType default_type);

/**
Specify the original set of types that this kernel supports by default to use when computing the kernel def hash.
The set of types supported by default may change over time, but the hash should stay the same.
*/
KernelDefBuilder& FixedTypeConstraintForHash(
const std::string& arg_name,
const std::vector<MLDataType>& default_types_for_hash);
KernelDefBuilder& FixedTypeConstraintForHash(
const char* arg_name,
const std::vector<MLDataType>& default_types_for_hash);

/**
Inplace mapping from inputs to outputs allowed.
It means that uplayer runtime could do memory in-place optimization
Expand Down Expand Up @@ -392,7 +363,6 @@ class KernelDefBuilder {
Return the kernel definition, passing ownership of the KernelDef to the caller
*/
std::unique_ptr<KernelDef> Build() {
kernel_def_->CalculateHash();
return std::move(kernel_def_);
}

Expand Down
Loading

0 comments on commit 454f77c

Please sign in to comment.