From 0b143d07036baae17bc9014ad971c96b2171247b Mon Sep 17 00:00:00 2001 From: jignparm Date: Wed, 6 Mar 2019 18:01:49 -0800 Subject: [PATCH 1/6] Fix parentheses and commas (#560) From af9c554dd354e897cf2fd64714889c4a0639496b Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Wed, 6 Mar 2019 19:09:55 -0800 Subject: [PATCH 2/6] Ryanunderhill/custom op (#550) * Prototype version that demonstrates it can work * Switched to OrtValue and removed the OrtCustomOpTensor code. * Support multiple outputs and reading of attributes * Add custom domain handling to custom ops * Update documentation * more wording changes --- docs/AddingCustomOp.md | 16 +- .../onnxruntime/core/framework/op_kernel.h | 1 + .../core/session/onnxruntime_c_api.h | 67 ++++++- onnxruntime/core/framework/op_kernel.cc | 7 +- .../framework/op_kernel_context_internal.h | 4 + .../core/framework/tensor_type_and_shape.cc | 50 ++++-- .../core/framework/tensor_type_and_shape.h | 13 ++ onnxruntime/core/providers/cpu/symbols.txt | 4 + .../core/session/abi_session_options_impl.h | 1 + onnxruntime/core/session/inference_session.cc | 167 ++++++++++++++++++ onnxruntime/core/session/inference_session.h | 9 + onnxruntime/core/session/onnxruntime_c_api.cc | 64 ++++--- onnxruntime/test/shared_lib/test_inference.cc | 88 ++++++++- 13 files changed, 441 insertions(+), 50 deletions(-) create mode 100644 onnxruntime/core/framework/tensor_type_and_shape.h diff --git a/docs/AddingCustomOp.md b/docs/AddingCustomOp.md index 6ad742617b5a1..25bbfdef14626 100644 --- a/docs/AddingCustomOp.md +++ b/docs/AddingCustomOp.md @@ -2,16 +2,12 @@ Adding a new op =============== ## A new op can be written and registered with ONNXRuntime in the following 3 ways -### 1. Using a dynamic shared library -* First write the implementation of the op and schema (if required) and assemble them in a shared library. -See [this](../onnxruntime/test/custom_op_shared_lib) for an example. Currently -this is supported for Linux only. - -Example of creating a shared lib using g++ on Linux: -```g++ -std=c++14 -shared test_custom_op.cc -o test_custom_op.so -fPIC -I. -Iinclude/onnxruntime -L. -lonnxruntime -DONNX_ML -DONNX_NAMESPACE=onnx``` - -* Register the shared lib with ONNXRuntime. -See [this](../onnxruntime/test/shared_lib/test_inference.cc) for an example. +### 1. Using the experimental custom op API in the C API (onnxruntime_c_api.h) +Note: These APIs are experimental and will change in the next release. They're released now for feedback and experimentation. +* Create an OrtCustomOpDomain with the domain name used by the custom ops +* Create an OrtCustomOp structure for each op and add them to the OrtCustomOpDomain with OrtCustomOpDomain_Add +* Call OrtAddCustomOpDomain to add the custom domain of ops to the session options +See [this](../onnxruntime/test/custom_op_shared_lib/test_custom_op.cc) for an example. ### 2. Using RegisterCustomRegistry API * Implement your kernel and schema (if required) using the OpKernel and OpSchema APIs (headers are in the include folder). diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 05393030f33de..d41a0f6da56f6 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -151,6 +151,7 @@ class OpKernelContext { const MLValue* GetInputMLValue(int index) const; const MLValue* GetImplicitInputMLValue(int index) const; MLValue* GetOutputMLValue(int index); + MLValue* OutputMLValue(int index, const TensorShape& shape); // Creates the MLValue* based on the shape, if it does not exist private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index fe06045c734e7..b93451adb8b1b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,7 +66,6 @@ extern "C" { #define NO_EXCEPTION #endif - // Copied from TensorProto::DataType // Currently, Ort doesn't support complex64, complex128, bfloat16 types typedef enum ONNXTensorElementDataType { @@ -152,6 +151,7 @@ ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); ORT_RUNTIME_CLASS(Callback); +ORT_RUNTIME_CLASS(CustomOpDomain); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. @@ -511,6 +511,71 @@ ORT_API_STATUS(OrtGetValueCount, const OrtValue* value, size_t* out); ORT_API_STATUS(OrtCreateValue, OrtValue** const in, int num_values, enum ONNXType value_type, OrtValue** out); +/* + * EXPERIMENTAL APIS - Subject to change. Released as a preview to get feedback and enable early testing +*/ + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options +*/ +struct OrtKernelInfo; +typedef struct OrtKernelInfo OrtKernelInfo; + +/* + * These allow reading node attributes during kernel creation +*/ +ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); +ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. +*/ +struct OrtCustomOp { + uint32_t version; // Initialize to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. + void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ OrtKernelInfo* info, _Out_ void** op_kernel); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op); + + // Op kernel callbacks + void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output); + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ OrtValue** outputs, _In_ size_t output_count); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); +}; +typedef struct OrtCustomOp OrtCustomOp; + +/* +* Create a custom op domain. After all sessions using it are released, call OrtReleaseCustomOpDomain +*/ +ORT_API(OrtCustomOpDomain*, OrtCreateCustomOpDomain, _In_ const char* domain, _In_ int op_version_start, _In_ int op_version_end); + +/* + * Add custom ops to the OrtCustomOpDomain + * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released +*/ +ORT_API_STATUS(OrtCustomOpDomain_Add, _In_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op); + +/* + * Add a custom op domain to the OrtSessionOptions + * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released +*/ +ORT_API_STATUS(OrtAddCustomOpDomain, _In_ OrtSessionOptions* options, OrtCustomOpDomain* custom_op_domain); +/* + * END EXPERIMENTAL +*/ + #ifdef __cplusplus } #endif diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 6e26b70efbe6d..dfd7489a2d294 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -24,6 +24,11 @@ OpKernelContext::OpKernelContext(IExecutionFrame* frame, } Tensor* OpKernelContext::Output(int index, const TensorShape& shape) { + auto p_ml_value = OutputMLValue(index, shape); + return p_ml_value ? p_ml_value->GetMutable() : nullptr; +} + +MLValue* OpKernelContext::OutputMLValue(int index, const TensorShape& shape) { if (index < 0 || index >= OutputCount()) return nullptr; @@ -34,7 +39,7 @@ Tensor* OpKernelContext::Output(int index, const TensorShape& shape) { MLValue* p_ml_value = nullptr; Status status = execution_frame_->GetOrCreateNodeOutputMLValue(GetOutputArgIndex(index), &shape, p_ml_value); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return p_ml_value ? p_ml_value->GetMutable() : nullptr; + return p_ml_value; } int OpKernelContext::NumVariadicInputs(size_t arg_num) const { diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index 3ec850310d5a6..5cc3a50a96d52 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -39,6 +39,10 @@ class OpKernelContextInternal : public OpKernelContext { return OpKernelContext::GetOutputMLValue(index); } + MLValue* OutputMLValue(int index, const TensorShape& shape) { + return OpKernelContext::OutputMLValue(index, shape); + } + std::unordered_map GetImplicitInputs() const { // we need to convert implicit_inputs_ to a name to MLValue map so it can be used in the ExecutionFrame // for a subgraph (the index numbers will be different there). diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 3fd5323536d3e..af3c52c51b33b 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -5,6 +5,7 @@ #include "core/framework/tensor_shape.h" #include "core/framework/ml_value.h" #include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" #include #include @@ -15,16 +16,6 @@ using onnxruntime::DataTypeImpl; using onnxruntime::MLFloat16; using onnxruntime::Tensor; -struct OrtTensorTypeAndShapeInfo { - public: - ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - onnxruntime::TensorShape shape; - - OrtTensorTypeAndShapeInfo() = default; - OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; - OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; -}; - #define API_IMPL_BEGIN try { #define API_IMPL_END \ } \ @@ -72,8 +63,7 @@ ORT_API(int64_t, OrtGetTensorShapeElementCount, _In_ const OrtTensorTypeAndShape struct OrtValue; -namespace { -inline ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( +ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( const onnxruntime::DataTypeImpl* cpp_type) { ONNXTensorElementDataType type; if (cpp_type == onnxruntime::DataTypeImpl::GetType()) { @@ -109,7 +99,41 @@ inline ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( } return type; } -} // namespace + +const onnxruntime::DataTypeImpl* TensorElementDataTypeToMLDataType(ONNXTensorElementDataType type) { + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return onnxruntime::DataTypeImpl::GetType(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return onnxruntime::DataTypeImpl::GetType(); + default: + return nullptr; + } +} OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(tensor_data_type); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h new file mode 100644 index 0000000000000..9c829215b94f3 --- /dev/null +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +struct OrtTensorTypeAndShapeInfo { + public: + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + onnxruntime::TensorShape shape; + + OrtTensorTypeAndShapeInfo() = default; + OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; + OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; +}; diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt index 4f8aaaca4aa45..c8eac9cd59ca5 100644 --- a/onnxruntime/core/providers/cpu/symbols.txt +++ b/onnxruntime/core/providers/cpu/symbols.txt @@ -1,3 +1,4 @@ +OrtAddCustomOpDomain OrtAllocatorAlloc OrtAllocatorFree OrtAllocatorGetInfo @@ -11,6 +12,7 @@ OrtCloneSessionOptions OrtCompareAllocatorInfo OrtCreateAllocatorInfo OrtCreateCpuAllocatorInfo +OrtCreateCustomOpDomain OrtCreateDefaultAllocator OrtCreateEnv OrtCreateEnvWithCustomLogger @@ -21,6 +23,7 @@ OrtCreateTensorAsOrtValue OrtCreateTensorTypeAndShapeInfo OrtCreateTensorWithDataAsOrtValue OrtCreateValue +OrtCustomOpDomain_Add OrtDisableCpuMemArena OrtDisableMemPattern OrtDisableProfiling @@ -48,6 +51,7 @@ OrtGetValueType OrtIsTensor OrtReleaseAllocator OrtReleaseAllocatorInfo +OrtReleaseCustomOpDomain OrtReleaseEnv OrtReleaseRunOptions OrtReleaseSession diff --git a/onnxruntime/core/session/abi_session_options_impl.h b/onnxruntime/core/session/abi_session_options_impl.h index 1af9e5f26819b..ca57a01b8566f 100644 --- a/onnxruntime/core/session/abi_session_options_impl.h +++ b/onnxruntime/core/session/abi_session_options_impl.h @@ -13,6 +13,7 @@ struct OrtSessionOptions { onnxruntime::SessionOptions value; std::vector custom_op_paths; + std::vector custom_op_domains_; std::vector> provider_factories; OrtSessionOptions() = default; ~OrtSessionOptions(); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 562e56fbbb67f..e70ff06e5c750 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -22,6 +22,7 @@ #include "core/framework/allocatormgr.h" #include "core/framework/customregistry.h" #include "core/framework/environment.h" +#include "core/framework/error_code_helper.h" #include "core/framework/execution_frame.h" #include "core/framework/feeds_fetches_manager.h" #include "core/framework/graph_partitioner.h" @@ -31,11 +32,13 @@ #include "core/framework/mldata_type_utils.h" #include "core/framework/mlvalue_name_idx_map.h" #include "core/framework/sequential_executor.h" +#include "core/framework/op_kernel_context_internal.h" #include "core/framework/parallel_executor.h" #include "core/framework/path_lib.h" #include "core/framework/session_state.h" #include "core/framework/session_state_initializer.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/framework/utils.h" #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/graph_transformer.h" @@ -52,6 +55,77 @@ using namespace ONNX_NAMESPACE; +ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); +const onnxruntime::DataTypeImpl* TensorElementDataTypeToMLDataType(ONNXTensorElementDataType type); + +namespace onnxruntime { +const char* ElementTypeToString(MLDataType type) { + if (type == DataTypeImpl::GetType()) { + return "tensor(float)"; + } else if (type == DataTypeImpl::GetType()) { + return "tensor(bool)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(int32)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(double)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(string)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(uint8)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(uint16)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(int16)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(int64)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(uint32)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(uint64)"; + } + + else if (type == DataTypeImpl::GetType()) { + return "tensor(MLFloat16)"; + } else if (type == DataTypeImpl::GetType()) { + return "tensor(bfloat16)"; + } else { + return "unknown"; + } +} +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); +} + +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); +} + namespace onnxruntime { namespace { template @@ -87,6 +161,41 @@ inline std::basic_string GetCurrentTimeString() { return std::basic_string(time_str); } } // namespace +struct CustomOpKernel : OpKernel { + CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) { + op_.CreateKernel(&op_, reinterpret_cast(const_cast(&info)), &op_kernel_); + } + + ~CustomOpKernel() { + op_.KernelDestroy(op_kernel_); + } + + Status Compute(OpKernelContext* ctx) const override { + auto* ictx = static_cast(ctx); + std::vector input_tensors; + auto input_count = ictx->InputCount(); + for (int i = 0; i < input_count; i++) + input_tensors.emplace_back(const_cast(reinterpret_cast(ictx->GetInputMLValue(i)))); + + std::vector output_tensors; + auto output_count = ictx->OutputCount(); + for (int i = 0; i < output_count; i++) { + OrtTensorTypeAndShapeInfo info; + op_.KernelGetOutputShape(op_kernel_, input_tensors.data(), input_tensors.size(), i, &info); + output_tensors.emplace_back(reinterpret_cast(ictx->OutputMLValue(0, info.shape))); + } + + op_.KernelCompute(op_kernel_, input_tensors.data(), input_tensors.size(), output_tensors.data(), output_tensors.size()); + return Status::OK(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel); + + OrtCustomOp& op_; + void* op_kernel_; +}; + class InferenceSession::Impl { public: Impl(const SessionOptions& session_options, logging::LoggingManager* logging_manager) @@ -156,6 +265,58 @@ class InferenceSession::Impl { return Status::OK(); } + common::Status AddCustomOpDomains(const std::vector& op_domains) { + auto custom_registry = std::make_shared(); + + for (auto& domain : op_domains) { + SchemasContainer schemas_container; + + schemas_container.domain = domain->domain_; + schemas_container.baseline_opset_version = domain->op_version_start_; + schemas_container.opset_version = domain->op_version_end_; + + for (auto& op : domain->custom_ops_) { + ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); + + auto input_count = op->GetInputTypeCount(op); + for (size_t i = 0; i < input_count; i++) { + auto type = op->GetInputType(op, i); + + schema.Input(i, "A", "Description", ElementTypeToString(TensorElementDataTypeToMLDataType(type))); + } + + auto output_count = op->GetOutputTypeCount(op); + for (size_t i = 0; i < output_count; i++) { + auto type = op->GetOutputType(op, i); + + schema.Output(i, "A", "Description", ElementTypeToString(TensorElementDataTypeToMLDataType(type))); + } + + schema.SinceVersion(domain->op_version_start_); + schema.AllowUncheckedAttributes(); + + schemas_container.schemas_list.push_back(schema); + + KernelDefBuilder def_builder; + def_builder.SetName(op->GetName(op)) + .SetDomain(onnxruntime::kOnnxDomain) + .SinceVersion(domain->op_version_start_) + .Provider(onnxruntime::kCpuExecutionProvider); + KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; + KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); + + custom_registry->RegisterCustomKernel(create_info); + } + + ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list, + schemas_container.domain, + schemas_container.baseline_opset_version, + schemas_container.opset_version)); + } + RegisterCustomRegistry(custom_registry); + return Status::OK(); + } + common::Status RegisterCustomRegistry(std::shared_ptr& custom_registry) { if (custom_registry == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for custom registry"); @@ -163,6 +324,8 @@ class InferenceSession::Impl { // Insert session-level customized kernel registry. kernel_registry_manager_.RegisterKernelRegistry(custom_registry); + // if (custom_schema_registries_.empty()) + // custom_schema_registries_.push_back(); custom_schema_registries_.push_back(custom_registry); return Status::OK(); } @@ -1041,4 +1204,8 @@ common::Status InferenceSession::Run(IOBinding& io_binding) { common::Status InferenceSession::LoadCustomOps(const std::vector& dso_list) { return impl_->LoadCustomOps(dso_list); } + +common::Status InferenceSession::AddCustomOpDomains(const std::vector& ops) { + return impl_->AddCustomOpDomains(ops); +} } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index aa43b4bbe394a..3d88997935d8f 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -20,6 +20,13 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE +struct OrtCustomOpDomain { + std::string domain_; + int op_version_start_{}; + int op_version_end_{}; + std::vector custom_ops_; +}; + namespace onnxruntime { class IExecutionProvider; // forward decl class IOBinding; @@ -132,6 +139,8 @@ class InferenceSession { */ common::Status LoadCustomOps(const std::vector& dso_list); + common::Status AddCustomOpDomains(const std::vector& ops); + /** * Register a custom registry for operator schema and kernels. If you've one to register, * call this before invoking Initialize(). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f5fc9f545d22b..f359f5b932505 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -337,18 +337,50 @@ ORT_API_STATUS_IMPL(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, API_IMPL_END } -template -static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path, - _In_ const OrtSessionOptions* options, - _Out_ OrtSession** out) { +ORT_API(OrtCustomOpDomain*, OrtCreateCustomOpDomain, _In_ const char* domain, int op_version_start, int op_version_end) { + auto custom_op_domain = std::make_unique(); + custom_op_domain->domain_ = domain; + custom_op_domain->op_version_start_ = op_version_start; + custom_op_domain->op_version_end_ = op_version_end; + return custom_op_domain.release(); +} + +ORT_API(void, OrtReleaseCustomOpDomain, OrtCustomOpDomain* ptr) { + delete ptr; +} + +ORT_API_STATUS_IMPL(OrtCustomOpDomain_Add, _In_ OrtCustomOpDomain* custom_op_domain, OrtCustomOp* op) { + API_IMPL_BEGIN + custom_op_domain->custom_ops_.emplace_back(op); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtAddCustomOpDomain, _In_ OrtSessionOptions* options, OrtCustomOpDomain* custom_op_domain) { + API_IMPL_BEGIN + options->custom_op_domains_.emplace_back(custom_op_domain); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtCreateSession, _In_ OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Out_ OrtSession** out) { API_IMPL_BEGIN auto sess = std::make_unique<::onnxruntime::InferenceSession>(options == nullptr ? onnxruntime::SessionOptions() : options->value, env->loggingManager); Status status; - if (options != nullptr && !options->custom_op_paths.empty()) { - status = sess->LoadCustomOps(options->custom_op_paths); - if (!status.IsOK()) - return ToOrtStatus(status); + if (options != nullptr) { + if (!options->custom_op_paths.empty()) { + status = sess->LoadCustomOps(options->custom_op_paths); + if (!status.IsOK()) + return ToOrtStatus(status); + } + if (!options->custom_op_domains_.empty()) { + status = sess->AddCustomOpDomains(options->custom_op_domains_); + if (!status.IsOK()) + return ToOrtStatus(status); + } } + if (options != nullptr) for (auto& factory : options->provider_factories) { auto provider = factory->CreateProvider(); @@ -366,22 +398,6 @@ static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path, API_IMPL_END } -#ifdef _WIN32 -ORT_API_STATUS_IMPL(OrtCreateSession, _In_ OrtEnv* env, _In_ const wchar_t* model_path, - _In_ const OrtSessionOptions* options, _Out_ OrtSession** out) { - API_IMPL_BEGIN - return CreateSessionImpl(env, model_path, options, out); - API_IMPL_END -} -#else -ORT_API_STATUS_IMPL(OrtCreateSession, _In_ OrtEnv* env, _In_ const char* model_path, - _In_ const OrtSessionOptions* options, _Out_ OrtSession** out) { - API_IMPL_BEGIN - return CreateSessionImpl(env, model_path, options, out); - API_IMPL_END -} -#endif - ORT_API_STATUS_IMPL(OrtRun, _In_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 7737edb68607c..ad1129ec2a447 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -63,7 +63,7 @@ void TestInference(OrtEnv* env, T model_uri, const std::vector& values_x, const std::vector& expected_dims_y, const std::vector& expected_values_y, - int provider_type, bool custom_op) { + int provider_type, bool custom_op, OrtCustomOpDomain* custom_op_domain_ptr = nullptr) { SessionOptionsWrapper sf(env); if (provider_type == 1) { @@ -93,6 +93,10 @@ void TestInference(OrtEnv* env, T model_uri, if (custom_op) { sf.AppendCustomOpLibPath("libonnxruntime_custom_op_shared_lib_test.so"); } + if (custom_op_domain_ptr) { + ORT_THROW_ON_ERROR(OrtAddCustomOpDomain(sf, custom_op_domain_ptr)); + } + std::unique_ptr inference_session(sf.OrtCreateSession(model_uri), OrtReleaseSession); std::unique_ptr default_allocator(std::make_unique()); @@ -169,6 +173,88 @@ TEST_F(CApiTest, DISABLED_custom_op) { } #endif +struct OrtTensorDimensions : std::vector { + OrtTensorDimensions(OrtValue* value) { + OrtTensorTypeAndShapeInfo* info; + ORT_THROW_ON_ERROR(OrtGetTensorShapeAndType(value, &info)); + auto dimensionCount = OrtGetNumOfDimensions(info); + resize(dimensionCount); + OrtGetDimensions(info, data(), dimensionCount); + OrtReleaseTensorTypeAndShapeInfo(info); + } + + size_t ElementCount() const { + int64_t count = 1; + for (int i = 0; i < size(); i++) + count *= (*this)[i]; + return count; + } +}; + +template +constexpr size_t countof(T (&)[N]) { return N; } + +struct MyCustomKernel { + MyCustomKernel(OrtKernelInfo& /*info*/) { + } + + void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) { + OrtTensorDimensions dimensions(inputs[0]); + ORT_THROW_ON_ERROR(OrtSetDims(info, dimensions.data(), dimensions.size())); + } + + void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) { + const float* X; + const float* Y; + ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], reinterpret_cast(const_cast(&X)))); + ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[1], reinterpret_cast(const_cast(&Y)))); + + float* out; + ORT_THROW_ON_ERROR(OrtGetTensorMutableData(outputs[0], reinterpret_cast(&out))); + + int64_t size = OrtTensorDimensions(inputs[0]).ElementCount(); + for (int64_t i = 0; i < size; i++) { + out[i] = X[i] + Y[i]; + } + } +}; + +struct MyCustomOp : OrtCustomOp { + MyCustomOp() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*info); }; + OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; }; + + static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; + OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_inputTypes); }; + OrtCustomOp::GetInputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_inputTypes[index]; }; + + static const ONNXTensorElementDataType c_outputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; + OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_outputTypes); }; + OrtCustomOp::GetOutputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_outputTypes[index]; }; + + OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtValue** inputs, size_t input_count, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast(op_kernel)->GetOutputShape(inputs, input_count, output_index, output); }; + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtValue** inputs, size_t input_count, OrtValue** outputs, size_t output_count) { static_cast(op_kernel)->Compute(inputs, input_count, outputs, output_count); }; + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; + } +}; + +TEST_F(CApiTest, custom_op_handler) { + std::cout << "Running custom op inference" << std::endl; + std::vector dims_x = {3, 2}; + std::vector values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + + MyCustomOp custom_op; + OrtCustomOpDomain* custom_op_domain = OrtCreateCustomOpDomain("", 5, 7); + ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(custom_op_domain, &custom_op)); + + TestInference(env, CUSTOM_OP_MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, false, false, custom_op_domain); +} + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS TEST_F(CApiTest, create_session_without_session_option) { constexpr PATH_TYPE model_uri = TSTR("../models/opset8/test_squeezenet/model.onnx"); From b68079fe5d334ab525cc80a901efdc5e4dee4833 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Thu, 7 Mar 2019 00:13:11 -0800 Subject: [PATCH 3/6] Support int32_t for Split op (#563) * Support int32_t for Split op * Support int32_t for Split op --- .../core/providers/cpu/tensor/split.cc | 3 + .../providers/cpu/tensor/split_op_test.cc | 113 +++++++++++------- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 2e995647120c1..d4e655a7f20e4 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -17,6 +17,7 @@ ONNX_CPU_OPERATOR_KERNEL( std::vector{ DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), }), Split); @@ -28,6 +29,8 @@ Status Split::Compute(OpKernelContext* context) const { if (data_type == DataTypeImpl::GetType()) status = ComputeImpl(*context, input); + else if (data_type == DataTypeImpl::GetType()) + status = ComputeImpl(*context, input); else if (data_type == DataTypeImpl::GetType()) { /* Need to update CopyMatrix to support double... status = ComputeImpl(*context, input); */ diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index c2a4d07365298..a7c4f35d82cb3 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -7,12 +7,15 @@ namespace onnxruntime { namespace test { -using ShapeAndData = std::pair, const std::vector>; +template using ShapeAndData = std::pair, const std::vector>; + +using ShapeAndFloatData = ShapeAndData; +using ShapeAndInt32Data = ShapeAndData; using ExpectResult = OpTester::ExpectResult; -void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAndData& input, - const std::vector& outputs, - bool expect_failure = false, const std::string& err_msg = {}) { +template void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAndData& input, + const std::vector>& outputs, + bool expect_failure = false, const std::string& err_msg = {}) { OpTester test("Split"); test.AddAttribute("axis", axis); @@ -20,7 +23,7 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn if (!split_sizes.empty()) test.AddAttribute("split", split_sizes); - test.AddInput("input", input.first, input.second); + test.AddInput("input", input.first, input.second); int i = 0; for (auto& output : outputs) { @@ -28,7 +31,7 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn auto& data = output.second; std::ostringstream oss; oss << "output" << i++; - test.AddOutput(oss.str().c_str(), shape, data); + test.AddOutput(oss.str().c_str(), shape, data); } test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg); @@ -36,10 +39,10 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn TEST(SplitOperatorTest, Axis0EqualSplit) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -53,15 +56,37 @@ TEST(SplitOperatorTest, Axis0EqualSplit) { {5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); +} + +TEST(SplitOperatorTest, Axis0EqualSplitInt32) { + const int64_t axis = 0; + std::vector outputs; + + // input shape and data + ShapeAndInt32Data input = {{4, 2}, // shape + {1, 2, + 3, 4, + 5, 6, + 7, 8}}; + + outputs.push_back({{2, 2}, + {1, 2, + 3, 4}}); + + outputs.push_back({{2, 2}, + {5, 6, + 7, 8}}); + + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis0UnequalSplit) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -76,15 +101,15 @@ TEST(SplitOperatorTest, Axis0UnequalSplit) { 5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } TEST(SplitOperatorTest, Axis1EqualSplit) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -96,15 +121,15 @@ TEST(SplitOperatorTest, Axis1EqualSplit) { {3.f, 4.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis1UnequalSplit) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -118,10 +143,10 @@ TEST(SplitOperatorTest, Axis1UnequalSplit) { {4.f, 8.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } -ShapeAndData CreateInput(std::vector shape) { +ShapeAndFloatData CreateInput(std::vector shape) { auto size = TensorShape(shape).Size(); float i = 0.f, increment = 1.f; @@ -129,16 +154,16 @@ ShapeAndData CreateInput(std::vector shape) { std::vector data; std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; }); - ShapeAndData input = {shape, data}; + ShapeAndFloatData input = {shape, data}; return input; } TEST(SplitOperatorTest, Axis2EqualSplit) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); outputs.push_back({{2, 2, 2}, {1.f, 2.f, @@ -161,14 +186,14 @@ TEST(SplitOperatorTest, Axis2EqualSplit) { 17.f, 18.f, 23.f, 24.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis2UnequalSplit) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); std::vector splits{1, 2, 3}; @@ -193,15 +218,15 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) { 16.f, 17.f, 18.f, 22.f, 23.f, 24.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } // test a split of a dimension that has leading and trailing dimensions TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); outputs.push_back({{2, 2, 4}, {1.f, 2.f, 3.f, 4.f, @@ -217,15 +242,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } // test a split of a dimension that has leading and trailing dimensions TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); std::vector splits{1, 3}; @@ -243,15 +268,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } TEST(SplitOperatorTest, NegativeAxis) { const int64_t axis = -1; // split last axis equally - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -263,15 +288,15 @@ TEST(SplitOperatorTest, NegativeAxis) { {3.f, 4.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, InvalidAxis) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -279,16 +304,16 @@ TEST(SplitOperatorTest, InvalidAxis) { outputs.push_back({{1}, {0.f}}); - RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'"); + RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'"); } // sum of values in splits is too small TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -299,15 +324,15 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}}); - RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute"); + RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute"); } TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -317,7 +342,7 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute"); + RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute"); } /* From 4635bcc62461aba1208bfe1acd44f5040107f5f2 Mon Sep 17 00:00:00 2001 From: jignparm Date: Thu, 7 Mar 2019 00:28:15 -0800 Subject: [PATCH 4/6] Updating C_API end-to-end test and user samples (#564) * Updating user sample and C_API unit test * remove debugging info * remove precompiled headers * header file location changed in master...updating --- .../C_Api_Sample.cpp | 148 +++++++++++++++++ ....OnnxRuntime.EndToEndTests.RunCapi.vcxproj | 111 +++++++++++++ .../runtest.bat | 7 +- docs/CSharp_API.md | 6 + docs/C_API.md | 155 +----------------- 5 files changed, 270 insertions(+), 157 deletions(-) create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp new file mode 100644 index 0000000000000..240c728de5909 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp @@ -0,0 +1,148 @@ +// Copyright(c) Microsoft Corporation.All rights reserved. +// Licensed under the MIT License. +// + +#include +#include +#include +#include +#include + +//***************************************************************************** +// helper function to check for status +#define CHECK_STATUS(expr) \ + { \ + OrtStatus* onnx_status = (expr); \ + if (onnx_status != NULL) { \ + const char* msg = OrtGetErrorMessage(onnx_status); \ + fprintf(stderr, "%s\n", msg); \ + OrtReleaseStatus(onnx_status); \ + exit(1); \ + } \ + } + +int main(int argc, char* argv[]) { + //************************************************************************* + // initialize enviroment...one enviroment per process + // enviroment maintains thread pools and other state info + OrtEnv* env; + CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); + + // initialize session options if needed + OrtSessionOptions* session_option = OrtCreateSessionOptions(); + OrtSetSessionThreadPoolSize(session_option, 1); + + //************************************************************************* + // create session and load model into memory + // using squeezenet version 1.3 + // URL = https://github.com/onnx/models/tree/master/squeezenet + OrtSession* session; + const wchar_t* model_path = L"squeezenet.onnx"; + CHECK_STATUS(OrtCreateSession(env, model_path, session_option, &session)); + + //************************************************************************* + // print model input layer (node names, types, shape etc.) + size_t num_input_nodes; + OrtStatus* status; + OrtAllocator* allocator; + OrtCreateDefaultAllocator(&allocator); + + // print number of model input nodes + status = OrtSessionGetInputCount(session, &num_input_nodes); + std::vector input_node_names(num_input_nodes); + std::vector input_node_dims; // simplify... this model has only 1 input node {1, 3, 224, 224}. + // Otherwise need vector> + + printf("Number of inputs = %zu\n", num_input_nodes); + + // iterate over all input nodes + for (int i = 0; i < num_input_nodes; i++) { + // print input node names + char* input_name; + status = OrtSessionGetInputName(session, i, allocator, &input_name); + printf("Input %d : name=%s\n", i, input_name); + input_node_names[i] = input_name; + + // print input node types + OrtTypeInfo* typeinfo; + status = OrtSessionGetInputTypeInfo(session, i, &typeinfo); + const OrtTensorTypeAndShapeInfo* tensor_info = OrtCastTypeInfoToTensorInfo(typeinfo); + ONNXTensorElementDataType type = OrtGetTensorElementType(tensor_info); + printf("Input %d : type=%d\n", i, type); + + // print input shapes/dims + size_t num_dims = OrtGetNumOfDimensions(tensor_info); + printf("Input %d : num_dims=%zu\n", i, num_dims); + input_node_dims.resize(num_dims); + OrtGetDimensions(tensor_info, (int64_t*)input_node_dims.data(), num_dims); + for (int j = 0; j < num_dims; j++) + printf("Input %d : dim %d=%jd\n", i, j, input_node_dims[j]); + + OrtReleaseTypeInfo(typeinfo); + } + OrtReleaseAllocator(allocator); + + // Results should be... + // Number of inputs = 1 + // Input 0 : name = data_0 + // Input 0 : type = 1 + // Input 0 : num_dims = 4 + // Input 0 : dim 0 = 1 + // Input 0 : dim 1 = 3 + // Input 0 : dim 2 = 224 + // Input 0 : dim 3 = 224 + + //************************************************************************* + // Similar operations to get output node information. + // Use OrtSessionGetOutputCount(), OrtSessionGetOutputName() + // OrtSessionGetOutputTypeInfo() as shown above. + + //************************************************************************* + // Score the model using sample data, and inspect values + + size_t input_tensor_size = 224 * 224 * 3; // simplify ... using known dim values to calculate size + // use OrtGetTensorShapeElementCount() to get official size! + + std::vector input_tensor_values(input_tensor_size); + std::vector output_node_names = {"softmaxout_1"}; + + // initialize input data with values in [0.0, 1.0] + for (unsigned int i = 0; i < input_tensor_size; i++) + input_tensor_values[i] = (float)i / (input_tensor_size + 1); + + // create input tensor object from data values + OrtAllocatorInfo* allocator_info; + CHECK_STATUS(OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info)); + OrtValue* input_tensor = NULL; + CHECK_STATUS(OrtCreateTensorWithDataAsOrtValue(allocator_info, input_tensor_values.data(), input_tensor_size * sizeof(float), input_node_dims.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); + assert(OrtIsTensor(input_tensor)); + OrtReleaseAllocatorInfo(allocator_info); + + // score model & input tensor, get back output tensor + OrtValue* output_tensor = NULL; + CHECK_STATUS(OrtRun(session, NULL, input_node_names.data(), (const OrtValue* const*)&input_tensor, 1, output_node_names.data(), 1, &output_tensor)); + assert(OrtIsTensor(output_tensor)); + + // Get pointer to output tensor float values + float* floatarr; + OrtGetTensorMutableData(output_tensor, (void**)&floatarr); + assert(abs(floatarr[0] - 0.000045) < 1e-6); + + // score the model, and print scores for first 5 classes + for (int i = 0; i < 5; i++) + printf("Score for class [%d] = %f\n", i, floatarr[i]); + + // Results should be as below... + // Score for class[0] = 0.000045 + // Score for class[1] = 0.003846 + // Score for class[2] = 0.000125 + // Score for class[3] = 0.001180 + // Score for class[4] = 0.001317 + + OrtReleaseValue(output_tensor); + OrtReleaseValue(input_tensor); + OrtReleaseSession(session); + OrtReleaseEnv(env); + printf("Done!\n"); + return 0; +} \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj new file mode 100644 index 0000000000000..ab0e88fd1a99d --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj @@ -0,0 +1,111 @@ + + + + + $(MSBuildThisFileDirectory)..\.. + + + + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {B8CA7F10-0171-4EA5-8662-5A9942DDF415} + Win32Proj + MicrosoftMLOnnxRuntimeEndToEndTestsRunCapi + 10.0.17763.0 + + + + Application + true + v141 + Unicode + + + Application + false + v141 + true + Unicode + + + + + + + + + + + + + + + true + + + false + + + + NotUsing + Level3 + Disabled + true + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + + + + + NotUsing + Level3 + MaxSpeed + true + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + true + true + + + + + + + + Always + false + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/runtest.bat b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/runtest.bat index c0a042c034a3f..b9d45924c3863 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/runtest.bat +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/runtest.bat @@ -36,7 +36,7 @@ for /f "delims=" %%i in ('type "%templateFile%" ^& break ^> "packages.config" ') echo on REM Restore NuGet Packages -nuget restore -PackagesDirectory ..\packages -Source %LocalNuGetRepo% Microsoft.ML.OnnxRuntime.EndToEndTests.Capi.vcxproj +nuget restore -PackagesDirectory ..\packages -Source %LocalNuGetRepo% Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj if NOT %ERRORLEVEL% EQU 0 ( echo "Error:Nuget restore failed" popd @@ -44,7 +44,7 @@ if NOT %ERRORLEVEL% EQU 0 ( ) REM Build Native project -msbuild Microsoft.ML.OnnxRuntime.EndToEndTests.Capi.vcxproj +msbuild Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.vcxproj if NOT %ERRORLEVEL% EQU 0 ( echo "Error:MSBuild failed to compile project" popd @@ -54,7 +54,8 @@ if NOT %ERRORLEVEL% EQU 0 ( REM Run Unit Tests pushd x64\Debug -vstest.console.exe /platform:x64 Microsoft.ML.OnnxRuntime.EndToEndTests.Capi.dll +REM vstest.console.exe /platform:x64 Microsoft.ML.OnnxRuntime.EndToEndTests.Capi.dll +.\Microsoft.ML.OnnxRuntime.EndToEndTests.RunCapi.exe if NOT %ERRORLEVEL% EQU 0 ( echo "Unit test failure: %ERRORLEVEL%" popd diff --git a/docs/CSharp_API.md b/docs/CSharp_API.md index 92b2a2b195652..c67b3d6d25de3 100644 --- a/docs/CSharp_API.md +++ b/docs/CSharp_API.md @@ -4,6 +4,12 @@ The ONNX runtime provides a C# .Net binding for running inference on ONNX models ## NuGet Package The Microsoft.ML.OnnxRuntime Nuget package includes the precompiled binaries for ONNX runtime, and includes libraries for Windows and Linux platforms with X64 CPUs. The APIs conform to .Net Standard 1.1. +## Sample Code + +The unit tests contain several examples of loading models, inspecting input/output node shapes and types, as well as constructing tensors for scoring. + +* [../csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs#L54](../csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs#L54) + ## Getting Started Here is simple tutorial for getting started with running inference on an existing ONNX model for a given input data. The model is typically trained using any of the well-known training frameworks and exported into the ONNX format. To start scoring using the model, open a session using the `InferenceSession` class, passing in the file path to the model as a parameter. diff --git a/docs/C_API.md b/docs/C_API.md index 48470053ba5db..7aca76cfe1180 100644 --- a/docs/C_API.md +++ b/docs/C_API.md @@ -25,158 +25,5 @@ The example below shows a sample run using the SqueezeNet model from ONNX model zoo, including dynamically reading model inputs, outputs, shape and type information, as well as running a sample vector and fetching the resulting class probabilities for inspection. - -```c -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// - -#include -#include -#include -#include -#include - -//***************************************************************************** -// helper function to check for status -#define CHECK_STATUS(expr) \ - do { \ - OrtStatus* onnx_status = (expr); \ - if (onnx_status != NULL) { \ - const char* msg = OrtGetErrorMessage(onnx_status); \ - fprintf(stderr, "%s\n", msg); \ - OrtReleaseStatus(onnx_status); \ - abort(); \ - } \ - } while (0); - -int main(int argc, char *argv[]) -{ - //************************************************************************* - // initialize enviroment...one enviroment per process - // enviroment maintains thread pools and other state info - OrtEnv* env; - CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); - - // initialize session options if needed - OrtSessionOptions* session_option = OrtCreateSessionOptions(); - OrtSetSessionThreadPoolSize(session_option, 1); - - //************************************************************************* - // create session and load model into memory - // using squeezenet version 1.3 - // URL = https://github.com/onnx/models/tree/master/squeezenet - OrtSession* session; - const wchar_t * model_path = L"model.onnx"; - CHECK_STATUS(OrtCreateSession(env, model_path, session_option, &session)); - - //************************************************************************* - // print model input layer (node names, types, shape etc.) - - size_t num_inputs; - OrtStatus* status; - OrtAllocator* allocator; - OrtCreateDefaultAllocator(&allocator); - - // print number of model input nodes - status = OrtSessionGetInputCount(session, &num_inputs); - char **input_names = (char**)malloc(num_inputs * sizeof(char*)); - printf("Number of inputs = %zu\n", num_inputs); - - // iterate over all input nodes - for (int i = 0; i < num_inputs; i++) - { - // print input node names - char* input_name; - status = OrtSessionGetInputName(session, i, allocator, &input_name); - printf("Input %d : name=%s\n", i, input_name); - input_names[i] = input_name; - - // print input node types - OrtTypeInfo* typeinfo; - status = OrtSessionGetInputTypeInfo(session, i, &typeinfo); - const OrtTensorTypeAndShapeInfo* tensor_info = OrtCastTypeInfoToTensorInfo(typeinfo); - ONNXTensorElementDataType type = OrtGetTensorElementType(tensor_info); - printf("Input %d : type=%d\n", i, type); - - // print input shapes - size_t num_dims = OrtGetNumOfDimensions(tensor_info); - int64_t* dims = (int64_t*)malloc(num_dims * sizeof(int64_t)); - - printf("Input %d : num_dims=%zu\n", i, num_dims); - - OrtGetDimensions(tensor_info, dims, num_dims); - - for (int j = 0; j < num_dims; j++) - printf("Input %d : dim %d=%jd\n", i, j, dims[j]); - - OrtReleaseTypeInfo(typeinfo); - } - OrtReleaseAllocator(allocator); - - // Results should be... - // Number of inputs = 1 - // Input 0 : name = data_0 - // Input 0 : type = 1 - // Input 0 : num_dims = 4 - // Input 0 : dim 0 = 1 - // Input 0 : dim 1 = 3 - // Input 0 : dim 2 = 224 - // Input 0 : dim 3 = 224 - - //************************************************************************* - // Similar operations to get output node information. - // Use OrtSessionGetOutputCount(), OrtSessionGetOutputName() - // OrtSessionGetOutputTypeInfo() as shown above. - - //************************************************************************* - // Score the model using sample data, and inspect values - - size_t input_dims[] = { 1, 3, 224, 224 }; - size_t input_count = 3 * 224 * 224; // input tensor count = product of dims - float* input_data = (float *) malloc(sizeof(float) * input_count); - const char* output_names[] = { "softmaxout_1"}; - - // initialize input data with values in [0.0, 1.0] - for (unsigned int i = 0; i < input_count; i++) - input_data[i] = (float)i / (float)(input_count + 1); - - // create input tensor object from data values - OrtAllocatorInfo* allocator_info; - CHECK_STATUS(OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info)); - OrtValue* input_tensor = NULL; - CHECK_STATUS(OrtCreateTensorWithDataAsOrtValue(allocator_info, input_data, input_count * sizeof(float), input_dims, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); - assert(OrtIsTensor(input_tensor)); - OrtReleaseAllocatorInfo(allocator_info); - - // score model & input tensor, get back output tensor - OrtValue* output_tensor = NULL; - CHECK_STATUS(OrtRun(session, NULL, input_names, (const OrtValue* const*)&input_tensor, 1, output_names, 1, &output_tensor)); - assert(OrtIsTensor(output_tensor)); - - // copy output tensor values to float array - // model produces scores for 1000 classes - float* floatarr = (float *) malloc(1000 * sizeof(float)); - OrtGetTensorMutableData(output_tensor, (void **) &floatarr); - - // score the model, and print scores for first 5 classes - for (int i = 0; i < 5; i++) - printf("Score for class [%d] = %f\n", i, floatarr[i]); - - // Results should be as below... - // Score for class[0] = 0.000045 - // Score for class[1] = 0.003846 - // Score for class[2] = 0.000125 - // Score for class[3] = 0.001180 - // Score for class[4] = 0.001317 - - free(input_data); - OrtReleaseValue(output_tensor); - OrtReleaseValue(input_tensor); - OrtReleaseEnv(env); - printf("Done!\n"); - return 0; -} - - +* [../csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp](../csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp) From b4ffcf8258cc7bcd4a97303efd53d6e6c01904d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 7 Mar 2019 13:08:02 +0100 Subject: [PATCH 5/6] Fixes #31, add option numpy_version, skip_keras_test to the parser of build.py, add flag PRIVATE for the python bindings (#544) * add option numpy_version to build against the installed numpy version and not 1.15.0 (hardcoded version number), default is still 1.15.0 * add option skip_keras_test to skip keras test even if keras is installed (still enabled by default) disable unnecessary warnings about ubuntu * enable option PRIVATE for the compilation of the Python bindings (settings recommended on pybind11 documentation) * test on debian 9 --- cmake/onnxruntime_python.cmake | 2 +- onnxruntime/python/onnxruntime_validation.py | 3 -- tools/ci_build/build.py | 29 ++++++++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c7df6a7de3b7d..efe411d46bf15 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -86,7 +86,7 @@ elseif (APPLE) BUILD_WITH_INSTALL_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE) else() - target_link_libraries(onnxruntime_pybind11_state ${onnxruntime_pybind11_state_libs} ${PYTHON_LIBRARY} ${ONNXRUNTIME_SO_LINK_FLAG} debug ${onnxruntime_EXTERNAL_LIBRARIES_DEBUG} optimized ${onnxruntime_EXTERNAL_LIBRARIES}) + target_link_libraries(onnxruntime_pybind11_state PRIVATE ${onnxruntime_pybind11_state_libs} ${PYTHON_LIBRARY} ${ONNXRUNTIME_SO_LINK_FLAG} debug ${onnxruntime_EXTERNAL_LIBRARIES_DEBUG} optimized ${onnxruntime_EXTERNAL_LIBRARIES}) set_target_properties(onnxruntime_pybind11_state PROPERTIES LINK_FLAGS "-Xlinker -rpath=\$ORIGIN") endif() diff --git a/onnxruntime/python/onnxruntime_validation.py b/onnxruntime/python/onnxruntime_validation.py index b337395d87c0a..7350fe6655bac 100644 --- a/onnxruntime/python/onnxruntime_validation.py +++ b/onnxruntime/python/onnxruntime_validation.py @@ -47,9 +47,6 @@ def check_distro_info(): # warn the user ONNX Runtime may not work out of the box __my_distro__ = __my_distro__.lower() __my_distro_ver__ = __my_distro_ver__.lower() - - if __my_distro__ != 'ubuntu' and __my_distro_ver__ != '16.04': - warnings.warn('Unsupported Linux distribution (%s-%s). ONNX Runtime supports Ubuntu 16.04 only.' % (__my_distro__, __my_distro_ver__)) elif __my_system__ == 'darwin': __my_distro__ = __my_system__ __my_distro_ver__ = platform.release().lower() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index fd425d9f13753..08e70e674ab4c 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -79,6 +79,9 @@ def parse_arguments(): # Python bindings parser.add_argument("--enable_pybind", action='store_true', help="Enable Python Bindings.") parser.add_argument("--build_wheel", action='store_true', help="Build Python Wheel. ") + parser.add_argument("--numpy_version", default='1.15.0', help="Installs a specific version of numpy " + "before building the python binding.") + parser.add_argument("--skip-keras-test", action='store_true', help="Skip tests with Keras if keras is installed") # C-Sharp bindings parser.add_argument("--build_csharp", action='store_true', help="Build C#.Net DLL and NuGet package") @@ -195,8 +198,9 @@ def install_ubuntu_deps(args): except Exception as e: raise BuildError("Error setting up required APT packages. {}".format(str(e))) -def install_python_deps(): - dep_packages = ['setuptools', 'wheel', 'numpy==1.15.0'] +def install_python_deps(numpy_version=""): + dep_packages = ['setuptools', 'wheel'] + dep_packages.append('numpy==%s' % numpy_version if numpy_version else 'numpy') run_subprocess([sys.executable, '-m', 'pip', 'install', '--trusted-host', 'files.pythonhosted.org'] + dep_packages) def check_md5(filename, expected_md5): @@ -465,15 +469,16 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enab run_subprocess([os.path.join(cwd,'onnx_test_runner'), 'test_models'], cwd=cwd) if config != 'Debug': run_subprocess([sys.executable, 'onnx_backend_test_series.py'], cwd=cwd, dll_path=dll_path) - try: - import onnxmltools - import keras - onnxml_test = True - except ImportError: - warnings.warn("onnxmltools and keras are not installed. Following test cannot be run.") - onnxml_test = False - if onnxml_test: - run_subprocess([sys.executable, 'onnxruntime_test_python_keras.py'], cwd=cwd, dll_path=dll_path) + if not args.skip_keras_test: + try: + import onnxmltools + import keras + onnxml_test = True + except ImportError: + warnings.warn("onnxmltools and keras are not installed. Following test cannot be run.") + onnxml_test = False + if onnxml_test: + run_subprocess([sys.executable, 'onnxruntime_test_python_keras.py'], cwd=cwd, dll_path=dll_path) def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_parallel_executor_test, num_parallel_models): for config in configs: @@ -570,7 +575,7 @@ def main(): if not is_docker(): install_python_deps() if (args.enable_pybind and is_windows()): - install_python_deps() + install_python_deps(args.numpy_version) if (not args.skip_submodule_sync): update_submodules(source_dir) From d40a9f894f7161ce8671a8678b282c7af15b12f8 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 7 Mar 2019 11:07:35 -0800 Subject: [PATCH 6/6] Enable Component Detection (#559) * Enable Component Detection --- .../azure-pipelines/linux-ci-pipeline.yml | 4 + .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 + .../azure-pipelines/mac-ci-pipeline.yml | 6 +- .../azure-pipelines/win-ci-pipeline-cg.yml | 135 ------------------ .../azure-pipelines/win-ci-pipeline.yml | 5 + .../azure-pipelines/win-gpu-ci-pipeline.yml | 5 + 6 files changed, 23 insertions(+), 136 deletions(-) delete mode 100644 tools/ci_build/github/azure-pipelines/win-ci-pipeline-cg.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 636e8b81f8e40..d3d79cf674c0d 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -7,4 +7,8 @@ jobs: - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -x "--use_mklml --use_tvm --test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum)"' displayName: 'Command Line Script' + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 382da4fd819be..014bd3d2f117c 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -7,4 +7,8 @@ jobs: - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d gpu -r $(Build.BinariesDirectory) -x "--test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum)"' displayName: 'Command Line Script' + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml index 3383e20f5a31a..0954611e1317f 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml @@ -9,4 +9,8 @@ jobs: python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --use_openmp --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --build_shared_lib --enable_onnx_tests --test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum) displayName: 'Build and Test OnnxRuntime lib for MacOS' - - template: templates/clean-agent-build-directory-step.yml \ No newline at end of file + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) + + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline-cg.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline-cg.yml deleted file mode 100644 index fab2013fd2e0b..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline-cg.yml +++ /dev/null @@ -1,135 +0,0 @@ -jobs: -- job: Windows_CI_Dev - variables: - buildDirectory: '$(Build.BinariesDirectory)' - steps: - - template: templates/set-test-data-variables-step.yml - - task: NuGetCommand@2 - displayName: 'NuGet restore' - inputs: - restoreSolution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - feedsToUse: config - nugetConfigPath: '$(Build.SourcesDirectory)\csharp\Nuget.CSharp.config' - restoreDirectory: '$(Build.SourcesDirectory)\csharp' - - task: UniversalPackages@0 - displayName: 'Download python' - inputs: - command: download - vstsFeed: '$(System.TeamProject)' - vstsFeedPackage: 'miniconda3_win64' - vstsPackageVersion: '4.5.11' - downloadDirectory: '$(Build.BinariesDirectory)\python' - - task: CmdLine@1 - displayName: 'Run python installer' - inputs: - filename: '$(Build.BinariesDirectory)\python\installer.exe' - arguments: '/S /NoRegistry=1 /AddToPath=0 /RegisterPython=0 /D=$(Build.BinariesDirectory)\packages\python' - timeoutInMinutes: 10 - - task: BatchScript@1 - displayName: 'setup env' - inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env.bat' - modifyEnvironment: true - workingFolder: '$(Build.BinariesDirectory)' - - task: CmdLine@1 - displayName: 'Install conda modules' - inputs: - filename: '$(Build.BinariesDirectory)\packages\python\scripts\conda.exe' - arguments: 'install -q --insecure -y pyopenssl setuptools wheel numpy' - timeoutInMinutes: 10 - - - task: CmdLine@1 - displayName: 'Download cmake' - inputs: - filename: '$(Build.BinariesDirectory)\packages\python\python.exe' - arguments: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\download_cmake.py --build_dir $(Build.BinariesDirectory)' - - task: CmdLine@1 - displayName: 'Download test data and generate cmake config' - inputs: - filename: '$(Build.BinariesDirectory)\packages\python\python.exe' - arguments: '$(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug Release --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --use_tvm --enable_pybind --use_mkldnn --use_mklml --use_openmp --build_shared_lib --enable_onnx_tests --test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum) --update' - workingDirectory: "$(Build.BinariesDirectory)" - - - task: VSBuild@1 - displayName: 'Build Debug' - inputs: - solution: '$(Build.BinariesDirectory)\Debug\onnxruntime.sln' - platform: 'x64' - configuration: 'Debug' - msbuildArgs: '/m' - msbuildArchitecture: 'x64' - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\Debug' - - task: BatchScript@1 - displayName: 'Test Debug' - inputs: - filename: '$(Build.BinariesDirectory)\packages\python\python.exe' - arguments: '$(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --use_tvm --enable_pybind --use_mkldnn --use_mklml --use_openmp --build_shared_lib --enable_onnx_tests --test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum) --test' - workingFolder: '$(Build.BinariesDirectory)' - - task: VSBuild@1 - displayName: 'Build C# Debug' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'any cpu' - configuration: 'Debug' - restoreNugetPackages: false - msbuildArchitecture: 'x64' - workingFolder: '$(Build.SourcesDirectory)\csharp' - msbuildArgs: '/m /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)' - - - task: VSTest@2 - displayName: 'VsTest - C# Debug' - inputs: - testAssemblyVer2: '**\bin\Debug\**\*Tests.dll' - searchFolder: '$(Build.SourcesDirectory)\csharp\test' - runInParallel: true - configuration: Debug - - - task: VSBuild@1 - displayName: 'Build Release' - inputs: - solution: '$(Build.BinariesDirectory)\Release\onnxruntime.sln' - platform: 'x64' - configuration: 'Release' - msbuildArgs: '/m' - msbuildArchitecture: 'x64' - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\Release' - - task: BatchScript@1 - displayName: 'Test Release' - inputs: - filename: '$(Build.BinariesDirectory)\packages\python\python.exe' - arguments: '$(Build.SourcesDirectory)\tools\ci_build\build.py --config Release --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --use_tvm --enable_pybind --use_mkldnn --use_mklml --use_openmp --build_shared_lib --enable_onnx_tests --test_data_url $(TestDataUrl) --test_data_checksum $(TestDataChecksum) --test' - workingFolder: "$(Build.BinariesDirectory)" - - - task: VSBuild@1 - displayName: 'Build c# Release' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'any cpu' - configuration: 'Release' - msbuildArchitecture: 'x64' - restoreNugetPackages: false - workingFolder: '$(Build.SourcesDirectory)\csharp' - msbuildArgs: '/m /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)' - - - task: VSTest@2 - displayName: 'VsTest - C# Release' - inputs: - testAssemblyVer2: '**\bin\Release\**\*Tests.dll' - searchFolder: '$(Build.SourcesDirectory)\csharp\test' - runInParallel: true - configuration: Release - - - task: PublishTestResults@2 - displayName: 'Publish unit test results' - inputs: - testResultsFiles: '**\*.results.xml' - searchFolder: '$(Build.BinariesDirectory)' - testRunTitle: 'Unit Test Run' - condition: succeededOrFailed() - - - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 - displayName: 'Component Detection' - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 19a653827d74c..2340c150dd6db 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -128,4 +128,9 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() + + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 8658624f69172..0d0c7367e344a 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -139,5 +139,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() + + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) + - template: templates/clean-agent-build-directory-step.yml