diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index e0cacb6a75..313dd717be 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -3,6 +3,7 @@ add_onnx_mlir_library(OMKrnlToLLVM ConvertKrnlToLLVM.cpp KrnlFindIndex.cpp + KrnlCall.cpp KrnlEntryPoint.cpp KrnlGetRef.cpp KrnlGlobal.cpp diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index 1ed675781e..b29c57f977 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -577,6 +577,7 @@ std::unique_ptr createConvertKrnlToLLVMPass() { void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, MLIRContext *ctx, ArrayRef outputOMTensorOwnerships, bool singleEntryPoint) { + krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlEntryPointOpPattern( typeConverter, patterns, ctx, outputOMTensorOwnerships, singleEntryPoint); krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx); diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index 93b1212d18..7533ac11e4 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -32,6 +32,9 @@ void populateKrnlToLLVMConversion(mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, llvm::ArrayRef constantOutputs, bool singleEntryPoint); +void populateLoweringKrnlCallOpPattern(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + void populateLoweringKrnlEntryPointOpPattern(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, llvm::ArrayRef constantOutputs, bool singleEntryPoint); diff --git a/src/Conversion/KrnlToLLVM/KrnlCall.cpp b/src/Conversion/KrnlToLLVM/KrnlCall.cpp new file mode 100644 index 0000000000..ee1c817c34 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlCall.cpp @@ -0,0 +1,213 @@ + +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------- KrnlCall.cpp - Lower KrnlCallOp -----------------------===// +// +// Copyright 2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlCallOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlCallOpLowering : public ConversionPattern { +public: + explicit KrnlCallOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlCallOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + KrnlCallOpAdaptor krnlCallAdaptor(operands); + Location loc = op->getLoc(); + KrnlCallOp krnlCallOp = llvm::cast(op); + + // Get a symbol reference to the function, inserting it if necessary. + ModuleOp module = op->getParentOfType(); + llvm::SmallVector parameterTypeList; + llvm::SmallVector parameterList; + handleOneParameter(rewriter, op, krnlCallAdaptor.result(), + krnlCallOp.result(), parameterTypeList, parameterList); + + // Some type of operands has been converted. + // It is better to check the type of original operands. + // Thus, the two kinds of operands are used together. + auto itConverted = krnlCallAdaptor.parameters().begin(); + auto itOriginal = krnlCallOp.parameters().begin(); + for (; itConverted != krnlCallAdaptor.parameters().end(); + itConverted++, itOriginal++) { + handleOneParameter(rewriter, op, *itConverted, *itOriginal, + parameterTypeList, parameterList); + } + + // Handle the Attributes + for (auto namedAttr : op->getAttrs()) { + // Avoid the funcName() Attribute + if (namedAttr.getName().getValue().equals("funcName")) + continue; + handleOneAttribute(rewriter, getTypeConverter(), op, namedAttr.getValue(), + parameterTypeList, parameterList); + } + + auto callRef = getOrInsertCall( + rewriter, module, krnlCallOp.funcName(), parameterTypeList); + rewriter.create( + loc, callRef, ArrayRef({}), parameterList); + + rewriter.eraseOp(op); + return success(); + } + +private: + static void handleOneParameter(PatternRewriter &rewriter, Operation *op, + Value parameter, Value original, + llvm::SmallVector ¶meterTypeList, + llvm::SmallVector ¶meterList) { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + ModuleOp module = op->getParentOfType(); + const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter); + + // Check the original type, not after type conversion + Type ty = original.getType(); + if (ty.isa()) { + auto int64Ty = IntegerType::get(context, 64); + auto memRefTy = parameter.getType().dyn_cast(); + auto memRefRank = krnl::getRankFromMemRefType(memRefTy); + auto memRefRankVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank)); + Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); + + krnl::fillOMTensorWithMemRef(parameter, omTensor, false /*outOwning*/, + rewriter, loc, apiRegistry, module); + auto int8Ty = IntegerType::get(context, 8); + auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty); + parameterTypeList.emplace_back(opaquePtrTy); + parameterList.emplace_back(omTensor); + } else { + parameterTypeList.emplace_back(parameter.getType()); + parameterList.emplace_back(parameter); + } + } + + static void handleOneAttribute(PatternRewriter &rewriter, + TypeConverter *typeConverter, Operation *op, Attribute attribute, + llvm::SmallVector ¶meterTypeList, + llvm::SmallVector ¶meterList) { + auto *context = op->getContext(); + auto loc = op->getLoc(); + ModuleOp module = op->getParentOfType(); + + TypeSwitch(attribute) + .Case([&](StringAttr strAttr) { + StringRef attrValue = strAttr.getValue(); + LLVM::GlobalOp globalStr = + krnl::getOrCreateGlobalString(attrValue, loc, rewriter, module, + static_cast(typeConverter)); + Value strPtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter); + auto int8Ty = IntegerType::get(context, 8); + auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty); + parameterTypeList.emplace_back(opaquePtrTy); + parameterList.emplace_back(strPtr); + }) + .Case([&](IntegerAttr integerAttr) { + auto int64Ty = IntegerType::get(context, 64); + Value cst = + rewriter.create(loc, int64Ty, integerAttr); + parameterTypeList.emplace_back(int64Ty); + parameterList.emplace_back(cst); + }) + .Case([&](FloatAttr floatAttr) { + auto f64Ty = rewriter.getF64Type(); + Value cst = rewriter.create(loc, f64Ty, floatAttr); + parameterTypeList.emplace_back(f64Ty); + parameterList.emplace_back(cst); + }) + .Case([&](DenseElementsAttr denseAttr) { + // Use krnl.global to handle it + // Since the attribute is still in tensor type, the code has to cross + // onnx to krnl, and krnl to llvm. + // In future, the attributes should be converted in krnl.call builder. + // This code passed onnx-mlir-opt --convert-krnl-to-llvm test case, + // but failed in onnx-milr for the tensor type for the attribute + const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter); + auto tensorTy = denseAttr.getType().cast(); + auto memRefTy = + MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()); + memRefTy.dump(); + MultiDialectBuilder create(rewriter, loc); + Value constantGlobal = + create.krnl.constant(memRefTy, "constant_", denseAttr); + Value convertedConstantGlobal = + rewriter + .create( + loc, typeConverter->convertType(memRefTy), constantGlobal) + .getResult(0); + // constantGlobal.setType(typeConverter->convertType(memRefTy)); + + auto int64Ty = IntegerType::get(context, 64); + auto memRefRank = memRefTy.getRank(); + auto memRefRankVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank)); + Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); + + krnl::fillOMTensorWithMemRef(convertedConstantGlobal, omTensor, + false /*outOwning*/, rewriter, loc, apiRegistry, module); + auto int8Ty = IntegerType::get(context, 8); + auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty); + parameterTypeList.emplace_back(opaquePtrTy); + parameterList.emplace_back(omTensor); + }) + .Default([&](Attribute attr) { + llvm_unreachable("This type of Attribute used by krnl.call is not " + "yet implemented"); + }); + } + + FlatSymbolRefAttr getOrInsertCall(PatternRewriter &rewriter, ModuleOp module, + llvm::StringRef funcName, ArrayRef parameterTypeList) const { + auto *context = module.getContext(); + if (module.lookupSymbol(funcName)) + return SymbolRefAttr::get(context, funcName); + auto llvmVoidTy = LLVM::LLVMVoidType::get(context); + auto llvmFnType = + LLVM::LLVMFunctionType::get(llvmVoidTy, parameterTypeList, false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, llvmFnType); + return SymbolRefAttr::get(context, funcName); + } +}; + +void populateLoweringKrnlCallOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp index c86aa78e2a..fad3d65942 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp @@ -129,12 +129,27 @@ struct ONNXResizeOpLowering : public ConversionPattern { // Call external function when the mode is not "nearest" // Create KrnlCallOp and replace the du chain + // One of inputs, scales() and size(), has to be None. + // For now, None input is picked out by KrnlCall builder, + // and different function will be called accordingly. + // Another issue is the attributes with default value. + // Currently, it is assumed that all the optional attributes have + // the default value and does appear in the Attribute dictionry. + // ToFix: Handle attributes for general case if (resizeOp.mode() != "nearest") { - Value resizeCall = - rewriter.create(loc, alloc, op, operands, true); - rewriter.replaceOp(op, resizeCall); + assert(op->getAttrs().size() == 1 && + "ResizeOp: runtime lib is not supported for this case"); + if (!isFromNone(resizeOp.scales())) { + rewriter.create( + loc, "Resize_Scales", alloc, op, operands, true); + } else { + rewriter.create( + loc, "Resize_Size", alloc, op, operands, true); + } + rewriter.replaceOp(op, alloc); return success(); } + // It is much more efficient to generate codes directly if possible // Constants used in the loop body Value zero = create.math.constant(rewriter.getIntegerType(64), 0); diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 069705ca6c..0a93ef43fc 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -36,31 +36,36 @@ def StringType : Type()">, "string type">; // Require regions to have krnl.terminate terminator operation. def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">; -def KrnlCallOp : Op { +def KrnlCallOp : Op] + > { let summary = "call operation"; let description = [{ The call operation provides a generic way to call an external function at Krnl level. The `funcName` determines which function to call. - The `alloc` is the Value to store the function return. Since allocation - of the return MemRef involves shape inference usually with IndexExpr. - Thus most of time the allocation should stay in compiler, not in runtime library. + The `result` is the Value to store the function return. Currently only + one output is supported. `result` has to be resultated memref. + Since resultation of the output MemRef involves shape inference on ONNX Op, + resultation should be done at lowering ONNX Op, not within krnl.Call. + Another reason is that Krnl.call need to be defined with AllocationOp + interface if `result` is allcated inside this Op. The parameters can be of any type: MemRef, NoneType or any llvm type. Different types of parameters will be converted, if needed, when KrnlCallOp is lowered. Attributes will be converted to parameters too (To be Added). The function signature will be determined with the types of parameters. An LLVM::CallOp to either a runtime library or a llvm intrinsic function will be generated. + The krnl.call op will be lowered to llvm at krnl-to-llvm conversion. }]; - let arguments = (ins StrAttr:$funcName, AnyType:$alloc, Variadic:$parameters); - let results =(outs AnyType:$result); + let arguments = (ins StrAttr:$funcName, AnyType:$result, Variadic:$parameters); // builders to build KrnlCallOp from op and operands, helping conversion from onnx to krnl. // The name of function can be determined by the op name and elemnt type of the return, // or given to builder if the simple rule does not work // Attributes of the op will be propagated to KrnlCallOp if the copyAttrs is true - let builders = [OpBuilder<(ins "mlir::Value":$alloc, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>, - OpBuilder<(ins "std::string":$funcNameStr, "mlir::Value":$alloc, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>]; + let builders = [OpBuilder<(ins "mlir::Value":$result, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>, + OpBuilder<(ins "std::string":$funcNameStr, "mlir::Value":$result, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>]; } def KrnlDefineLoopsOp : Op { @@ -208,11 +213,18 @@ def KrnlTerminatorOp : Op { def KrnlRegionOp : Op { - let summary = "Affine Region for Krnl"; + let summary = "Affine boundary for krnl loops"; let description = [{ - This Op has AffineScope trait and is used to limit the scope of affine.for. - Otherwise, the affine.for might be illegal if a symbol is not defined at - the top of function, which has the AffineScope trait. + This Op has a region with AffineScope trait and is used to limit the + scope of `affine.for.'. The loop inside krnl.region can be affined if + its boundary is defined at the level of krnl.region. krnl.region does + not guarantee or require the loops inside it to be affine. + With krnl.oregion, a krnl loop may not be affine if its boundary symbol + is not defined inside a enclosing region without AffineScope trait. + In MLIR, FuncOp has the AffineScope trait. + The `krnl.region` will be removed after affine.for is lowered. + ToFix: current krnl.region does not have input and output. You cannot + create a new memref inside the region and use it outside of the region. }]; let regions = (region SizedRegion<1>:$bodyRegion); diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index 5e72717da9..5acc9e160d 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -34,6 +34,7 @@ #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" using namespace mlir; using namespace onnx_mlir; @@ -89,25 +90,34 @@ static std::string typeToString(Type ty) { void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, std::string funcNameStr, Value resultVal, Operation *op, ValueRange operands, bool copyAttrs) { - // Creates inputs + // Creates parameters for KrnlCall for Optional input (with NoneType) + // The semantics of optional input is ONNX Op specific and should be + // handled when lowering ONNX Op, not lowering KrnlCall. + // For now, None input is picked out from parameters of KrnCall. + // The Op will decide which external function to call based on the input. + // For future work: it might be possible to assume None type is + // always for a tensor and implemented with a nullptr in llvm. + // Then the None input can be handled inside the external function. + // Currently, onnx-mlir::NoneType is not handled by typeConverter of + // ONNXToKrnl conversion. SmallVector allInputs; allInputs.emplace_back(resultVal); - for (auto operand : operands) - allInputs.emplace_back(operand); + for (auto operand : operands) { + if (!isFromNone(operand)) + allInputs.emplace_back(operand); + } StringAttr funcNameAttr = builder.getStringAttr(funcNameStr); auto namedAttr = builder.getNamedAttr("funcName", funcNameAttr); if (!copyAttrs) { - build(builder, odsState, resultVal.getType(), funcNameAttr, resultVal, - operands); + build(builder, odsState, funcNameAttr, resultVal, allInputs); } else { std::vector attributes; attributes.emplace_back(namedAttr); for (auto namedAttr : op->getAttrs()) { attributes.emplace_back(namedAttr); } - build(builder, odsState, resultVal.getType(), ValueRange(allInputs), - attributes); + build(builder, odsState, TypeRange(), ValueRange(allInputs), attributes); } } @@ -123,6 +133,17 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, build(builder, odsState, funcNameStr, resultVal, op, operands, copyAttrs); } +void KrnlCallOp::getEffects( + SmallVectorImpl> + &effects) { + for (auto parameter : parameters()) { + effects.emplace_back(MemoryEffects::Read::get(), parameter, + SideEffects::DefaultResource::get()); + } + effects.emplace_back(MemoryEffects::Write::get(), result(), + SideEffects::DefaultResource::get()); +} + //===----------------------------------------------------------------------===// // KrnlDefineLoopsOp //===----------------------------------------------------------------------===// diff --git a/src/Runtime/CMakeLists.txt b/src/Runtime/CMakeLists.txt index 7648b62650..3d4c5803b9 100644 --- a/src/Runtime/CMakeLists.txt +++ b/src/Runtime/CMakeLists.txt @@ -13,6 +13,7 @@ add_onnx_mlir_library(cruntime STATIC OMIndexLookup.c OMInstrument.c OMRandomNormal.c + OMResize.c OMTensor.c OMTensorList.c OnnxDataType.c @@ -35,6 +36,7 @@ add_onnx_mlir_library(OMTensorUtils OMIndexLookup.cpp OMInstrument.cpp OMRandomNormal.cpp + OMResize.cpp OMTensor.cpp OMTensorList.cpp OnnxDataType.cpp diff --git a/src/Runtime/OMResize.c b/src/Runtime/OMResize.c new file mode 100644 index 0000000000..ccea164c63 --- /dev/null +++ b/src/Runtime/OMResize.c @@ -0,0 +1,15 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------- OMResize.c - OMResize C Implementation ----------------===// +// +// Copyright 2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains implementation of the OMResize functions. +// +//===----------------------------------------------------------------------===// + +#include "OMResize.inc" diff --git a/src/Runtime/OMResize.cpp b/src/Runtime/OMResize.cpp new file mode 100644 index 0000000000..24cba65bcb --- /dev/null +++ b/src/Runtime/OMResize.cpp @@ -0,0 +1,15 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- OMResize.cpp - OMResize C++ Implementation -------------===// +// +// Copyright 2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains implementation of the OMResize functions. +// +//===----------------------------------------------------------------------===// + +#include "OMResize.inc" diff --git a/src/Runtime/OMResize.inc b/src/Runtime/OMResize.inc new file mode 100644 index 0000000000..526e2535b8 --- /dev/null +++ b/src/Runtime/OMResize.inc @@ -0,0 +1,353 @@ +#ifdef __cplusplus +#include +#else +#include +#endif + +#include +#include +#include + +#include "onnx-mlir/Runtime/OMTensor.h" + +/** + * The runtime implementation of ONNXResizeOp is manually translated python + * implementation: onnx/onnx/backend/test/case/node/resize.py (onnx v1.9.0) + * This implementation can be improved in efficiency. + **/ + +static void linear_coeffs(float ratio, float coeffs_buffer[2], int mode) { + coeffs_buffer[0] = 1 - ratio; + coeffs_buffer[1] = ratio; +} + +static void nearest_coeffs(float ratio, float coeffs_buffer[2], int mode) { + /* integer ratio is handled outside */ + switch (mode) { + case 0: // round_prefer_float + coeffs_buffer[0] = ratio <= 0.5 ? 1 : 0; + coeffs_buffer[1] = ratio > 0.5 ? 1 : 0; + break; + case 1: // round_prefer_ceil + coeffs_buffer[0] = ratio < 0.5 ? 1 : 0; + coeffs_buffer[1] = ratio >= 0.5 ? 1: 0; + break; + case 2: // floor + coeffs_buffer[0] = 1; + coeffs_buffer[1] = 0; + break; + case 4: // ceil + coeffs_buffer[0] = 0; + coeffs_buffer[1] = 1; + break; + } +} + +static void cubic_coeffs(float ratio, float coeffs_buffer[3], int mode) { + float A = -0.75; + // A may have different value for different coordinate_transformation_mode + // Currently, only default mode is supported + coeffs_buffer[0] = + ((A * (ratio + 1) - 5 * A) * (ratio + 1) + 8 * A) * (ratio + 1) - 4 * A; + coeffs_buffer[1] = ((A + 2) * ratio - (A + 3)) * ratio * ratio + 1; + coeffs_buffer[2] = + ((A + 2) * (1 - ratio) - (A + 3)) * (1 - ratio) * (1 - ratio) + 1; + coeffs_buffer[3] = + ((A * ((1 - ratio) + 1) - 5 * A) * ((1 - ratio) + 1) + 8 * A) * + ((1 - ratio) + 1) - + 4 * A; +} + +static void get_neighbor(float x, int64_t n, int limit, float *data, + float *points, int exclude_outside) { + + // inline the python get_neighbor_idx function without real padding + // Avoid malloc/free + // nearest indx: identify the central idx first, then select from both side + // If the central idx is right to x (>= x), favor the left one + // == from the example + + int pad_width = ceil(n / 2); + x += pad_width; + float r = x - floor(x); + + int start, end; + int c; + if (r > 0.5) { + c = (int)(floor(x)) + 1; + } else { + c = (int)(floor(x)); + } + + int rest = n - 1; + int half = rest / 2; + if (rest == 0) { + start = end = c; + } else if (rest % 2 == 0) { + start = c - half; + end = c + half; + } else if (r == 0) { + start = c - half - 1; + end = c + half; + } else { + if (r > 0.5) { + end = c + half; + start = c - half - 1; + } else { + end = c + half + 1; + start = c - half; + } + } + + start -= pad_width; + end -= pad_width; + + for (int i = start; i <= end; i++) { + if (i < 0) { + if (exclude_outside) + points[i - start] = 0; + else + points[i - start] = data[0]; + } else if (i >= limit) { + if (exclude_outside) + points[i - start] = 0; + else + points[i - start] = data[limit - 1]; + } else { + points[i - start] = data[i]; + } + } +} + +typedef void (*Coeff_Func_t)(float, float *, int mode); + +static float interpolate_1d_with_x(OMTensor *data, float scale_factor, float x, + Coeff_Func_t get_coeffs, float *coeffs_buffer, int coeffs_n, float roi, + float extrapolation_value, int coordinate_transformation_mode, + int exclude_outside, int mode) { + + int64_t input_width = omTensorGetShape(data)[0]; + float x_ori; + switch (coordinate_transformation_mode) { + case 0: // half_pixel + x_ori = (x + 0.5) / scale_factor - 0.5; + break; + case 1: // asymmetric + x_ori = x / scale_factor; + break; + } + int64_t x_ori_int = floor(x_ori); + + float ratio = x_ori - x_ori_int; + if (ratio == 0) + ratio = 1; + + get_coeffs(ratio, coeffs_buffer, mode); + int64_t n = coeffs_n; + + // float points[coeffs_n]; + float *points = (float *)malloc(sizeof(float) * coeffs_n); + + get_neighbor(x_ori, n, input_width, (float *)omTensorGetDataPtr(data), points, + exclude_outside); + float sum = 0.; + for (int i = 0; i < n; i++) { + sum += coeffs_buffer[i] * points[i]; + } + free(points); + // free OMTensor data1 + return sum; +} + +static float interpolate_nd_with_x(OMTensor *data, int n, float *scale_factors, + float *xs, Coeff_Func_t get_coeffs, float *coeffs_buffer, int coeffs_n, + float roi, float extrapolation_value, int coordinate_transformation_mode, + int exclude_outside, int mode) { + if (n == 1) { + return interpolate_1d_with_x(data, scale_factors[0], xs[0], get_coeffs, + coeffs_buffer, coeffs_n, roi, extrapolation_value, + coordinate_transformation_mode, exclude_outside, mode); + } else { + int64_t input_width = omTensorGetShape(data)[0]; + float *tempData = (float *)malloc(sizeof(float) * input_width); + int64_t tempShape[] = {input_width}; + + int64_t stride = 1; + for (int i = 1; i < n; i++) { + stride *= omTensorGetShape(data)[i]; + } + for (int i = 0; i < input_width; i++) { + float *dataPtr = (float *)omTensorGetDataPtr(data) + i * stride; + OMTensor *data1 = omTensorCreate( + dataPtr, omTensorGetShape(data) + 1, n - 1, ONNX_TYPE_FLOAT); + tempData[i] = interpolate_nd_with_x(data1, n - 1, scale_factors + 1, + xs + 1, get_coeffs, coeffs_buffer, coeffs_n, roi, extrapolation_value, + coordinate_transformation_mode, exclude_outside, mode); + omTensorDestroy(data1); + } + OMTensor *tempT = omTensorCreate(tempData, tempShape, 1, ONNX_TYPE_FLOAT); + float ret = interpolate_1d_with_x(tempT, scale_factors[0], xs[0], + get_coeffs, coeffs_buffer, coeffs_n, roi, extrapolation_value, + coordinate_transformation_mode, exclude_outside, mode); + omTensorDestroy(tempT); + free(tempData); + return ret; + } +} + +static void coordinate_step(int64_t rank, int64_t *output_size, + int64_t *allCoordinates, int64_t currentRank, int64_t *currentIter, + int64_t *currentPosition) { + for (int i = 0; i < output_size[currentRank]; i++) { + if (currentRank == rank - 1) { + for (int j = 0; j < currentRank; j++) { + *(allCoordinates + (*currentPosition) * rank + j) = currentIter[j]; + } + *(allCoordinates + (*currentPosition) * rank + currentRank) = i; + (*currentPosition)++; + } else { + currentIter[currentRank] = i; + coordinate_step(rank, output_size, allCoordinates, currentRank + 1, + currentIter, currentPosition); + } + } +} + +static void generate_coordinates( + int64_t rank, int64_t *output_size, int64_t *allCoordinates) { + int64_t position = 0; + int64_t *currentIter = (int64_t *)malloc(sizeof(int64_t) * rank); + coordinate_step(rank, output_size, allCoordinates, 0, currentIter, &position); + free(currentIter); +} + +static void interpolate_nd_OMTensor(OMTensor *output_OMT, OMTensor *data, + int64_t mode, OMTensor *output_size_OMT, OMTensor *scale_factor_OMT, + Coeff_Func_t get_coeffs, int coeffs_n, OMTensor *roi, + float *extrapolation_value, int coordinate_transformation_mode, + int exclude_outside) { + assert(omTensorGetDataType(data) == ONNX_TYPE_FLOAT && + "Resize runtime: only float type is supported currently"); + + int64_t rank = omTensorGetRank(data); + int64_t *inputShape = omTensorGetShape(data); + float *scale_factor = NULL; + int64_t *output_size = NULL; + if (scale_factor_OMT != NULL) + scale_factor = (float *)omTensorGetDataPtr(scale_factor_OMT); + if (output_size_OMT != NULL) + output_size = (int64_t *)omTensorGetDataPtr(output_size_OMT); + if (scale_factor == NULL) { + scale_factor = (float *)malloc(sizeof(float) * rank); + for (int i = 0; i < rank; i++) { + scale_factor[i] = ((float)output_size[i]) / inputShape[i]; + } + } else { + output_size = (int64_t *)malloc(sizeof(int64_t) * rank); + for (int i = 0; i < rank; i++) { + output_size[i] = scale_factor[i] * inputShape[i]; + } + } + + int64_t outputSize = 1; + for (int i = 0; i < rank; i++) { + outputSize *= output_size[i]; + } + float *outputData = (float *)omTensorGetDataPtr(output_OMT); + + // int64_t allCoordinates[outputSize][rank]; + int64_t *allCoordinates = + (int64_t *)malloc(outputSize * rank * sizeof(int64_t)); + generate_coordinates(rank, output_size, allCoordinates); + + // float coeffs_buffer[coeffs_n]; // = {1.0, 0.}; + float *coeffs_buffer = (float *)malloc(sizeof(float) * coeffs_n); + + for (int i = 0; i < outputSize; i++) { + float *Xs = (float *)malloc(sizeof(float) * rank); + for (int j = 0; j < rank; j++) { + Xs[j] = *(allCoordinates + i * rank + j); + } + float r = interpolate_nd_with_x( + /*OMTensor */ data, + /* n */ 4, + /*float scale_factor*/ scale_factor, + /*floats *x*/ Xs, + /* Coeff_Func_t*/ get_coeffs, + /*float */ coeffs_buffer, + /*int coeffs_n*/ coeffs_n, + /*float roi*/ 0., + /*float extrapolation_value*/ 0., + /*int coordinate_transformation_mode*/ 0, + /*exclude */ 0, + /*mode */ 0); + outputData[i] = r; + free(Xs); + } + if (output_size_OMT == NULL) + free(output_size); + if (scale_factor_OMT == NULL) + free(scale_factor); + free(allCoordinates); + free(coeffs_buffer); +} + +void Resize_Scales( + OMTensor *output, OMTensor *data, OMTensor *scales, char *mode_str) { + Coeff_Func_t coeffs_f = NULL; + int coeffs_n = 0; + if (strcmp(mode_str, "nearest") == 0) { + coeffs_f = nearest_coeffs; + coeffs_n = 2; + } else if (strncmp(mode_str, "linear", 6) == 0) { + coeffs_f = linear_coeffs; + coeffs_n = 2; + } else if (strcmp(mode_str, "cubic") == 0) { + coeffs_f = cubic_coeffs; + coeffs_n = 4; + } else { + assert(0 && "Resize runtime: unsupported mode"); + } + interpolate_nd_OMTensor( + /*OMTensor */ output, + /*OMTensor */ data, + /*mode*/ 0, + /*OMTensor output size */ NULL, + /*OMTensor scales */ scales, + /* Coeff_Func_t*/ coeffs_f, + /*int coeffs_n*/ coeffs_n, + /* roi */ NULL, + /*float * extrapolation_value*/ 0, + /*int coordinate_transformation_mode*/ 0, + /*exclude */ 0); +} + +void Resize_Size( + OMTensor *output, OMTensor *data, OMTensor *size, char *mode_str) { + Coeff_Func_t coeffs_f = NULL; + int coeffs_n = 0; + if (strcmp(mode_str, "nearest") == 0) { + coeffs_f = nearest_coeffs; + coeffs_n = 2; + } else if (strncmp(mode_str, "linear", 6) == 0) { + coeffs_f = linear_coeffs; + coeffs_n = 2; + } else if (strcmp(mode_str, "cubic") == 0) { + coeffs_f = cubic_coeffs; + coeffs_n = 4; + } else { + assert(0 && "Resize runtime: unsupported mode"); + } + interpolate_nd_OMTensor( + /*OMTensor */ output, + /*OMTensor */ data, + /*mode*/ 0, + /*OMTensor output size */ size, + /*OMTensor scales */ NULL, + /* Coeff_Func_t*/ coeffs_f, + /*int coeffs_n*/ coeffs_n, + /* roi */ NULL, + /*float * extrapolation_value*/ 0, + /*int coordinate_transformation_mode*/ 0, + /*exclude */ 0); +} diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 4d9dcee995..caf8d8339f 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -679,12 +679,44 @@ def get_test_models(): "test_reshape_zero_dim_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{0:{-1}}, CONSTANT_INPUT:{-1}}, # Resize + + #All test cases in onnx v1.11.0. yes for currently supported + #yes name='test_resize_upsample_scales_nearest') + #yes name='test_resize_downsample_scales_nearest') + #yes name='test_resize_upsample_sizes_nearest') + #yes name='test_resize_downsample_sizes_nearest') + #yes name='test_resize_upsample_scales_linear') + #name='test_resize_upsample_scales_linear_align_corners') + #yes name='test_resize_downsample_scales_linear') + #name='test_resize_downsample_scales_linear_align_corners') + #yes name='test_resize_upsample_scales_cubic') + #name='test_resize_upsample_scales_cubic_align_corners') + #yes name='test_resize_downsample_scales_cubic') + #name='test_resize_downsample_scales_cubic_align_corners') + #yes name='test_resize_upsample_sizes_cubic') + #yes name='test_resize_downsample_sizes_cubic') + #name='test_resize_upsample_scales_cubic_A_n0p5_exclude_outside') + #name='test_resize_downsample_scales_cubic_A_n0p5_exclude_outside') + #name='test_resize_upsample_scales_cubic_asymmetric') + #name='test_resize_tf_crop_and_resize') + #name='test_resize_tf_crop_and_resize') + #name='test_resize_downsample_sizes_linear_pytorch_half_pixel') + #name='test_resize_upsample_sizes_nearest_floor_align_corners') + #yes name='test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric') + #yes name='test_resize_upsample_sizes_nearest_ceil_half_pixel') + "test_resize_upsample_scales_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, "test_resize_downsample_scales_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, "test_resize_upsample_sizes_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, "test_resize_downsample_sizes_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, "test_resize_upsample_sizes_nearest_ceil_half_pixel_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_upsample_scales_linear_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_downsample_scales_linear_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_upsample_scales_cubic_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_downsample_scales_cubic_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_upsample_sizes_cubic_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, + "test_resize_downsample_sizes_cubic_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}}, # Reverse Sequence "test_reversesequence_time_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index f2d070c801..b27620a76f 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2559,7 +2559,7 @@ func @test_resize2(%arg0 : tensor<3x4xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 = "onnx.Constant"() {value = dense<[0.000000e+00, 0.000000e+00, 1.000000e+00, 1.000000e+00]> : tensor<4xf32>} : () -> tensor<4xf32> %1 = "onnx.Constant"() {value = dense<[1.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> - %2 = "onnx.Resize"(%arg0, %0, %1, %cst) {coordinate_transformation_mode = "asymmetric", mode = "linear", nearest_mode = "round_prefer_floor"} : (tensor<3x4xf32>, tensor<4xf32>, tensor<2xf32>, none) -> tensor<*xf32> + %2 = "onnx.Resize"(%arg0, %0, %1, %cst) {mode = "linear"} : (tensor<3x4xf32>, tensor<4xf32>, tensor<2xf32>, none) -> tensor<*xf32> "func.return"(%2) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_resize2 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4xf32>) -> memref<3x12xf32> { @@ -2569,8 +2569,8 @@ func @test_resize2(%arg0 : tensor<3x4xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant 3.000000e+00 : f32 // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x12xf32> -// CHECK: [[VAR_4_:%.+]] = "krnl.call"([[RES_]], [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {coordinate_transformation_mode = "asymmetric", funcName = "onnx_Resize_f32", mode = "linear", nearest_mode = "round_prefer_floor"} : (memref<3x12xf32>, memref<3x4xf32>, memref<4xf32>, memref<2xf32>, none) -> memref<3x12xf32> -// CHECK: return [[VAR_4_]] : memref<3x12xf32> +// CHECK: "krnl.call"([[RES_]], [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {funcName = "Resize_Scales", mode = "linear"} : (memref<3x12xf32>, memref<3x4xf32>, memref<4xf32>, memref<2xf32>) -> () +// CHECK: return [[RES_]] : memref<3x12xf32> // CHECK: } }