diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5aa53440a06ac..29f83819fe9f0 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -72,6 +72,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_BUILD_HOSTING "Build ONNX hosting service" ON) #TODO: make default off before going into master option(onnxruntime_USE_FULL_PROTOBUF "Use full protobuf" OFF) option(onnxruntime_DISABLE_CONTRIB_OPS "Disable contrib ops" OFF) +option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) #nsync tests failed on Mac Build @@ -584,12 +585,6 @@ if (onnxruntime_BUILD_SHARED_LIB) include(onnxruntime.cmake) endif() -if (onnxruntime_BUILD_CSHARP) - message(STATUS "CSharp Build is enabled") -# set_property(GLOBAL PROPERTY VS_DOTNET_TARGET_FRAMEWORK_VERSION "netstandard2.0") - include(onnxruntime_csharp.cmake) -endif() - # some of the tests rely on the shared libs to be # built; hence the ordering if (onnxruntime_BUILD_UNIT_TESTS) @@ -611,6 +606,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) include(onnxruntime_unittests.cmake) endif() +if (onnxruntime_BUILD_CSHARP) + message(STATUS "CSharp Build is enabled") +# set_property(GLOBAL PROPERTY VS_DOTNET_TARGET_FRAMEWORK_VERSION "netstandard2.0") + include(onnxruntime_csharp.cmake) +endif() + if (onnxruntime_BUILD_HOSTING) include(onnxruntime_hosting.cmake) endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 890f9af16a5ba..ba82f6a70e7e8 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -446,6 +446,7 @@ add_test(NAME onnx_test_pytorch_converted add_test(NAME onnx_test_pytorch_operator COMMAND onnx_test_runner ${PROJECT_SOURCE_DIR}/external/onnx/onnx/backend/test/data/pytorch-operator) +#perf test runner set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) set(onnxruntime_perf_test_src_patterns "${onnxruntime_perf_test_src_dir}/*.cc" @@ -462,16 +463,30 @@ else () endif() file(GLOB onnxruntime_perf_test_src ${onnxruntime_perf_test_src_patterns}) -add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src}) +add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/framework/path_lib.cc) target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${extra_includes} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/onnx) if (WIN32) target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) + SET(SYS_PATH_LIB shlwapi) endif() onnxruntime_add_include_to_target(onnxruntime_perf_test gsl) -target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) + +if (onnxruntime_BUILD_SHARED_LIB) + target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_test_utils onnx_test_runner_common onnxruntime_common + onnx_test_data_proto onnx_proto libprotobuf ${GETOPT_LIB_WIDE} onnxruntime + ${SYS_PATH_LIB} ${CMAKE_DL_LIBS} Threads::Threads) + if(tensorflow_C_PACKAGE_PATH) + target_include_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/include) + target_link_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/lib) + target_link_libraries(onnxruntime_perf_test PRIVATE tensorflow) + target_compile_definitions(onnxruntime_perf_test PRIVATE HAVE_TENSORFLOW) + endif() +else() + target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) +endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") # shared lib diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Microsoft.ML.OnnxRuntime.InferenceSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Microsoft.ML.OnnxRuntime.InferenceSample.csproj index ae8cb450c2995..8a43842d729c2 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Microsoft.ML.OnnxRuntime.InferenceSample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Microsoft.ML.OnnxRuntime.InferenceSample.csproj @@ -22,6 +22,10 @@ Always false + + Always + false + Always false diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 6e253e7a6a502..1e4cdb2bc6773 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -65,6 +65,13 @@ CopyToOutputDirectory="Always" Visible="false" /> + + + @@ -111,6 +120,8 @@ + + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.props b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.props deleted file mode 100644 index 5ccd0a36728f4..0000000000000 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.props +++ /dev/null @@ -1,36 +0,0 @@ - - - - - - $(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories) - - - $(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories) - - - - - - $(MSBuildThisFileDirectory)../../runtimes/win-x64/native/onnxruntime.lib - - - - - - onnxruntime.dll - PreserveNewest - false - - - mkldnn.dll - PreserveNewest - false - - - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.targets b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.targets deleted file mode 100644 index c4829504eda37..0000000000000 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.targets +++ /dev/null @@ -1,16 +0,0 @@ - - - - - - - - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.Gpu.props b/csharp/src/Microsoft.ML.OnnxRuntime/props.xml similarity index 81% rename from csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.Gpu.props rename to csharp/src/Microsoft.ML.OnnxRuntime/props.xml index 5ccd0a36728f4..c2a3a9b4ae06c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.Gpu.props +++ b/csharp/src/Microsoft.ML.OnnxRuntime/props.xml @@ -32,5 +32,12 @@ PreserveNewest false + + mklml.dll + PreserveNewest + false + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.Gpu.targets b/csharp/src/Microsoft.ML.OnnxRuntime/targets.xml similarity index 100% rename from csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.Gpu.targets rename to csharp/src/Microsoft.ML.OnnxRuntime/targets.xml diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index d98b8d3995eea..43d16ee5d2b3d 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -30,6 +30,32 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence); +// This section includes all opkernel declarations for former experimental ops which have now been removed from onnx. +// To maintain backward compatibility these are added as contrib ops. +// Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops +// we cannot change the domain now as this will break backward compatibility. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu); + + void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); @@ -56,6 +82,31 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + + + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to main backward compatibility + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/contrib_op_activations.cc b/onnxruntime/contrib_ops/cpu/contrib_op_activations.cc new file mode 100644 index 0000000000000..add4ff3551778 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/contrib_op_activations.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/activation/activations.h" +#include "contrib_op_activations.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_KERNEL( + ParametricSoftplus, + 1, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), + ParametricSoftplus); + +ONNX_CPU_OPERATOR_KERNEL( + ScaledTanh, + 1, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), + ScaledTanh); + +ONNX_CPU_OPERATOR_KERNEL( + ThresholdedRelu, + 1, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), + ThresholdedRelu); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/contrib_op_activations.h b/onnxruntime/contrib_ops/cpu/contrib_op_activations.h new file mode 100644 index 0000000000000..fff59b0fc5e5b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/contrib_op_activations.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace contrib { + +template +class ScaledTanh final : public OpKernel { + public: + ScaledTanh(const OpKernelInfo& info) + : OpKernel(info), alpha_(info.GetAttrOrDefault("alpha", 1.0f)), beta_(info.GetAttrOrDefault("beta", 1.0f)) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + Tensor* Y = context->Output(0, X->Shape()); + EIGEN_Y = (T)alpha_ * (EIGEN_X * (T)beta_).tanh(); + return Status::OK(); + } + + private: + const float alpha_; + const float beta_; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/crop.cc b/onnxruntime/contrib_ops/cpu/crop.cc similarity index 86% rename from onnxruntime/core/providers/cpu/tensor/crop.cc rename to onnxruntime/contrib_ops/cpu/crop.cc index 53698123eb8bd..804d7876a8081 100644 --- a/onnxruntime/core/providers/cpu/tensor/crop.cc +++ b/onnxruntime/contrib_ops/cpu/crop.cc @@ -1,12 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cpu/tensor/crop.h" +#include "crop.h" namespace onnxruntime { +namespace contrib { ONNX_CPU_OPERATOR_KERNEL( Crop, 1, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Crop); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/crop.h b/onnxruntime/contrib_ops/cpu/crop.h similarity index 99% rename from onnxruntime/core/providers/cpu/tensor/crop.h rename to onnxruntime/contrib_ops/cpu/crop.h index 557e76cc8e84d..16f397c717625 100644 --- a/onnxruntime/core/providers/cpu/tensor/crop.h +++ b/onnxruntime/contrib_ops/cpu/crop.h @@ -9,6 +9,7 @@ #include "gsl/gsl_util" namespace onnxruntime { +namespace contrib { class CropBase { protected: @@ -131,4 +132,5 @@ class Crop final : public CropBase, public OpKernel { } }; +} } //namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.cc b/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.cc new file mode 100644 index 0000000000000..919ab366eb121 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "element_wise_exp_ops.h" +#include "core/providers/cpu/math/element_wise_ops.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_KERNEL( + Affine, + 1, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Affine); + +template <> +Status Affine::Compute(OpKernelContext* ctx) const { + auto& X = *ctx->Input(0); + auto& Y = *ctx->Output(0, X.Shape()); + MakeEigenArrayMap(Y) = alpha_ * MakeEigenArrayMap(X) + beta_; + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.h b/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.h new file mode 100644 index 0000000000000..1837126e74754 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/element_wise_exp_ops.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace contrib { +template +class Affine final : public OpKernel { + public: + Affine(const OpKernelInfo& info) : OpKernel(info) { + // Either model-supplied or default values should be returned for alpha and beta + ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK()); + ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); + } + + Status Compute(OpKernelContext* context) const override; + + private: + float alpha_; + float beta_; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/image_scaler.cc b/onnxruntime/contrib_ops/cpu/image_scaler.cc similarity index 85% rename from onnxruntime/core/providers/cpu/tensor/image_scaler.cc rename to onnxruntime/contrib_ops/cpu/image_scaler.cc index 98c7240b59b64..875bd3a0a428e 100644 --- a/onnxruntime/core/providers/cpu/tensor/image_scaler.cc +++ b/onnxruntime/contrib_ops/cpu/image_scaler.cc @@ -1,12 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cpu/tensor/image_scaler.h" +#include "image_scaler.h" namespace onnxruntime { +namespace contrib { ONNX_CPU_OPERATOR_KERNEL( ImageScaler, 1, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), ImageScaler); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/image_scaler.h b/onnxruntime/contrib_ops/cpu/image_scaler.h similarity index 98% rename from onnxruntime/core/providers/cpu/tensor/image_scaler.h rename to onnxruntime/contrib_ops/cpu/image_scaler.h index 221e0a41f34eb..66ee084474044 100644 --- a/onnxruntime/core/providers/cpu/tensor/image_scaler.h +++ b/onnxruntime/contrib_ops/cpu/image_scaler.h @@ -8,6 +8,7 @@ #include "core/util/math_cpuonly.h" namespace onnxruntime { +namespace contrib{ template class ImageScaler final : public OpKernel { @@ -51,5 +52,5 @@ class ImageScaler final : public OpKernel { float scale_; std::vector bias_; }; - +} } //namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/mean_variance_normalization_exp.cc b/onnxruntime/contrib_ops/cpu/mean_variance_normalization_exp.cc new file mode 100644 index 0000000000000..f2f054c2a57c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/mean_variance_normalization_exp.cc @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/mean_variance_normalization.h" + +namespace onnxruntime { +namespace contrib { +// Register MVN operator for backward compatibility. +// The experimental MVN op was removed. The history has to be kept locally as below. +// As of (9/26/2018) MVN is a production function in ONNX. +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + MeanVarianceNormalization, + 1, + 8, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MeanVarianceNormalization_0); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/environment.cc b/onnxruntime/core/framework/environment.cc index 815214ade52cc..1edda86ebbb55 100644 --- a/onnxruntime/core/framework/environment.cc +++ b/onnxruntime/core/framework/environment.cc @@ -40,38 +40,6 @@ Status Environment::Initialize() { RegisterOnnxOperatorSetSchema(); RegisterOnnxMLOperatorSetSchema(); }); - //TODO:put all of the following things into call_once - // Register MVN operator for backward compatibility. - // Experimental operator does not have history kept in ONNX. Unfortunately, RS5 takes bunch of experimental operators - // in onnx as production ops. MVN is one of them. Now (9/26/2018) MVN is a production function in ONNX. The experimental - // MVN op was removed. The history has to be kept locally as below. - ORT_ATTRIBUTE_UNUSED ONNX_OPERATOR_SCHEMA(MeanVarianceNormalization) - .SetDoc(R"DOC(Perform mean variance normalization.)DOC") - .Attr("across_channels", "If 1, mean and variance are computed across channels. Default is 0.", AttributeProto::INT, static_cast(0)) - .Attr("normalize_variance", "If 0, normalize the mean only. Default is 1.", AttributeProto::INT, static_cast(1)) - .Input(0, "input", "Input tensor of shape [N,C,H,W]", "T") - .Output(0, "output", "Result, has same shape and type as input", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput); - - ORT_ATTRIBUTE_UNUSED ONNX_OPERATOR_SCHEMA(ScaledTanh) - .Attr("alpha", "Scaling value", AttributeProto::FLOAT, OPTIONAL) - .Attr("beta", "Scaling value", AttributeProto::FLOAT, OPTIONAL) - .Input(0, "input", "Input tensor", "T") - .Output( - 0, - "output", - "The scaled hyperbolic tangent values of the input tensor " - "computed element-wise", - "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput); // Register MemCpy schema; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 2daa648fa3362..90a2468ef815f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -222,7 +222,11 @@ void convPoolShapeInference( } void RegisterContribSchemas() { - // ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler) old version history maintainance + // Register removed experimental ops for backward compatibility. + // Experimental operators do not have version history. However, RS5 takes bunch of experimental operators + // as production ops. In order to maintain backward compatibility when the experimental ops are removed from ONNX + // they need to be added in onnxruntime as contrib ops. + // ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler, ThresholdedRelu, DynamicSlice, ScaledTanh, MVN) old version history maintenance static const char* Affine_ver1_doc = R"DOC( Affine takes one input data (Tensor) and produces one output data (Tensor) where the affine function, y = alpha * x + beta, @@ -344,6 +348,36 @@ Example 2: .Output(0, "output", "Sliced data tensor.", "T") .TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_OPERATOR_SCHEMA(MeanVarianceNormalization) + .SinceVersion(1) + .SetDoc(R"DOC(Perform mean variance normalization.)DOC") + .Attr("across_channels", "If 1, mean and variance are computed across channels. Default is 0.", AttributeProto::INT, static_cast(0)) + .Attr("normalize_variance", "If 0, normalize the mean only. Default is 1.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "Input tensor of shape [N,C,H,W]", "T") + .Output(0, "output", "Result, has same shape and type as input", "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + + ONNX_OPERATOR_SCHEMA(ScaledTanh) + .SinceVersion(1) + .Attr("alpha", "Scaling value", AttributeProto::FLOAT, OPTIONAL) + .Attr("beta", "Scaling value", AttributeProto::FLOAT, OPTIONAL) + .Input(0, "input", "Input tensor", "T") + .Output( + 0, + "output", + "The scaled hyperbolic tangent values of the input tensor " + "computed element-wise", + "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); ONNX_CONTRIB_OPERATOR_SCHEMA(Affine) .SinceVersion(10) @@ -403,7 +437,25 @@ Example 2: .TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); - // End of ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler) old version history maintainance + ONNX_OPERATOR_SCHEMA(ScaledTanh) + .SinceVersion(10) + .Deprecate() + .Attr("alpha", "Scaling value", AttributeProto::FLOAT, OPTIONAL) + .Attr("beta", "Scaling value", AttributeProto::FLOAT, OPTIONAL) + .Input(0, "input", "Input tensor", "T") + .Output( + 0, + "output", + "The scaled hyperbolic tangent values of the input tensor " + "computed element-wise", + "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + + // End of ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler, ThresholdedRelu, DynamicSlice, ScaledTanh, MVN) old version history maintenance ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp) .SetDomain(kMSDomain) diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 78ff664e1eb1a..4a40471a2caf0 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -145,16 +145,6 @@ class PosixEnv : public Env { return Status::OK(); } - static bool GetFileSizeIfUnknown(int fd, size_t& len) { - if(len > 0) return true; - struct stat stbuf; - if ((fstat(fd, &stbuf) != 0) || (!S_ISREG(stbuf.st_mode))) { - return false; - } - len = static_cast(stbuf.st_size); - return true; - } - common::Status ReadFileAsString(const char* fname, off_t offset, void*& p, size_t& len, OrtCallback& deleter) const override { if (!fname) { @@ -169,13 +159,21 @@ class PosixEnv : public Env { deleter.param = nullptr; int fd = open(fname, O_RDONLY); if (fd < 0) { - int err = errno; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", fname, " fail, errcode =", err); + return ReportSystemError("open", fname); } - if (!GetFileSizeIfUnknown(fd, len)) { - (void)close(fd); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Get file '", fname, "' size fail"); + if (len <= 0) { + struct stat stbuf; + if (fstat(fd, &stbuf) != 0) { + return ReportSystemError("fstat", fname); + } + + if (!S_ISREG(stbuf.st_mode)) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "ReadFileAsString: input is not a regular file"); + } + len = static_cast(stbuf.st_size); } + if (len == 0) { p = nullptr; } else { @@ -199,10 +197,30 @@ class PosixEnv : public Env { return common::Status::OK(); } + static common::Status ReportSystemError(const char* operation_name, const std::string& path) { + auto e = errno; + char buf[1024]; + const char* msg = ""; + if (e > 0) { +#if defined(_GNU_SOURCE) && !defined(__APPLE__) + msg = strerror_r(e, buf, sizeof(buf)); +#else + // for Mac OS X + if (strerror_r(e, buf, sizeof(buf)) != 0) { + buf[0] = '\0'; + } + msg = buf; +#endif + } + std::ostringstream oss; + oss << operation_name << " file \"" << path << "\" failed: " << msg; + return common::Status(common::SYSTEM, e, oss.str()); + } + common::Status FileOpenRd(const std::string& path, /*out*/ int& fd) const override { fd = open(path.c_str(), O_RDONLY); if (0 > fd) { - return common::Status(common::SYSTEM, errno); + return ReportSystemError("open", path); } return Status::OK(); } @@ -210,7 +228,7 @@ class PosixEnv : public Env { common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override { fd = open(path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); if (0 > fd) { - return common::Status(common::SYSTEM, errno); + return ReportSystemError("open", path); } return Status::OK(); } @@ -218,7 +236,7 @@ class PosixEnv : public Env { common::Status FileClose(int fd) const override { int ret = close(fd); if (0 != ret) { - return common::Status(common::SYSTEM, errno); + return ReportSystemError("close", ""); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index cbdaaceb8d9aa..e2e56291c9546 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -19,16 +19,13 @@ namespace onnxruntime { REGISTER_UNARY_ELEMENTWISE_KERNEL(Elu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(HardSigmoid, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(LeakyRelu, 6); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ParametricSoftplus, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Relu, 6); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ScaledTanh, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Selu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(Sigmoid, 6); // SoftPlus is the default case for ParametricSoftPlus REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(Softplus, ParametricSoftplus, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Tanh, 6); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10); template <> diff --git a/onnxruntime/core/providers/cpu/activation/activations.h b/onnxruntime/core/providers/cpu/activation/activations.h index 1283f570a1121..9aaac6b05c6cb 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.h +++ b/onnxruntime/core/providers/cpu/activation/activations.h @@ -102,24 +102,6 @@ class Relu : public OpKernel { } }; -template -class ScaledTanh final : public OpKernel { - public: - ScaledTanh(const OpKernelInfo& info) - : OpKernel(info), alpha_(info.GetAttrOrDefault("alpha", 1.0f)), beta_(info.GetAttrOrDefault("beta", 1.0f)) {} - - Status Compute(OpKernelContext* context) const override { - const Tensor* X = context->Input(0); - Tensor* Y = context->Output(0, X->Shape()); - EIGEN_Y = (T)alpha_ * (EIGEN_X * (T)beta_).tanh(); - return Status::OK(); - } - - private: - const float alpha_; - const float beta_; -}; - template class Selu final : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3fc0dc71225d0..c05ed3abdd467 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -18,15 +18,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Cli class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Elu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, LeakyRelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Relu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Selu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softplus); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softsign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Tanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, PRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomNormal); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniform); @@ -83,7 +80,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Sin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan); @@ -156,12 +152,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, MLFloat16, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 4, Concat); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Identity); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, Reshape_1); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape); @@ -180,21 +173,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice); - class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split); @@ -269,15 +247,12 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); @@ -334,7 +309,6 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); @@ -407,12 +381,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); @@ -431,21 +402,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index f12feeb5d3ca5..5f34ef730e8ff 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -300,12 +300,6 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Mean_8); -ONNX_CPU_OPERATOR_KERNEL( - Affine, - 1, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Affine); - ONNX_CPU_OPERATOR_KERNEL( Scale, 1, @@ -660,14 +654,6 @@ Status Mean_8::Compute(OpKernelContext* context) const { return Status::OK(); } -template <> -Status Affine::Compute(OpKernelContext* ctx) const { - auto& X = *ctx->Input(0); - auto& Y = *ctx->Output(0, X.Shape()); - MakeEigenArrayMap(Y) = alpha_ * MakeEigenArrayMap(X) + beta_; - return Status::OK(); -} - template class Sin final : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 57bca99b282af..b2f96460e7d12 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -269,22 +269,6 @@ class Mean_8 final : public OpKernel { Status Compute(OpKernelContext* context) const override; }; -template -class Affine final : public OpKernel { - public: - Affine(const OpKernelInfo& info) : OpKernel(info) { - // Either model-supplied or default values should be returned for alpha and beta - ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); - } - - Status Compute(OpKernelContext* context) const override; - - private: - float alpha_; - float beta_; -}; - // PRelu is activation function, but it's closer to binary elementwise ops in implementation template class PRelu final : public OpKernel { diff --git a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc index 67c6feab46312..ba600377c89b0 100644 --- a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc +++ b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc @@ -4,13 +4,6 @@ #include "core/providers/cpu/tensor/mean_variance_normalization.h" namespace onnxruntime { -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - MeanVarianceNormalization, - 1, - 8, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MeanVarianceNormalization_0); - ONNX_CPU_OPERATOR_KERNEL( MeanVarianceNormalization, 9, diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 5ceeb630dda30..86e5772a3786e 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -7,7 +7,6 @@ using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { - #define ADD_TYPED_SLICE_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ Slice, \ @@ -30,6 +29,8 @@ ADD_TYPED_SLICE_OP(MLFloat16); ADD_TYPED_SLICE_OP(bool); ADD_TYPED_SLICE_OP(string); +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { #define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ DynamicSlice, \ @@ -54,6 +55,8 @@ ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16); ADD_TYPED_DYNAMIC_SLICE_OP(bool); ADD_TYPED_DYNAMIC_SLICE_OP(string); +} // namespace contrib +#endif namespace { // std::clamp doesn't exist until C++17 so create a local version template diff --git a/onnxruntime/core/providers/cuda/tensor/crop.h b/onnxruntime/core/providers/cuda/tensor/crop.h index 50ef2af0fe1e7..f3326bef008ca 100644 --- a/onnxruntime/core/providers/cuda/tensor/crop.h +++ b/onnxruntime/core/providers/cuda/tensor/crop.h @@ -4,15 +4,15 @@ #pragma once #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cpu/tensor/crop.h" +#include "contrib_ops/cpu/crop.h" namespace onnxruntime { namespace cuda { template -class Crop final : public CropBase, public CudaKernel { +class Crop final : public contrib::CropBase, public CudaKernel { public: - Crop(const OpKernelInfo& info) : CropBase(info), CudaKernel(info) { + Crop(const OpKernelInfo& info) : contrib::CropBase(info), CudaKernel(info) { } Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc b/onnxruntime/test/contrib_ops/dynamic_slice_op_test.cc similarity index 100% rename from onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc rename to onnxruntime/test/contrib_ops/dynamic_slice_op_test.cc diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 460e1cc6b002e..bdf040aed93fc 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -222,14 +222,6 @@ class OnnxModelInfo : public TestModelInfo { const PATH_CHAR_TYPE* GetModelUrl() const override { return model_url_.c_str(); } - std::basic_string GetDir() const override { - std::basic_string test_case_dir; - auto st = GetDirNameFromFilePath(model_url_, test_case_dir); - if (!st.IsOK()) { - ORT_THROW("GetDirNameFromFilePath failed"); - } - return test_case_dir; - } const std::string& GetNodeName() const override { return node_name_; } const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const override { return &output_value_info_[i]; diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 9f632882c6400..5d7fe59453d78 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -37,7 +37,14 @@ class ITestCase { class TestModelInfo { public: virtual const PATH_CHAR_TYPE* GetModelUrl() const = 0; - virtual std::basic_string GetDir() const = 0; + virtual std::basic_string GetDir() const { + std::basic_string test_case_dir; + auto st = onnxruntime::GetDirNameFromFilePath(GetModelUrl(), test_case_dir); + if (!st.IsOK()) { + ORT_THROW("GetDirNameFromFilePath failed"); + } + return test_case_dir; + } virtual const std::string& GetNodeName() const = 0; virtual const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const = 0; virtual int GetInputCount() const = 0; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 3c78566f0065e..4152c2aa8fdb5 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -204,6 +204,23 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) { if (enable_cuda) { #ifdef USE_CUDA ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); + // Filter out some flaky tests from cuda test runs. Those tests + // caused random segfault in CUDA 9.1. + // TODO: remove this list once we fully moved to CUDA10 + // clang-format off + std::unordered_set cuda_flaky_tests = { + "fp16_inception_v1", "fp16_shufflenet", "fp16_tiny_yolov2" + }; + for (auto it = tests.begin(); it != tests.end();) { + auto iter = cuda_flaky_tests.find((*it)->GetTestCaseName()); + if (iter != cuda_flaky_tests.end()) { + delete *it; + it = tests.erase(it); + } + else { + ++it; + } + } #else fprintf(stderr, "CUDA is not supported in this build"); return -1; @@ -225,6 +242,7 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) { return -1; #endif } + TestEnv args(tests, stat, sf); Status st = RunTests(args, p_models, concurrent_session_runs, static_cast(repeat_count), GetDefaultThreadPool(Env::Default())); diff --git a/onnxruntime/test/onnx/runner.cc b/onnxruntime/test/onnx/runner.cc index 4c532d8b82816..b80f01793c37a 100644 --- a/onnxruntime/test/onnx/runner.cc +++ b/onnxruntime/test/onnx/runner.cc @@ -317,7 +317,7 @@ EXECUTE_RESULT DataRunner::RunTaskImpl(size_t task_id) { c_->LoadTestData(task_id, holder, feeds, true); // Create output feed - size_t output_count; + size_t output_count = 0; ORT_THROW_ON_ERROR(OrtSessionGetOutputCount(session, &output_count)); std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { diff --git a/onnxruntime/test/perftest/TFModelInfo.cc b/onnxruntime/test/perftest/TFModelInfo.cc new file mode 100644 index 0000000000000..21503846bc36f --- /dev/null +++ b/onnxruntime/test/perftest/TFModelInfo.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "TFModelInfo.h" +#include + +TestModelInfo* TFModelInfo::Create(_In_ const PATH_CHAR_TYPE* model_url) { + TFModelInfo* ret = new TFModelInfo(); + ret->model_url_ = model_url; + std::basic_string meta_file_path = model_url; + meta_file_path.append(ORT_TSTR(".meta")); + void* p = nullptr; + size_t len = 0; + OrtCallback b; + auto st = onnxruntime::Env::Default().ReadFileAsString(meta_file_path.c_str(), 0, p, len, b); + if (!st.IsOK()) { + ORT_THROW(st.ErrorMessage()); + } + // this string is not null terminated + std::string filecontent(reinterpret_cast(p), len); + std::istringstream is(filecontent); + + std::string line; + while (std::getline(is, line)) { + size_t line_len = 0; + if (!line.empty() && line.back() == '\n') { + line_len = line.length() - 1; + if (line_len > 0 && line[line_len - 1] == '\r') { + --line_len; + } + line.resize(line_len); + } + if (line.empty()) continue; + if (line.compare(0, 6, "input=") == 0) { + ret->input_names_.push_back(line.substr(6)); + } else if (line.compare(0, 7, "output=") == 0) { + ret->output_names_.push_back(line.substr(7)); + } else { + ORT_THROW("unknow line:", line.size()); + } + } + + if (b.f) b.f(b.param); + + return ret; +} + +int TFModelInfo::GetInputCount() const { return static_cast(input_names_.size()); } +int TFModelInfo::GetOutputCount() const { return static_cast(output_names_.size()); } +const std::string& TFModelInfo::GetInputName(size_t i) const { return input_names_[i]; } +const std::string& TFModelInfo::GetOutputName(size_t i) const { return output_names_[i]; } diff --git a/onnxruntime/test/perftest/TFModelInfo.h b/onnxruntime/test/perftest/TFModelInfo.h new file mode 100644 index 0000000000000..2d90c3bd25359 --- /dev/null +++ b/onnxruntime/test/perftest/TFModelInfo.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "TestCase.h" +#include +#include + +class TFModelInfo : public TestModelInfo { + public: + const PATH_CHAR_TYPE* GetModelUrl() const override { return model_url_.c_str(); } + + const std::string& GetNodeName() const override { return node_name_; } + const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t) const override { return nullptr; } + + int GetInputCount() const override; + int GetOutputCount() const override; + const std::string& GetInputName(size_t i) const override; + const std::string& GetOutputName(size_t i) const override; + ~TFModelInfo() override = default; + + static TestModelInfo* Create(_In_ const PATH_CHAR_TYPE* model_url); + + private: + TFModelInfo() = default; + std::basic_string model_url_; + std::vector input_names_; + std::vector output_names_; + std::string node_name_; +}; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 863179bd61fa8..7fd60da7ab19d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -27,8 +27,10 @@ namespace perftest { "perf_test [options...] model_path result_file\n" "Options:\n" "\t-m [test_mode]: Specifies the test mode. Value coulde be 'duration' or 'times'.\n" - "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. Default:'duration'.\n" + "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. " + "Default:'duration'.\n" "\t-e [cpu|cuda|mkldnn|tensorrt]: Specifies the provider 'cpu','cuda','mkldnn' or 'tensorrt'. Default:'cpu'.\n" + "\t-b [tf|ort]: backend to use. Default:ort\n" "\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n" "\t-t [seconds_to_run]: Specifies the seconds to run for 'duration' mode. Default:600.\n" "\t-p [profile_file]: Specifies the profile name to enable profiling and dump the profile data to the file.\n" @@ -41,7 +43,7 @@ namespace perftest { /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:o:vhs"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:o:vhs"))) != -1) { switch (ch) { case 'm': if (!CompareCString(optarg, ORT_TSTR("duration"))) { @@ -52,6 +54,9 @@ namespace perftest { return false; } break; + case 'b': + test_config.backend = optarg; + break; case 'p': test_config.run_config.profile_file = optarg; break; @@ -71,14 +76,14 @@ namespace perftest { } break; case 'r': - test_config.run_config.repeated_times = static_cast(OrtStrtol(optarg, nullptr)); + test_config.run_config.repeated_times = static_cast(OrtStrtol(optarg, nullptr)); if (test_config.run_config.repeated_times <= 0) { return false; } test_config.run_config.test_mode = TestMode::KFixRepeatedTimesMode; break; case 't': - test_config.run_config.duration_in_seconds = static_cast(OrtStrtol(optarg, nullptr)); + test_config.run_config.duration_in_seconds = static_cast(OrtStrtol(optarg, nullptr)); if (test_config.run_config.repeated_times <= 0) { return false; } @@ -100,8 +105,8 @@ namespace perftest { case 'o': test_config.run_config.optimization_level = static_cast(OrtStrtol(optarg, nullptr)); // Valid values are: 0, 1, 2. - if (test_config.run_config.optimization_level > 2 ) { - return false; + if (test_config.run_config.optimization_level > 2) { + return false; } break; case '?': diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 43c0446f3dcfc..8a5cbfe6fa25a 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -2,10 +2,7 @@ // Licensed under the MIT License. // onnxruntime dependencies -#include -#include -#include -#include +#include #include "command_args_parser.h" #include "performance_runner.h" @@ -36,7 +33,7 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) { perftest::PerformanceRunner perf_runner(env, test_config); auto status = perf_runner.Run(); if (!status.IsOK()) { - LOGF_DEFAULT(ERROR, "Run failed:%s", status.ErrorMessage().c_str()); + printf("Run failed:%s\n", status.ErrorMessage().c_str()); return -1; } @@ -60,8 +57,6 @@ int main(int argc, char* argv[]) { } if (env) { OrtReleaseEnv(env); - } else { - ::google::protobuf::ShutdownProtobufLibrary(); } return retval; } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc new file mode 100644 index 0000000000000..69f24efd9a05e --- /dev/null +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -0,0 +1,105 @@ +#include "ort_test_session.h" +#include +#include +#include "providers.h" +#include "TestCase.h" + +#ifdef _WIN32 +#define strdup _strdup +#endif + +namespace onnxruntime { +namespace perftest { + +std::chrono::duration OnnxRuntimeTestSession::Run(const OrtValue* const* input) { + auto start = std::chrono::high_resolution_clock::now(); + + ORT_THROW_ON_ERROR(OrtRun(session_object_, nullptr, input_names_.data(), input, input_names_.size(), + output_names_raw_ptr.data(), output_names_raw_ptr.size(), output_values_.data())); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration_seconds = end - start; + for (size_t i = 0; i != output_values_.size(); ++i) { + OrtReleaseValue(output_values_[i]); + output_values_[i] = nullptr; + } + return duration_seconds; +} + +OnnxRuntimeTestSession::OnnxRuntimeTestSession(OrtEnv* env, PerformanceTestConfig& performance_test_config, + const TestModelInfo* m) + : input_names_(m->GetInputCount()) { + SessionOptionsWrapper sf(env); + const bool enable_cpu_mem_arena = true; + const std::string& provider_name = performance_test_config.machine_config.provider_type_name; + if (provider_name == onnxruntime::kMklDnnExecutionProvider) { +#ifdef USE_MKLDNN + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, enable_cpu_mem_arena ? 1 : 0)); +#else + ORT_THROW("MKL-DNN is not supported in this build\n"); +#endif + } else if (provider_name == onnxruntime::kCudaExecutionProvider) { +#ifdef USE_CUDA + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); +#else + ORT_THROW("CUDA is not supported in this build\n"); +#endif + } else if (provider_name == onnxruntime::kNupharExecutionProvider) { +#ifdef USE_NUPHAR + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, "")); +#else + ORT_THROW("Nuphar is not supported in this build\n"); +#endif + } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { +#ifdef USE_TENSORRT + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf)); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); +#else + ORT_THROW("TensorRT is not supported in this build\n"); +#endif + } else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) { + ORT_THROW("This backend is not included in perf test runner.\n"); + } + + if (enable_cpu_mem_arena) + sf.EnableCpuMemArena(); + else + sf.DisableCpuMemArena(); + if (performance_test_config.run_config.enable_sequential_execution) + sf.EnableSequentialExecution(); + else + sf.DisableSequentialExecution(); + fprintf(stdout, "Setting thread pool size to %d\n", performance_test_config.run_config.session_thread_pool_size); + sf.SetSessionThreadPoolSize(performance_test_config.run_config.session_thread_pool_size); + // Set optimization level. + sf.SetSessionGraphOptimizationLevel(performance_test_config.run_config.optimization_level); + if (!performance_test_config.run_config.profile_file.empty()) + sf.EnableProfiling(performance_test_config.run_config.profile_file.c_str()); + session_object_ = sf.OrtCreateSession(performance_test_config.model_info.model_file_path.c_str()); + + size_t output_count; + ORT_THROW_ON_ERROR(OrtSessionGetOutputCount(session_object_, &output_count)); + output_names_.resize(output_count); + OrtAllocator* a; + ORT_THROW_ON_ERROR(OrtCreateDefaultAllocator(&a)); + for (size_t i = 0; i != output_count; ++i) { + char* output_name = nullptr; + ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session_object_, i, a, &output_name)); + assert(output_name != nullptr); + output_names_[i] = output_name; + a->Free(a, output_name); + } + output_names_raw_ptr.resize(output_count); + for (size_t i = 0; i != output_count; ++i) { + output_names_raw_ptr[i] = output_names_[i].c_str(); + } + OrtReleaseAllocator(a); + output_values_.resize(output_count); + + size_t input_count = static_cast(m->GetInputCount()); + for (size_t i = 0; i != input_count; ++i) { + input_names_[i] = strdup(m->GetInputName(i).c_str()); + } +} + +} // namespace perftest +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h new file mode 100644 index 0000000000000..01d461e89306f --- /dev/null +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "test_configuration.h" +#include "test_session.h" +class TestModelInfo; +namespace onnxruntime { +namespace perftest { +class OnnxRuntimeTestSession : public TestSession { + public: + OnnxRuntimeTestSession(OrtEnv* env, PerformanceTestConfig& performance_test_config, const TestModelInfo* m); + + ~OnnxRuntimeTestSession() override { + if (session_object_ != nullptr) OrtReleaseSession(session_object_); + for (char* p : input_names_) { + free(p); + } + } + std::chrono::duration Run(const OrtValue* const* input) override; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OnnxRuntimeTestSession); + + private: + OrtSession* session_object_ = nullptr; + std::vector output_names_; + // The same size with output_names_. + // TODO: implement a customized allocator, then we can remove output_names_ to simplify this code + std::vector output_names_raw_ptr; + std::vector output_values_; + std::vector input_names_; +}; + +} // namespace perftest +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index c9defc54a7867..8fa23843f7041 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -2,13 +2,15 @@ // Licensed under the MIT License. #include "performance_runner.h" +#include + #include "TestCase.h" -#include "core/graph/graph_viewer.h" //for onnxruntime::NodeArg -#include "core/session/inference_session.h" +#include "TFModelInfo.h" #include "utils.h" -#include "testenv.h" -#include "providers.h" - +#include "ort_test_session.h" +#ifdef HAVE_TENSORFLOW +#include "tf_test_session.h" +#endif using onnxruntime::Status; namespace onnxruntime { @@ -20,10 +22,9 @@ Status PerformanceRunner::Run() { // warm up RunOneIteration(true /*isWarmup*/); - InferenceSession* session_object = (InferenceSession*)session_object_; - if (!performance_test_config_.run_config.profile_file.empty()) - session_object->StartProfiling(performance_test_config_.run_config.profile_file); + // TODO: start profiling + // if (!performance_test_config_.run_config.profile_file.empty()) std::unique_ptr p_ICPUUsage = utils::CreateICPUUsage(); switch (performance_test_config_.run_config.test_mode) { @@ -39,7 +40,8 @@ Status PerformanceRunner::Run() { performance_result_.average_CPU_usage = p_ICPUUsage->GetUsage(); performance_result_.peak_workingset_size = utils::GetPeakWorkingSetSize(); - if (!performance_test_config_.run_config.profile_file.empty()) session_object->EndProfiling(); + // TODO: end profiling + // if (!performance_test_config_.run_config.profile_file.empty()) session_object->EndProfiling(); std::cout << "Total time cost:" << performance_result_.total_time_cost << std::endl << "Total iterations:" << performance_result_.time_costs.size() << std::endl @@ -48,18 +50,8 @@ Status PerformanceRunner::Run() { } Status PerformanceRunner::RunOneIteration(bool isWarmup) { - auto start = std::chrono::high_resolution_clock::now(); - OrtRunOptions run_options; - - ORT_THROW_ON_ERROR(OrtRun(session_object_, nullptr, input_names_.data(), input_values_.data(), input_names_.size(), - output_names_raw_ptr.data(), output_names_raw_ptr.size(), output_values_.data())); - auto end = std::chrono::high_resolution_clock::now(); - for (size_t i = 0; i != output_values_.size(); ++i) { - OrtReleaseValue(output_values_[i]); - output_values_[i] = nullptr; - } + std::chrono::duration duration_seconds = session_->Run(input_values_.data()); if (!isWarmup) { - std::chrono::duration duration_seconds = end - start; performance_result_.time_costs.emplace_back(duration_seconds.count()); performance_result_.total_time_cost += duration_seconds.count(); if (performance_test_config_.run_config.f_verbose) { @@ -70,16 +62,16 @@ Status PerformanceRunner::RunOneIteration(bool isWarmup) { return Status::OK(); } +PerformanceRunner::~PerformanceRunner() = default; + +PerformanceRunner::PerformanceRunner(OrtEnv* env, const PerformanceTestConfig& test_config) + : env_(env), performance_test_config_(test_config) {} + bool PerformanceRunner::Initialize() { - bool has_valid_extension = HasExtensionOf(performance_test_config_.model_info.model_file_path, ORT_TSTR("onnx")); - if (!has_valid_extension) { - LOGF_DEFAULT(ERROR, "input path is not a valid model"); - return false; - } std::basic_string test_case_dir; auto st = GetDirNameFromFilePath(performance_test_config_.model_info.model_file_path, test_case_dir); if (!st.IsOK()) { - LOGF_DEFAULT(ERROR, "input path is not a valid model"); + printf("input path is not a valid model\n"); return false; } std::basic_string model_name = GetLastComponent(test_case_dir); @@ -90,102 +82,40 @@ bool PerformanceRunner::Initialize() { std::string narrow_model_name = ToMBString(model_name); performance_result_.model_name = narrow_model_name; - auto p_model = TestModelInfo::LoadOnnxModel(performance_test_config_.model_info.model_file_path.c_str()); - std::unique_ptr test_case(CreateOnnxTestCase(narrow_model_name, p_model, 0.0, 0.0)); - - SessionOptionsWrapper sf(env_); - const bool enable_cpu_mem_arena = true; - const std::string& provider_name = performance_test_config_.machine_config.provider_type_name; - if (provider_name == onnxruntime::kMklDnnExecutionProvider) { -#ifdef USE_MKLDNN - ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, enable_cpu_mem_arena ? 1 : 0)); -#else - fprintf(stderr, "MKL-DNN is not supported in this build"); - return false; -#endif - } else if (provider_name == onnxruntime::kCudaExecutionProvider) { -#ifdef USE_CUDA - ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); -#else - fprintf(stderr, "CUDA is not supported in this build"); - return false; -#endif - } else if (provider_name == onnxruntime::kNupharExecutionProvider) { -#ifdef USE_NUPHAR - ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, "")); -#else - fprintf(stderr, "Nuphar is not supported in this build"); - return false; -#endif - } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { -#ifdef USE_TENSORRT - ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf)); - ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); -#else - fprintf(stderr, "TensorRT is not supported in this build"); - return false; -#endif - } else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) { - fprintf(stderr, "This backend is not included in perf test runner."); - return false; - } - - if (enable_cpu_mem_arena) - sf.EnableCpuMemArena(); - else - sf.DisableCpuMemArena(); - if (performance_test_config_.run_config.enable_sequential_execution) - sf.EnableSequentialExecution(); - else - sf.DisableSequentialExecution(); - fprintf(stdout, "Setting thread pool size to %d\n", performance_test_config_.run_config.session_thread_pool_size); - sf.SetSessionThreadPoolSize(performance_test_config_.run_config.session_thread_pool_size); - - // Set optimization level. - sf.SetSessionGraphOptimizationLevel(performance_test_config_.run_config.optimization_level); - - session_object_ = sf.OrtCreateSession(test_case->GetModelUrl()); - - auto provider_type = performance_test_config_.machine_config.provider_type_name; - // Place input tensor on cpu memory if mkldnn provider type to avoid CopyTensor logic in CopyInputAcrossDevices - // TODO: find a better way to do this. - if (provider_type == onnxruntime::kMklDnnExecutionProvider) { - provider_type = onnxruntime::kCpuExecutionProvider; + TestModelInfo* p_model; + if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("ort")) == 0) { + p_model = TestModelInfo::LoadOnnxModel(performance_test_config_.model_info.model_file_path.c_str()); + } else if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("tf")) == 0) { + p_model = TFModelInfo::Create(performance_test_config_.model_info.model_file_path.c_str()); + } else { + ORT_NOT_IMPLEMENTED(ToMBString(performance_test_config_.backend), " is not supported"); } + test_case_.reset(CreateOnnxTestCase(narrow_model_name, p_model, 0.0, 0.0)); - if (test_case->GetDataCount() <= 0) { - LOGS_DEFAULT(ERROR) << "there is no test data for model " << test_case->GetTestCaseName(); + // TODO: Place input tensor on cpu memory if mkldnn provider type to avoid CopyTensor logic in CopyInputAcrossDevices + if (test_case_->GetDataCount() <= 0) { + std::cout << "there is no test data for model " << test_case_->GetTestCaseName() << std::endl; return false; } - test_case->LoadTestData(0 /* id */, b_, feeds_, true); - - input_names_.resize(feeds_.size()); + test_case_->LoadTestData(0 /* id */, b_, feeds_, true); input_values_.resize(feeds_.size()); size_t input_index = 0; for (auto& kvp : feeds_) { - input_names_[input_index] = kvp.first.c_str(); input_values_[input_index] = kvp.second; ++input_index; } - size_t output_count; - ORT_THROW_ON_ERROR(OrtSessionGetOutputCount(session_object_, &output_count)); - output_names_.resize(output_count); - OrtAllocator* a; - ORT_THROW_ON_ERROR(OrtCreateDefaultAllocator(&a)); - for (size_t i = 0; i != output_count; ++i) { - char* output_name = nullptr; - ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session_object_, i, a, &output_name)); - assert(output_name != nullptr); - output_names_[i] = output_name; - a->Free(a, output_name); - } - output_names_raw_ptr.resize(output_count); - for (size_t i = 0; i != output_count; ++i) { - output_names_raw_ptr[i] = output_names_[i].c_str(); + + if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("ort")) == 0) { + session_ = new OnnxRuntimeTestSession(env_, performance_test_config_, p_model); +#ifdef HAVE_TENSORFLOW + } else if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("tf")) == 0) { + session_ = new TensorflowTestSession(performance_test_config_, p_model); +#endif + } else { + ORT_NOT_IMPLEMENTED(ToMBString(performance_test_config_.backend), " is not supported"); } - OrtReleaseAllocator(a); - output_values_.resize(output_count); + return true; } diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h index ec4e97b2f731e..f558a4f93d57a 100644 --- a/onnxruntime/test/perftest/performance_runner.h +++ b/onnxruntime/test/perftest/performance_runner.h @@ -10,17 +10,14 @@ // onnxruntime dependencies #include -#include -#include #include -#include -#include -#include #include -#include #include #include "test_configuration.h" #include "heap_buffer.h" +#include "test_session.h" + +class ITestCase; namespace onnxruntime { namespace perftest { @@ -36,7 +33,7 @@ struct PerformanceResult { std::ofstream outfile; outfile.open(path, std::ofstream::out | std::ofstream::app); if (!outfile.good()) { - LOGF_DEFAULT(ERROR, "failed to open result file"); + printf("failed to open result file"); return; } @@ -44,7 +41,7 @@ struct PerformanceResult { outfile << model_name << "," << time_costs[runs] << "," << peak_workingset_size << "," << average_CPU_usage << "," << runs << std::endl; } - if (time_costs.size() > 0 && f_include_statistics) { + if (!time_costs.empty() && f_include_statistics) { std::vector sorted_time = time_costs; size_t total = sorted_time.size(); @@ -70,9 +67,9 @@ struct PerformanceResult { class PerformanceRunner { public: - PerformanceRunner(OrtEnv* env, const PerformanceTestConfig& test_config) - : env_(env), performance_test_config_(test_config) {} + PerformanceRunner(OrtEnv* env, const PerformanceTestConfig& test_config); + ~PerformanceRunner(); Status Run(); inline const PerformanceResult& GetResult() const { return performance_result_; } @@ -81,9 +78,6 @@ class PerformanceRunner { performance_result_.DumpToFile(performance_test_config_.model_info.result_file_path, performance_test_config_.run_config.f_dump_statistics); } - ~PerformanceRunner() { - if (session_object_ != nullptr) OrtReleaseSession(session_object_); - } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerformanceRunner); private: @@ -108,17 +102,12 @@ class PerformanceRunner { OrtEnv* env_; PerformanceResult performance_result_; PerformanceTestConfig performance_test_config_; - // not owned - OrtSession* session_object_ = nullptr; - std::vector input_names_; + std::unordered_map feeds_; std::vector input_values_; HeapBuffer b_; - std::vector output_names_; - // The same size with output_names_. - // TODO: implement a customized allocator, then we can remove output_names_ to simplify this code - std::vector output_names_raw_ptr; - std::vector output_values_; + std::unique_ptr test_case_; + TestSession* session_; }; } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 182b120789e3e..59c3ce900085d 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -49,6 +49,7 @@ struct PerformanceTestConfig { ModelInfo model_info; MachineConfig machine_config; RunConfig run_config; + std::basic_string backend = ORT_TSTR("ort"); }; } // namespace perftest diff --git a/onnxruntime/test/perftest/test_session.h b/onnxruntime/test/perftest/test_session.h new file mode 100644 index 0000000000000..11cf1f308e6dc --- /dev/null +++ b/onnxruntime/test/perftest/test_session.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +namespace onnxruntime { +namespace perftest { +class TestSession { + public: + virtual std::chrono::duration Run(const OrtValue* const* input) = 0; + virtual ~TestSession() = default; +}; +} // namespace perftest +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/perftest/tf_test_session.h b/onnxruntime/test/perftest/tf_test_session.h new file mode 100644 index 0000000000000..2af1e847caa3b --- /dev/null +++ b/onnxruntime/test/perftest/tf_test_session.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "test_configuration.h" +#include "tensorflow/c/c_api.h" +#include "test_session.h" + +namespace onnxruntime { +namespace perftest { +class TensorflowTestSession : public TestSession { + private: + OrtCallback model_deleter; + std::vector feed_; + std::vector fetches_; + TF_Session* sess_; + TF_Graph* tf_graph_; + // This function is for both graph inputs and outputs + static TF_Output GetOutputFromGraph(const char* tensor_name, TF_Graph* tf_graph) { + TF_Output ret; + const char* start = tensor_name; + const char* sep = strchr(start, ':'); + if (sep == nullptr) { + ORT_THROW("invalid name:", tensor_name); + } + size_t name_len = sep - start; + std::string name(name_len, '\0'); + memcpy(const_cast(name.data()), start, name_len); + ret.oper = TF_GraphOperationByName(tf_graph, name.c_str()); + if (ret.oper == nullptr) ORT_THROW("input name: \"", name, "\" can not be find in the graph"); + start = sep + 1; + char* end; + ret.index = static_cast(strtol(start, &end, 10)); + if (start == end) { + ORT_THROW("invalid name:", tensor_name); + } + return ret; + } + + public: + TensorflowTestSession(PerformanceTestConfig& performance_test_config, const TestModelInfo* m) { + TF_Status* s = TF_NewStatus(); + tf_graph_ = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, ""); + TF_Buffer* graph_def = TF_NewBuffer(); + void* model_data; + auto st = Env::Default().ReadFileAsString(performance_test_config.model_info.model_file_path.c_str(), 0, model_data, + graph_def->length, model_deleter); + if (!st.IsOK()) + ORT_THROW("read file ", performance_test_config.model_info.model_file_path, " failed:", st.ErrorMessage()); + graph_def->data = model_data; + TF_GraphImportGraphDef(tf_graph_, graph_def, opts, s); + if (TF_GetCode(s) != TF_OK) ORT_THROW("load TF model failed:", TF_Message(s)); + TF_SessionOptions* session_opts = TF_NewSessionOptions(); + sess_ = TF_NewSession(tf_graph_, session_opts, s); + if (TF_GetCode(s) != TF_OK) ORT_THROW("load TF model failed:", TF_Message(s)); + feed_.resize(static_cast(m->GetInputCount())); + for (size_t i = 0; i != feed_.size(); ++i) { + feed_[i] = GetOutputFromGraph(m->GetInputName(i).c_str(), tf_graph_); + } + fetches_.resize(static_cast(m->GetOutputCount())); + for (size_t i = 0; i != fetches_.size(); ++i) { + fetches_[i] = GetOutputFromGraph(m->GetOutputName(i).c_str(), tf_graph_); + } + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorflowTestSession); + std::chrono::duration Run(const OrtValue* const* input) override { + size_t input_len = feed_.size(); + std::vector feed_tensors(input_len); + for (size_t i = 0; i != input_len; ++i) { + void* input_buffer = nullptr; + ORT_THROW_ON_ERROR(OrtGetTensorMutableData(const_cast(input[i]), &input_buffer)); + assert(input_buffer != nullptr); + OrtTensorTypeAndShapeInfo* shape; + ORT_THROW_ON_ERROR(OrtGetTensorShapeAndType(input[i], &shape)); + size_t dim_count = OrtGetNumOfDimensions(shape); + std::vector dims(dim_count); + OrtGetDimensions(shape, dims.data(), dim_count); + int64_t ele_count = OrtGetTensorShapeElementCount(shape); + size_t buffer_length = ele_count * sizeof(float); + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims.data(), static_cast(dims.size()), buffer_length); + assert(t != nullptr); + feed_tensors[i] = t; + assert(TF_TensorByteSize(t) == buffer_length); + memcpy(TF_TensorData(t), input_buffer, buffer_length); + } + std::vector output_tensors(fetches_.size()); + TF_Status* s = TF_NewStatus(); + auto start = std::chrono::high_resolution_clock::now(); + TF_SessionRun(sess_, nullptr, feed_.data(), feed_tensors.data(), static_cast(feed_.size()), fetches_.data(), + output_tensors.data(), static_cast(fetches_.size()), nullptr, 0, nullptr, s); + auto end = std::chrono::high_resolution_clock::now(); + if (TF_GetCode(s) != TF_OK) ORT_THROW("run TF model failed:", TF_Message(s)); + TF_DeleteStatus(s); + return end - start; + } + ~TensorflowTestSession() override { + if (model_deleter.f != nullptr) { + model_deleter.f(model_deleter.param); + } + TF_Status* s = TF_NewStatus(); + TF_DeleteSession(sess_, s); + TF_DeleteStatus(s); + } +}; + +} // namespace perftest +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d4be92b7acd31..a231a0fdaca0d 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -10,8 +10,9 @@ namespace test { void TestUnaryElementwiseOp(const char* szOp, std::vector& input_vals, std::function expected_func, - const std::unordered_map attribs = {}) { - OpTester test(szOp); + const std::unordered_map attribs = {}, + int opset_version = 7) { + OpTester test(szOp, opset_version); for (auto attr : attribs) test.AddAttribute(attr.first, attr.second); @@ -92,17 +93,7 @@ TEST(ActivationOpTest, ThresholdedRelu) { TestUnaryElementwiseOp("ThresholdedRelu", input_vals, [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}); -} - -TEST(ActivationOpTest, ScaledTanh) { - static constexpr float alpha = 2.0f; - static constexpr float beta = 1.5f; - - TestUnaryElementwiseOp("ScaledTanh", - input_vals, - [](float x) { return alpha * tanh(beta * x); }, - {{"alpha", alpha}, {"beta", beta}}); + {{"alpha", alpha}}, 10); } TEST(ActivationOpTest, Selu) { @@ -184,6 +175,24 @@ TEST(ActivationOpTest, PRelu_MultiChannel) { } #ifndef DISABLE_CONTRIB_OPS +TEST(ActivationOpTest, ThresholdedRelu_version_1_to_9) { + float alpha = 0.1f; + TestUnaryElementwiseOp("ThresholdedRelu", + input_vals, + [alpha](float x) { return (x >= alpha) ? x : 0; }, + {{"alpha", alpha}}, 1); +} + +TEST(ActivationOpTest, ScaledTanh) { + static constexpr float alpha = 2.0f; + static constexpr float beta = 1.5f; + + TestUnaryElementwiseOp("ScaledTanh", + input_vals, + [](float x) { return alpha * tanh(beta * x); }, + {{"alpha", alpha}, {"beta", beta}}); +} + TEST(ActivationOpTest, ParametricSoftplus) { static constexpr float alpha = 2.0f; static constexpr float beta = 1.5f; diff --git a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc index bf7fcf2e70dab..884a9dc2a359a 100644 --- a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc @@ -3,7 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "core/providers/cpu/tensor/crop.h" +#include "contrib_ops/cpu/crop.h" #include "core/util/math.h" using namespace ONNX_NAMESPACE; @@ -302,6 +302,33 @@ TEST(TensorOpTest, CastToString) { TestCastOp(int_16_input, int_string_data, shape, TensorProto::STRING); } +std::pair MeanStdev(std::vector& v) { + float sum = std::accumulate(v.begin(), v.end(), 0.0f); + float mean = sum / v.size(); + + std::vector diff(v.size()); + std::transform(v.begin(), v.end(), diff.begin(), + std::bind(std::minus(), std::placeholders::_1, mean)); + float sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0f); + float stdev = std::sqrt(sq_sum / v.size()); + + return std::make_pair(mean, stdev); +} + +void Normalize(std::vector& v, + std::pair& mean_stdev, bool normalize_variance) { + float mean = mean_stdev.first; + float stdev = mean_stdev.second; + + std::transform(v.begin(), v.end(), v.begin(), + std::bind(std::minus(), std::placeholders::_1, mean)); + + if (normalize_variance) { + std::transform(v.begin(), v.end(), v.begin(), + std::bind(std::divides(), std::placeholders::_1, stdev)); + } +} + #ifndef DISABLE_CONTRIB_OPS TEST(TensorOpTest, CropBorderOnly) { const int N = 2, C = 1, H = 3, W = 4; @@ -353,33 +380,32 @@ TEST(TensorOpTest, CropBorderAndScale) { test.AddOutput("output", {N, C, scale[0], scale[1]}, output); test.Run(); } -#endif -std::pair MeanStdev(std::vector& v) { - float sum = std::accumulate(v.begin(), v.end(), 0.0f); - float mean = sum / v.size(); +TEST(TensorOpTest, ImageScalerTest) { + const int64_t N = 1, C = 2, H = 2, W = 2; + std::vector X = { + 1.0f, 3.0f, + 3.0f, 5.0f, - std::vector diff(v.size()); - std::transform(v.begin(), v.end(), diff.begin(), - std::bind(std::minus(), std::placeholders::_1, mean)); - float sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0f); - float stdev = std::sqrt(sq_sum / v.size()); + 3.0f, 5.0f, + 7.0f, 9.0f}; - return std::make_pair(mean, stdev); -} + float scale = 2.0f; + std::vector bias = {1.0f, 2.0f}; -void Normalize(std::vector& v, - std::pair& mean_stdev, bool normalize_variance) { - float mean = mean_stdev.first; - float stdev = mean_stdev.second; + std::vector result = { + 3.0f, 7.0f, + 7.0f, 11.0f, - std::transform(v.begin(), v.end(), v.begin(), - std::bind(std::minus(), std::placeholders::_1, mean)); + 8.0f, 12.0f, + 16.0f, 20.0f}; - if (normalize_variance) { - std::transform(v.begin(), v.end(), v.begin(), - std::bind(std::divides(), std::placeholders::_1, stdev)); - } + OpTester test("ImageScaler"); + test.AddAttribute("scale", scale); + test.AddAttribute("bias", bias); + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {N, C, H, W}, result); + test.Run(); } void MeanVarianceNormalizationAcrossChannels(bool across_channels, bool normalize_variance) { @@ -475,6 +501,21 @@ void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_va test.Run(); } +TEST(TensorOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) { + // across_channels: true, normalize_variance: true + MeanVarianceNormalizationAcrossChannels(true, true); + + // across_channels: true, normalize_variance: false + MeanVarianceNormalizationAcrossChannels(true, false); + + // across_channels: false, normalize_variance: false + MeanVarianceNormalizationPerChannel(false, false); + + // across_channels: false, normalize_variance: true + MeanVarianceNormalizationPerChannel(false, true); +} +#endif + void MeanVarianceNormalizationFunctionDefaultPerChannel() { const int64_t N = 2, C = 2, H = 2, W = 3; @@ -562,18 +603,7 @@ void MeanVarianceNormalizationFunctionAcrossChannels(std::vector axes) } TEST(TensorOpTest, MeanVarianceNormalizationCPUTest) { - // across_channels: true, normalize_variance: true - MeanVarianceNormalizationAcrossChannels(true, true); - - // across_channels: true, normalize_variance: false - MeanVarianceNormalizationAcrossChannels(true, false); - - // across_channels: false, normalize_variance: false - MeanVarianceNormalizationPerChannel(false, false); - - // across_channels: false, normalize_variance: true - MeanVarianceNormalizationPerChannel(false, true); - + // axes: {0, 1, 2, 3} for across_channels MeanVarianceNormalizationFunctionAcrossChannels({0, 1, 2, 3}); @@ -581,33 +611,5 @@ TEST(TensorOpTest, MeanVarianceNormalizationCPUTest) { MeanVarianceNormalizationFunctionDefaultPerChannel(); } -#ifndef DISABLE_CONTRIB_OPS -TEST(TensorOpTest, ImageScalerTest) { - const int64_t N = 1, C = 2, H = 2, W = 2; - std::vector X = { - 1.0f, 3.0f, - 3.0f, 5.0f, - - 3.0f, 5.0f, - 7.0f, 9.0f}; - - float scale = 2.0f; - std::vector bias = {1.0f, 2.0f}; - - std::vector result = { - 3.0f, 7.0f, - 7.0f, 11.0f, - - 8.0f, 12.0f, - 16.0f, 20.0f}; - - OpTester test("ImageScaler"); - test.AddAttribute("scale", scale); - test.AddAttribute("bias", bias); - test.AddInput("input", {N, C, H, W}, X); - test.AddOutput("output", {N, C, H, W}, result); - test.Run(); -} -#endif } // namespace test } // namespace onnxruntime