Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Builder API #23223

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ endif()
if (onnxruntime_MINIMAL_BUILD)
set(onnxruntime_session_src_exclude
"${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc"
"${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc"
)

list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude})
Expand Down
6 changes: 6 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ set (onnxruntime_shared_lib_test_SRC

if (NOT onnxruntime_MINIMAL_BUILD)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc)
endif()

if(onnxruntime_RUN_ONNX_TESTS)
Expand Down Expand Up @@ -1353,14 +1354,19 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)

target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT})

if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()

if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
endif()

if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
Expand Down
32 changes: 30 additions & 2 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "core/common/span_utils.h"
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/framework/ort_value.h"
#include "core/framework/prepacked_weights_container.h"
#include "core/graph/onnx_protobuf.h"
#include "core/graph/basic_types.h"
Expand All @@ -39,6 +40,9 @@
#include "core/graph/node_arg.h"
#include "core/graph/ort_format_load_options.h"

// Type from Model Editor API in ORT C API so can't be in a namespace
struct OrtGraph;

namespace onnxruntime {
class Graph;
struct IndexedSubGraph;
Expand Down Expand Up @@ -763,6 +767,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
*/
bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;

/** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name.
*/
bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const;

/** Gets all the initializer tensors in this Graph. */
const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; }

Expand Down Expand Up @@ -1430,6 +1438,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);

static Status LoadFromModelEditorApiModel(const OrtGraph& api_graph,
const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
bool strict_shape_type_inference,
const logging::Logger& logger,
std::unique_ptr<Graph>& graph);

Status UpdateUsingModelEditorApiModel(const OrtModel& api_model);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
return runtime_optimizations_;
Expand Down Expand Up @@ -1630,7 +1648,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
// Implementation for initializer replacement
Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);

std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
template <typename StringRange> // range-initializer returning std::string
std::vector<NodeArg*> CreateNodeArgs(const StringRange& names,
const ArgNameToTypeMap& name_to_type_map);

void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
Expand Down Expand Up @@ -1694,6 +1713,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return nodes_[node_index].get();
}

Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false);

const Model& owning_model_;

// GraphProto to store name, version, initializer.
Expand All @@ -1708,6 +1729,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi

InitializedTensorSet name_to_initial_tensor_;

// Initializers that are external to the Graph.
// e.g. created from existing memory using CreateTensorWithDataAndDeleterAsOrtValue in the ORT API.
// As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them
// in the Graph instance and retrieve during session state finalization.
std::unordered_map<std::string, OrtValue> ortvalue_initializers_;

std::unordered_set<std::reference_wrapper<const std::string>,
std::hash<std::string>, std::equal_to<std::string>>
sparse_tensor_names_;
Expand Down Expand Up @@ -1744,6 +1771,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
// in some case, a fused sub-graph will happens multiple times in one model, we use a map
// to store reusable-schema in lookup.
InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;

#endif // !defined(ORT_MINIMAL_BUILD)

// Graph nodes.
Expand Down Expand Up @@ -1806,7 +1834,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

const std::unordered_map<std::string, int> domain_to_version_;
std::unordered_map<std::string, int> domain_to_version_;

// Model IR version.
Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); }
#endif

/** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name.
*/
bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const {

Check warning on line 198 in include/onnxruntime/core/graph/graph_viewer.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/graph/graph_viewer.h:198: Add #include <string> for string [build/include_what_you_use] [4]
return graph_->GetOrtValueInitializer(name, value);
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);
GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info);
Expand Down
Loading
Loading