Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kernel matching logic: decouple from op schemas and remove kernel def hashes #12791

Merged
merged 134 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
0db5888
Save work.
edgchen1 May 3, 2022
c346b73
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 3, 2022
21545be
Save work
edgchen1 May 4, 2022
0746d1b
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 10, 2022
983294d
save work
edgchen1 May 11, 2022
4d6ba51
Remove unused code.
edgchen1 May 13, 2022
43f6340
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 13, 2022
8859fbb
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 13, 2022
106eba9
save work
edgchen1 May 17, 2022
33a940c
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 17, 2022
ed0d91d
Fix to pass tests.
edgchen1 May 17, 2022
bed0a2e
save work
edgchen1 May 18, 2022
568ebbf
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 18, 2022
1cc880f
Update flatbuffers schema.
edgchen1 May 20, 2022
8a8e4b2
Save work
edgchen1 May 24, 2022
a3c78b2
Save work.
edgchen1 May 24, 2022
f96986c
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 24, 2022
9463b11
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 25, 2022
0753293
Update compile_schema.py to first delete generated Python files.
edgchen1 May 27, 2022
1b21a0c
save changes
edgchen1 May 28, 2022
3f9e936
build fix
edgchen1 May 31, 2022
532467c
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 May 31, 2022
31fd820
small fix
edgchen1 Jun 2, 2022
92436bf
Update KernelRegistry, KernelRegistryManager, KernelTypeStrResolver c…
edgchen1 Jun 7, 2022
b35c048
Add KernelTypeStrResolver parameter to IExecutionProvider::GetCapabil…
edgchen1 Jun 7, 2022
937d694
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Jun 7, 2022
9a4acad
save/load kernel_type_str_resolver
edgchen1 Jun 9, 2022
ee010ba
remove kernel hashes from graph partitioning, other updates
edgchen1 Jun 10, 2022
8a2863b
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Jun 22, 2022
8e6af8b
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Jul 5, 2022
070e744
Merge branch 'edgchen1/kernel_matching_experiment' of github.com:micr…
edgchen1 Jul 5, 2022
58da3fe
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Jul 25, 2022
b7ecbea
save work
edgchen1 Jul 26, 2022
30823d1
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Aug 2, 2022
2eb888b
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Aug 4, 2022
af03399
Fix build error.
edgchen1 Aug 5, 2022
0a5672f
Save work, moving to IKernelTypeStrResolver.
edgchen1 Aug 6, 2022
60c04ea
Fix allocation planner test failures.
edgchen1 Aug 8, 2022
b38b290
Merge remote-tracking branch 'origin/master' into edgchen1/kernel_mat…
edgchen1 Aug 8, 2022
3e3a037
save work, adding kernel type str resolver info for minimal build opt…
edgchen1 Aug 11, 2022
0cbb6ca
Merge remote-tracking branch 'origin/main' into edgchen1/kernel_match…
edgchen1 Aug 16, 2022
c90211d
refine utils, use only one of type constraint/io name for kernel type…
edgchen1 Aug 17, 2022
0bfb399
improve error message when kernel defs use an op input/output name in…
edgchen1 Aug 18, 2022
7774117
short term fix for kernel def hash change
edgchen1 Aug 18, 2022
48b0ef0
regenerate ORT models to pass tests
edgchen1 Aug 18, 2022
6778da1
disable OrtModelOnlyTests.TestBackwardsCompat
edgchen1 Aug 18, 2022
6103aa3
clean up unused container utils
edgchen1 Aug 18, 2022
d037585
use AutoRegisteringKernelTypeStrResolver in gradient_op_test_utils.cc
edgchen1 Aug 18, 2022
2c2b8f1
update flatbuffers schema to remove kernel hash usage
edgchen1 Aug 19, 2022
6a3536b
Remove kernel def hashes from ORT format.
edgchen1 Aug 19, 2022
a563dcf
Update reduce_op_kernels.py to check kernel_type_str_resolver_utils.c…
edgchen1 Aug 20, 2022
6271f1f
Remove kernel def hashes, change OpIdentifier to struct.
edgchen1 Aug 23, 2022
98bc5b5
Remove KernelDefBuilder::FixedTypeConstraintForHash().
edgchen1 Aug 23, 2022
419d43e
Fix regex in reduce_op_kernels.py.
edgchen1 Aug 24, 2022
a65eabe
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Aug 24, 2022
3994bef
fix issues to compile in extended minimal build, fix issue in KernelT…
edgchen1 Aug 24, 2022
c104c80
save work, change ORT format to use OpIdentifier
edgchen1 Aug 25, 2022
b7f2412
save work - extended minimal build fixes
edgchen1 Aug 27, 2022
8c86a48
set fused node since version from Graph::CreateFusedSubGraphNode()
edgchen1 Aug 29, 2022
24a6dae
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Aug 29, 2022
5823c30
save work - get minimal build compiling
edgchen1 Aug 30, 2022
a924f32
Refactor graph_partitioner
edgchen1 Aug 30, 2022
4db4fd6
add mutex to AutoRegisteringKernelTypeStrResolver
edgchen1 Aug 30, 2022
4b1bccc
remove unnecessary helpers from KernelTypeStrResolver
edgchen1 Aug 30, 2022
5a54720
Remove using common::Status
edgchen1 Aug 30, 2022
4638538
revert XNNPACK version update
edgchen1 Aug 30, 2022
480c4bc
Fix compile warnings
edgchen1 Aug 30, 2022
3ff00b6
fix build error in InferenceSession::AddPredefinedTransformers
edgchen1 Aug 30, 2022
0e04f96
fix typo in dnnl_execution_provider.cc
edgchen1 Aug 30, 2022
b2576b3
fix XNNPACK EP build error
edgchen1 Aug 30, 2022
c668f52
Fix build error in allocation_planner_test
edgchen1 Aug 30, 2022
42ed4d9
remove duplicate Contains in propagate_cast_ops.cc
edgchen1 Aug 30, 2022
82d4e8f
fix DML compile error
edgchen1 Aug 30, 2022
ff384be
Update docs/OperatorKernels.md
edgchen1 Aug 30, 2022
b327bf8
move typedef to public section
edgchen1 Aug 30, 2022
3927fc9
fix unused parameter
edgchen1 Aug 31, 2022
fe4ec03
fix nuphar build
edgchen1 Aug 31, 2022
af9e6b6
build fixes
edgchen1 Aug 31, 2022
74de408
fix nuphar build
edgchen1 Aug 31, 2022
1df7bfc
fix formatting
edgchen1 Aug 31, 2022
b41fbd0
renaming in kernel_type_str_resolver_utils_test.cc
edgchen1 Aug 31, 2022
5cb403d
fix nuphar test failures
edgchen1 Aug 31, 2022
7fd426c
Increment ORT format version.
edgchen1 Sep 1, 2022
ccc2e86
fix error message
edgchen1 Sep 1, 2022
e8e8f39
Regenerate some ORT format files, update readmes.
edgchen1 Sep 1, 2022
b2c4fce
Remove onnxruntime/test/testdata/ort_backwards_compat.
edgchen1 Sep 1, 2022
0e806a0
regenerate js/ ORT format files
edgchen1 Sep 2, 2022
3d5a868
Fix TVM test failure.
edgchen1 Sep 2, 2022
68ef87d
temporary test change to try to get useful output
edgchen1 Sep 2, 2022
5b1bd3a
Update dump_ort_model.py to fix import order and add version number o…
edgchen1 Sep 2, 2022
f7d01b1
Fix winml test code.
edgchen1 Sep 2, 2022
db9d3dd
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 2, 2022
5420406
another try to get more info from test failure
edgchen1 Sep 2, 2022
5cd51c0
fix java test issue
edgchen1 Sep 3, 2022
3c24144
Update js/node/test/testdata/test_types_x.onnx models to opset 15.
edgchen1 Sep 6, 2022
785565e
Regenerate js/react_native ort format models.
edgchen1 Sep 6, 2022
a15f607
Add since version to dump_ort_model.py output.
edgchen1 Sep 6, 2022
de3591c
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 6, 2022
f03265f
Fix test model paths.
edgchen1 Sep 6, 2022
caa44c2
get more test output for other failing react native tests
edgchen1 Sep 7, 2022
cb41680
disable react native tests using op types that are not enabled in mob…
edgchen1 Sep 7, 2022
0fbd6bd
Merge branch 'edgchen1/static_kernel_update_fix' into edgchen1/static…
edgchen1 Sep 7, 2022
1815a5a
update readme for generating test ORT models
edgchen1 Sep 7, 2022
5dc8ba0
update codeowners file
edgchen1 Sep 7, 2022
161d6fb
skip react native ios tests for unsupported op types
edgchen1 Sep 7, 2022
c678b17
address some PR comments
edgchen1 Sep 8, 2022
25564f5
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 8, 2022
e047701
add documentation, remove unnecessary comment, make KernelTypeStrReso…
edgchen1 Sep 8, 2022
3670b09
fix odd formatting
edgchen1 Sep 8, 2022
a50b640
add TODO to hash_combine.h
edgchen1 Sep 9, 2022
bfcc36c
use string for op id in ORT format
edgchen1 Sep 9, 2022
02ecd7d
Add TODO comments.
edgchen1 Sep 12, 2022
e312452
update GetCapability() to take a IKernelLookup, get CPU/NNAPI build c…
edgchen1 Sep 12, 2022
16073fc
convert other EP::GetCapability calls
edgchen1 Sep 13, 2022
1f1bc2b
Fix warning.
edgchen1 Sep 13, 2022
985dcbd
fix build errors
edgchen1 Sep 13, 2022
1996468
Fix bug keeping reference to out of scope string.
edgchen1 Sep 13, 2022
88c1f61
update react native e2e tests for new mnist
edgchen1 Sep 13, 2022
1122239
remove unnecessary include
edgchen1 Sep 13, 2022
8bdf586
name todo
edgchen1 Sep 13, 2022
d90b5d4
try to fix react native e2e tests, convert models again, update gradl…
edgchen1 Sep 14, 2022
6beb71d
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 14, 2022
9d7cbf9
Address some PR comments.
edgchen1 Sep 14, 2022
359ebc1
fix for non-abseil build
edgchen1 Sep 14, 2022
d059c83
line length
edgchen1 Sep 15, 2022
9ffd4e4
remove shared provider KernelRegistry::TryFindKernel
edgchen1 Sep 15, 2022
22071cf
Add react_native/e2e/src/mnist.onnx.
edgchen1 Sep 15, 2022
00f057f
address PR comments
edgchen1 Sep 15, 2022
5222c10
Add reference about ORT format model breaking change to version check…
edgchen1 Sep 15, 2022
da5c9f4
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 15, 2022
716b756
small fixes
edgchen1 Sep 16, 2022
adf351a
Merge remote-tracking branch 'origin/main' into edgchen1/static_kerne…
edgchen1 Sep 16, 2022
e15fbe8
more fixes
edgchen1 Sep 17, 2022
396a957
update comments referring to kernel def hashes
edgchen1 Sep 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Mobile
/onnxruntime/test/testdata/kernel_def_hashes/ @microsoft/onnxruntime-mobile
/onnxruntime/core/framework/kernel_def_hash_helpers.* @microsoft/onnxruntime-mobile
/onnxruntime/core/flatbuffers/schema/ort.fbs @microsoft/onnxruntime-mobile

# Contrib Ops
/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @microsoft/onnxruntime-mlas
Expand Down
12 changes: 12 additions & 0 deletions docs/ORT_Format_Update_in_1.13.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ORT Format Update in 1.13

In ONNX Runtime 1.13, there was a breaking change to the
[ORT format](https://onnxruntime.ai/docs/reference/ort-format-models.html) in order to enable additional execution
providers with statically registered kernels in a minimal build.
More details can be found [here](../onnxruntime/core/flatbuffers/schema/README.md#version-5).

Unfortunately, this means that any older models (prior to ORT format version 5) will no longer work with ONNX Runtime
1.13 or later and must be re-converted.
Please refer
[here](https://onnxruntime.ai/docs/reference/ort-format-models.html#convert-onnx-models-to-ort-format) for instructions
on how to convert an ONNX model to ORT format.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**X** = tensor(float)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcMaxPool|*in* x:**T**<br> *out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
14 changes: 8 additions & 6 deletions include/onnxruntime/core/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#pragma once

#include <algorithm>
#include <cstring>
#include <climits>
#include <cstring>
#include <algorithm>
#include <chrono>
#include <functional>
#include <memory>
#include <numeric>
Expand All @@ -28,8 +29,8 @@
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include <chrono>

#include "core/common/code_location.h"
#include "core/common/exceptions.h"
Expand Down Expand Up @@ -279,9 +280,10 @@ constexpr size_t kMaxStrLen = 2048;
// Returns whether `key` is in `container`.
// Like C++20's map/set contains() member function.
template <typename Key, typename... OtherContainerArgs,
template <typename...> typename AssociativeContainer>
inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, const Key& key) {
return container.find(key) != container.end();
template <typename...> typename AssociativeContainer,
typename LookupKey>
inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, LookupKey&& key) {
return container.find(std::forward<LookupKey>(key)) != container.end();
}

} // namespace onnxruntime
21 changes: 21 additions & 0 deletions include/onnxruntime/core/common/hash_combine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {

// Combine hash value `seed` with hash value `h`, updating `seed` in place.
// TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine()
inline void HashCombineWithHashValue(size_t h, size_t& seed) {
seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// Combine hash value `seed` with the hash value of `value`, updating `seed` in place.
// The hash value computation is specified by the `Hash` template parameter.
template <typename T, typename Hash = std::hash<T>>
inline void HashCombine(const T& value, size_t& seed) {
HashCombineWithHashValue(Hash{}(value), seed);
}

} // namespace onnxruntime
13 changes: 7 additions & 6 deletions include/onnxruntime/core/common/parse_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <locale>
#include <sstream>
#include <string_view>
#include <type_traits>

#include "core/common/common.h"
Expand All @@ -15,7 +16,7 @@ namespace onnxruntime {
* Tries to parse a value from an entire string.
*/
template <typename T>
bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
Expand All @@ -28,7 +29,7 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
return false;
}

std::istringstream is{str};
std::istringstream is{std::string{str}};
is.imbue(std::locale::classic());
T parsed_value{};

Expand All @@ -43,12 +44,12 @@ bool TryParseStringWithClassicLocale(const std::string& str, T& value) {
return true;
}

inline bool TryParseStringWithClassicLocale(const std::string& str, std::string& value) {
inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
value = str;
return true;
}

inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value) {
inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
if (str == "0" || str == "False" || str == "false") {
value = false;
return true;
Expand All @@ -66,7 +67,7 @@ inline bool TryParseStringWithClassicLocale(const std::string& str, bool& value)
* Parses a value from an entire string.
*/
template <typename T>
Status ParseStringWithClassicLocale(const std::string& s, T& value) {
Status ParseStringWithClassicLocale(std::string_view s, T& value) {
ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
return Status::OK();
}
Expand All @@ -75,7 +76,7 @@ Status ParseStringWithClassicLocale(const std::string& s, T& value) {
* Parses a value from an entire string.
*/
template <typename T>
T ParseStringWithClassicLocale(const std::string& s) {
T ParseStringWithClassicLocale(std::string_view s) {
T value{};
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value));
return value;
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ class TensorType : public TensorTypeBase {

#if defined(DISABLE_OPTIONAL_TYPE)

// TODO is this still needed after removing kernel def hashes?
/// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
/// with disabled types to keep the ORT format model kernel hashes stable.
class DisabledTypeBase : public DataTypeImpl {
Expand Down
26 changes: 20 additions & 6 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

namespace onnxruntime {
class GraphViewer;
class Node;
struct ComputeCapability;
class KernelRegistry;
class KernelRegistryManager;

struct KernelCreateInfo;
class Node;
} // namespace onnxruntime
#else
#include <memory>
Expand Down Expand Up @@ -89,29 +88,44 @@ class IExecutionProvider {
return nullptr;
}

/**
* Interface for performing kernel lookup within kernel registries.
* Abstracts away lower-level details about kernel registries and kernel matching.
*/
class IKernelLookup {
public:
/**
* Given `node`, try to find a matching kernel for this EP.
* The return value is non-null if and only if a matching kernel was found.
*/
virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
};

/**
Get execution provider's capability for the specified <graph>.
Return a bunch of IndexedSubGraphs <*this> execution provider can run if
the sub-graph contains only one node or can fuse to run if the sub-graph
contains more than one node. The node indexes contained in sub-graphs may
have overlap, and it's ONNXRuntime's responsibility to do the partition
and decide whether a node will be assigned to <*this> execution provider.
For kernels registered in a kernel registry, `kernel_lookup` must be used
to find a matching kernel for this EP.
*/
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const;
const IKernelLookup& kernel_lookup) const;

/**
Get kernel registry per execution provider type.
The KernelRegistry share pointer returned is shared across sessions.

NOTE: this is a tricky but final solution to achieve following goals,
NOTE: this approach was taken to achieve the following goals,
1. The execution provider type based kernel registry should be shared
across sessions.
Only one copy of this kind of kernel registry exists in ONNXRuntime
with multiple sessions/models.
2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
frameowrk/session code.
framework/session code.
3. onnxruntime (framework/session) does not depend on any specific
execution provider lib.
*/
Expand Down
44 changes: 7 additions & 37 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ class KernelDef {
return provider_type_;
}

// TODO(edgchen1) do we need both TypeConstraints() and EnabledTypeConstraints()?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly not. Do we use them anywhere in the new setup? We need to write the full list of types into the ORT format model don't we, as we wouldn't know which types are included in the ORT build running the model.

Copy link
Contributor Author

@edgchen1 edgchen1 Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the two were needed to preserve the hash value (which used TypeConstraints()). we don't need to store these kernel def types in the ORT format model, we only store the info needed to associate the op args with these types.

we can probably just change TypeConstraints() to return the types enabled in the current build.

one place I'm not sure about is this usage in the Python API:

const auto& tempResult = kernelDef.TypeConstraints();

if that were to start returning only the types enabled in the current build, would it be ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That file is only included in a build where we use the python bindings to generate the documentation of supported op schemas for the CPU/CUDA/DML EPs (--gen_doc flag to build.py). Not included by default.

If someone did actually create a python package with a build with reduced type support AND wanted to generate schema documentation from it (highly unlikely I suspect) it's probably more correct to return only the types enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's good to know. I think I will consolidate them in another PR as that will require updating more files.


// type constraints with types supported by default
const std::map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
const std::unordered_map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
return default_type_constraints_;
}

// type constraints with types supported in this build
const std::map<std::string, std::vector<MLDataType>>& EnabledTypeConstraints() const {
const std::unordered_map<std::string, std::vector<MLDataType>>& EnabledTypeConstraints() const {
return enabled_type_constraints_;
}

Expand Down Expand Up @@ -108,19 +110,9 @@ class KernelDef {

bool IsConflict(const KernelDef& other) const;

HashValue GetHash() const noexcept {
// if we need to support different hash versions we can update CalculateHash to take a version number
// and calculate any non-default versions dynamically. we only use this during kernel lookup so
// it's not performance critical
return hash_;
}

private:
friend class KernelDefBuilder;

// called once by KernelDefBuilder::Build
void CalculateHash();

// The operator name supported by <*this> kernel..
std::string op_name_;

Expand All @@ -139,18 +131,11 @@ class KernelDef {
std::string provider_type_;

// The data types that are supported by default for inputs/outputs.
// Key is input/output name defined in op schema, Value are supported types.
// note: std::map as we need the order to be deterministic for the hash
// Note: default_type_constraints_ are used to calculate the kernel hash so that the hash is
// stable across builds with and without kernel type reduction enabled.
std::map<std::string, std::vector<MLDataType>> default_type_constraints_;
// Key is input/output/type constraint name defined in op schema, Value are supported types.
std::unordered_map<std::string, std::vector<MLDataType>> default_type_constraints_;

// the type constraints that are supported in this build (enabled) for the kernel
std::map<std::string, std::vector<MLDataType>> enabled_type_constraints_;

// optional alternate type constraints to use to calculate the hash instead of default_type_constraints_
// note: this provides a way to update the default type constraints while preserving the hash value
optional<std::map<std::string, std::vector<MLDataType>>> hash_type_constraints_;
std::unordered_map<std::string, std::vector<MLDataType>> enabled_type_constraints_;

// An element <i, j> means that output j reuses the memory of input i.
std::vector<std::pair<int, int>> inplace_map_;
Expand Down Expand Up @@ -186,9 +171,6 @@ class KernelDef {
OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
// Default memory type for all outputs
OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};

// hash of kernel definition for lookup in minimal build
HashValue hash_ = 0;
};

class KernelDefBuilder {
Expand Down Expand Up @@ -259,17 +241,6 @@ class KernelDefBuilder {
KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType default_type);
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType default_type);

/**
Specify the original set of types that this kernel supports by default to use when computing the kernel def hash.
The set of types supported by default may change over time, but the hash should stay the same.
*/
KernelDefBuilder& FixedTypeConstraintForHash(
const std::string& arg_name,
const std::vector<MLDataType>& default_types_for_hash);
KernelDefBuilder& FixedTypeConstraintForHash(
const char* arg_name,
const std::vector<MLDataType>& default_types_for_hash);

/**
Inplace mapping from inputs to outputs allowed.
It means that uplayer runtime could do memory in-place optimization
Expand Down Expand Up @@ -392,7 +363,6 @@ class KernelDefBuilder {
Return the kernel definition, passing ownership of the KernelDef to the caller
*/
std::unique_ptr<KernelDef> Build() {
kernel_def_->CalculateHash();
return std::move(kernel_def_);
}

Expand Down
Loading