diff --git a/CMakeLists.txt b/CMakeLists.txt index f17048d169..32e39a0e1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -116,6 +116,12 @@ elseif ((ONNX_USE_PROTOBUF_SHARED_LIBS AND Protobuf_USE_STATIC_LIBS) "ONNX_USE_PROTOBUF_SHARED_LIBS and Protobuf_USE_STATIC_LIBS must be opposites of each other.") endif() +# Use the new MSVC preprocessor to improve standard conformance. +if (CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Zc:preprocessor") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor") +endif() + # Suppress warnings in third party code. if (ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS) set(CMAKE_C_FLAGS_COPY ${CMAKE_C_FLAGS}) diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index a852bf30ae..1591c6497d 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -42,8 +42,9 @@ add_onnx_mlir_library(OMONNXToKrnl Tensor/DepthToSpace.cpp Tensor/Expand.cpp Tensor/Flatten.cpp - Tensor/Gather.cpp - Tensor/GatherElements.cpp + Tensor/Gather.cpp + Tensor/GatherElements.cpp + Tensor/GatherND.cpp Tensor/Identity.cpp Tensor/NonZero.cpp Tensor/OneHot.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 84fce39f9f..3962a6282f 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -87,6 +87,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, populateLoweringONNXTransposeOpPattern(patterns, typeConverter, ctx); populateLoweringONNXGatherOpPattern(patterns, typeConverter, ctx); populateLoweringONNXGatherElementsOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXGatherNDOpPattern(patterns, typeConverter, ctx); populateLoweringONNXIdentityOpPattern(patterns, typeConverter, ctx); populateLoweringONNXConstantOfShapeOpPattern(patterns, typeConverter, ctx); populateLoweringONNXConstantOpPattern(patterns, typeConverter, ctx); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 2ec7832fd4..c83c9b8756 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -453,10 +453,9 @@ Value foldOrEmitONNXTransposeOp(ConversionPatternRewriter &rewriter, } /// Emit MemRef ReinterpretCastOp to create a new view for 'data'. -/// The new view is created using the given 'memRefType' and 'outputDims'. +/// The new view is created using the given 'outputDims'. Value emitMemRefReinterpretCastOp(ConversionPatternRewriter &rewriter, - Location loc, Value data, const MemRefType &memRefType, - SmallVectorImpl<IndexExpr> &outputDims) { + Location loc, Value data, SmallVectorImpl<IndexExpr> &outputDims) { MemRefBuilder createMemRef(rewriter, loc); return createMemRef.reinterpretCast(data, outputDims); } diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 249fce812d..92819a9620 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -145,10 +145,9 @@ Value foldOrEmitONNXTransposeOp(ConversionPatternRewriter &rewriter, Location loc, Type resultType, Value input, ArrayAttr permAttr); /// Emit MemRef ReinterpretCastOp to create a new view for 'data'. -/// The new view is created using the given 'memRefType' and 'outputDims'. +/// The new view is created using the given 'outputDims'. Value emitMemRefReinterpretCastOp(ConversionPatternRewriter &rewriter, - Location loc, Value data, const MemRefType &memRefType, - SmallVectorImpl<IndexExpr> &outputDims); + Location loc, Value data, SmallVectorImpl<IndexExpr> &outputDims); /// Emit krnl iterate to compute argsort of a given MemRef along a given axis. /// Output MemRef has the same shape as the input MemRef but is of IndexType. @@ -321,6 +320,8 @@ void populateLoweringONNXGatherOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); void populateLoweringONNXGatherElementsOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); +void populateLoweringONNXGatherNDOpPattern( + RewritePatternSet &, TypeConverter &, MLIRContext *); void populateLoweringONNXPadConstantValuePadOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); void populateLoweringONNXPadOpPattern( diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp new file mode 100644 index 0000000000..7c524bfdd7 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -0,0 +1,269 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- GatherND.cpp - Lowering GatherND Op -----------------===// +// +// Copyright 2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX GatherND Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" +#include "llvm/Support/Debug.h" +#include <numeric> + +#define DEBUG_TYPE "gather_nd_onnx_to_krnl" + +using namespace mlir; + +namespace onnx_mlir { + +struct ONNXGatherNDOpLowering : public ConversionPattern { + ONNXGatherNDOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern( + typeConverter, ONNXGatherNDOp::getOperationName(), 1, ctx) {} + + // When true causes injection of print stmts in the generated code. + static constexpr bool emitPrintStmts = false; + + // Debug function used to emit code to print the supplied 'indices'. + static void printIndices( + StringRef title, const DimsExpr &indices, KrnlBuilder &createKrnl) { + llvm::Twine msg(title + ": ("); + createKrnl.printf(msg.str()); + int64_t n = (int64_t)indices.size(); + for (int64_t i = 0; i < n; ++i) { + Value val = indices[i].getValue(); + createKrnl.printf(val, val.getType()); + } + createKrnl.printf(")\n"); + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + ONNXGatherNDOpAdaptor operandAdaptor(operands); + ONNXGatherNDOp gatherNDOp = cast<ONNXGatherNDOp>(op); + Location loc = op->getLoc(); + MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder> create( + rewriter, loc); + IndexExprScope outerScope(&rewriter, loc); + + ONNXGatherNDOpShapeHelper shapeHelper(&gatherNDOp, &rewriter, + krnl::getDenseElementAttributeFromKrnlValue, + krnl::loadDenseElementArrayValueAtIndex); + auto shapecomputed = shapeHelper.computeShape(operandAdaptor); + assert(succeeded(shapecomputed) && "Could not compute output shape"); + + // Operands and attributes. + Value data = operandAdaptor.data(); + Value indices = operandAdaptor.indices(); + int64_t b = gatherNDOp.batch_dims(); + auto indicesType = indices.getType().cast<ShapedType>(); + auto dataType = data.getType().cast<ShapedType>(); + ArrayRef<int64_t> indicesShape = indicesType.getShape(); + ArrayRef<int64_t> dataShape = dataType.getShape(); + int64_t dataRank = dataShape.size(); + int64_t indicesRank = indicesShape.size(); + int64_t indicesLastDim = indicesShape[indicesRank - 1]; + + // Convert the output type to MemRefType. + Type convertedType = typeConverter->convertType(*op->result_type_begin()); + assert(convertedType && convertedType.isa<MemRefType>() && + "Failed to convert type to MemRefType"); + MemRefType outputMemRefType = convertedType.cast<MemRefType>(); + ArrayRef<int64_t> outputShape = outputMemRefType.getShape(); + int64_t outputRank = outputShape.size(); + + // Ensure the operation containts are satisfied. + assert(dataRank >= 1 && "The rank of 'data' must be >= 1"); + assert(indicesRank >= 1 && "The rank of 'indices' must be >= 1"); + assert((outputRank == dataRank + indicesRank - indicesLastDim - 1 - b) && + "Incorrect outut rank"); + assert(b >= 0 && "batch_dim should not be negative"); + assert(b < std::min(dataRank, indicesRank) && + "batch_dims must be smaller than the min(dataRank, indicesRank)"); + assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && + "indices.shape[-1] must be in the range [1, dataRank - b]"); + + // Reshape 'indices' to the 3D shape: + // [batchDimSize, indicesDimsSize, indices.shape[-1]]. + const int64_t batchDimsSize = std::accumulate(indicesShape.begin(), + indicesShape.begin() + b, 1, std::multiplies<int64_t>()); + const int64_t indicesDimsSize = std::accumulate(indicesShape.begin(), + indicesShape.end(), 1, std::multiplies<int64_t>()); + assert(batchDimsSize >= 0 && "batchDimsSize must be non-negative"); + assert(indicesDimsSize >= 0 && "indicesDimsSize must be non-negative"); + + LiteralIndexExpr BDS(batchDimsSize), + IDS(indicesDimsSize / (batchDimsSize * indicesLastDim)), + ILD(indicesLastDim); + DimsExpr newIndicesShape = {BDS, IDS, ILD}; + Value reshapedIndices = + create.mem.reinterpretCast(indices, newIndicesShape); + LLVM_DEBUG(llvm::dbgs() << "reshapedIndices: " << reshapedIndices << "\n"); + + // Reshape 'data' to shape [batchDimSize, data.shape[b:]] + DimsExpr newDataShape = {BDS}; + for (int64_t i = b; i < dataRank; ++i) { + assert(dataShape[i] > 0 && "Cannot support data with dynamic dimensions"); + LiteralIndexExpr dataDim(dataShape[i]); + newDataShape.emplace_back(dataDim); + } + int64_t reshapedDataRank = newDataShape.size(); + Value reshapedData = create.mem.reinterpretCast(data, newDataShape); + LLVM_DEBUG(llvm::dbgs() << "reshapedData: " << reshapedData << "\n"); + + // Allocate a 1D output buffer. + const int64_t outputDimsSize = std::accumulate( + outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); + Value outputDataBuffer = create.mem.alloc( + MemRefType::get({outputDimsSize}, outputMemRefType.getElementType())); + + // Initialize the index used to store the result values. + Value iZero = create.math.constantIndex(0); + Value iOne = create.math.constantIndex(1); + Value storeIndex = + create.mem.alloca(MemRefType::get({}, rewriter.getIndexType())); + create.krnl.store(iZero, storeIndex); + + // for (i,j) in (0..reshapedIndices.shape[0]), 0..reshapedIndices.shape[1]) + // { + // idx = tuple(reshapedIndices[i][j]) + // output.append(reshapedData[(i,) + idx]) + // } + // output.reshape(outputShape) + ValueRange loopDef = create.krnl.defineLoops(2); + DimsExpr lbs(2, LiteralIndexExpr(0)), + ubs = {newIndicesShape[0], newIndicesShape[1]}; + + if (emitPrintStmts) { + create.krnl.printTensor("reshapedIndices: ", reshapedIndices); + create.krnl.printTensor("reshapedData: ", reshapedData); + } + + create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, + [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + // Insert code inside the loop. + IndexExprScope innerLoopScope(createKrnl); + + // Access function for 'reshapedIndices'. The first 2 indices are + // equal to the loop indexes. + DimsExpr reshapedIndicesAccessFct; + getIndexExprList<DimIndexExpr>(loopInd, reshapedIndicesAccessFct); + + // Access function for 'reshapedData'. The first index is equal to the + // first loop index. + DimsExpr reshapedDataAccessFct; + IndexExpr ind = SymbolIndexExpr(loopInd[0]); + reshapedDataAccessFct.emplace_back(ind); + + // The last index of the access function for 'reshapedIndices' is + // given by the values of indices.shape[-1]. + // The loaded values from 'reshapedIndices' are the next set of + // indices to push to the `reshapedDataAccessFct`. + for (unsigned i = 0; i < indicesLastDim; ++i) { + IndexExpr ind = LiteralIndexExpr(i); + reshapedIndicesAccessFct.emplace_back(ind); + + if (emitPrintStmts) + printIndices("indices", reshapedIndicesAccessFct, createKrnl); + + Value indexVal = + createKrnl.loadIE(reshapedIndices, reshapedIndicesAccessFct); + reshapedIndicesAccessFct.pop_back(); + + if (emitPrintStmts) { + createKrnl.printf("index = ", indexVal, indexVal.getType()); + createKrnl.printf("\n"); + } + + IndexExpr index = NonAffineIndexExpr(indexVal); + reshapedDataAccessFct.emplace_back(index); + } + + if (indicesLastDim == dataRank - b) { + // When indices.shape[-1] is equal to (rank(data) - b) the + // `reshapedDataAccessFct` computed so far has the same number of + // indices as the rank of 'reshapedData'. + assert((int64_t)reshapedDataAccessFct.size() == reshapedDataRank && + "Access function should have the same rank as reshapedData"); + + if (emitPrintStmts) + printIndices("data indices", reshapedDataAccessFct, createKrnl); + + // Gather value from the 'data' tensor and store it into + // 'outputDataBuffer'. + Value val = createKrnl.loadIE(reshapedData, reshapedDataAccessFct); + Value storeIndexVal = createKrnl.load(storeIndex); + createKrnl.store(val, outputDataBuffer, storeIndexVal); + + // Bump up the storeIndex. + createKrnl.store(create.math.add(storeIndexVal, iOne), storeIndex); + } else { + assert((indicesLastDim < dataRank - b) && + "Expecting indices.shape[-1] to be smaller than " + "rank(indices) - b"); + + // When indices.shape[-1] is less than (rank(data) - b) the + // `reshapedDataAccessFct` computed so far yields a slice which + // needs to be inserted into the output buffer. + int64_t reshapedDataLastDim = dataShape[dataRank - 1]; + for (int64_t i = 0; i < reshapedDataLastDim; ++i) { + IndexExpr ind = LiteralIndexExpr(i); + reshapedDataAccessFct.emplace_back(ind); + assert( + (int64_t)reshapedDataAccessFct.size() == reshapedDataRank && + "Access function should have the same rank as reshapedData"); + + if (emitPrintStmts) + printIndices("data indices", reshapedDataAccessFct, createKrnl); + + // Gather value from the 'data' tensor and store it into + // 'outputDataBuffer'. + Value val = + createKrnl.loadIE(reshapedData, reshapedDataAccessFct); + reshapedDataAccessFct.pop_back(); + + if (emitPrintStmts) { + createKrnl.printf("val = ", val, val.getType()); + createKrnl.printf("\n"); + } + + Value storeIndexVal = createKrnl.load(storeIndex); + createKrnl.store(val, outputDataBuffer, storeIndexVal); + + // Bump up the storeIndex. + createKrnl.store( + create.math.add(storeIndexVal, iOne), storeIndex); + } + } + }); + + // Finally reshape 'outputDataBuffer' to the shape of the output. + DimsExpr newOutputShape; + for (int64_t dim : outputShape) { + LiteralIndexExpr outputDim(dim); + newOutputShape.emplace_back(outputDim); + } + + Value reshapedOutput = + create.mem.reinterpretCast(outputDataBuffer, newOutputShape); + LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n"); + + rewriter.replaceOp(op, reshapedOutput); + + return success(); + } +}; + +void populateLoweringONNXGatherNDOpPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert<ONNXGatherNDOpLowering>(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp index 3ecaa75a1f..583b6beffe 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp @@ -51,7 +51,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { // Lower to ReinterpretCastOp so that the data is never copied or modified. Value newView = emitMemRefReinterpretCastOp( - rewriter, loc, data, memRefType, shapeHelper.dimsForOutput()); + rewriter, loc, data, shapeHelper.dimsForOutput()); LLVM_DEBUG(llvm::dbgs() << "newView: " << newView << "\n"); rewriter.replaceOp(op, newView); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp index 2d0f99d523..d674d2b12a 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp @@ -26,15 +26,9 @@ LogicalResult ONNXSqueezeOpLoweringCommon(Operation *op, Adaptor operandAdaptor(operands); Op squeezeOp = dyn_cast_or_null<Op>(op); - auto loc = op->getLoc(); + Location loc = op->getLoc(); Value data = operandAdaptor.data(); - // Convert the output type to MemRefType. - Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa<MemRefType>() && - "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast<MemRefType>(); - ShapeHelper shapeHelper(&squeezeOp, &rewriter, krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); @@ -43,7 +37,7 @@ LogicalResult ONNXSqueezeOpLoweringCommon(Operation *op, // Lower to ReinterpretCastOp so that the data is never copied or modified. Value newView = emitMemRefReinterpretCastOp( - rewriter, loc, data, memRefType, shapeHelper.dimsForOutput()); + rewriter, loc, data, shapeHelper.dimsForOutput()); rewriter.replaceOp(op, newView); return success(); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp index 7ce98f3e4f..76ad2e80b1 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp @@ -26,15 +26,9 @@ LogicalResult ONNXUnsqueezeOpLoweringCommon(Operation *op, Adaptor operandAdaptor(operands); Op unsqueezeOp = dyn_cast_or_null<Op>(op); - auto loc = op->getLoc(); + Location loc = op->getLoc(); Value data = operandAdaptor.data(); - // Convert the output type to MemRefType. - Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa<MemRefType>() && - "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast<MemRefType>(); - ShapeHelper shapeHelper(&unsqueezeOp, &rewriter, krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); @@ -43,7 +37,7 @@ LogicalResult ONNXUnsqueezeOpLoweringCommon(Operation *op, // Lower to ReinterpretCastOp so that the data is never copied or modified. Value newView = emitMemRefReinterpretCastOp( - rewriter, loc, data, memRefType, shapeHelper.dimsForOutput()); + rewriter, loc, data, shapeHelper.dimsForOutput()); rewriter.replaceOp(op, newView); return success(); } diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 5c7af54d06..cb743b2326 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -25,6 +25,38 @@ using namespace mlir; namespace onnx_mlir { +static StringRef getFormat(const Type &inputType) { + StringRef format; + TypeSwitch<Type>(inputType) + .Case<Float16Type>([&](Float16Type) { format = "%g"; }) + .Case<Float32Type>([&](Float32Type) { format = "%f"; }) + .Case<Float64Type>([&](Float64Type) { format = "%f"; }) + .Case<IntegerType>([&](IntegerType type) { + switch (type.getWidth()) { + case 1: + case 8: + case 16: + case 32: + format = type.isUnsigned() ? "%u" : "%d"; + break; + case 64: + format = type.isUnsigned() ? "%llu" : "%lld"; + break; + } + }) + .Case<IndexType>([&](IndexType) { format = "%lld"; }) + .Case<onnx_mlir::krnl::StringType>( + [&](onnx_mlir::krnl::StringType) { format = "%s"; }) + .Case<LLVM::LLVMPointerType>( + [&](LLVM::LLVMPointerType) { format = "%s"; }) + .Default([&](Type type) { + llvm::errs() << "type: " << type << "\n"; + llvm_unreachable("Unhandled type"); + }); + + return format; +} + //====---------------- Support for Krnl Builder ----------------------===// Value KrnlBuilder::load(Value memref, ValueRange indices) const { @@ -208,37 +240,14 @@ void KrnlBuilder::printf(StringRef msg) const { } void KrnlBuilder::printf(StringRef msg, Value input, Type inputType) const { - StringRef format; - TypeSwitch<Type>(inputType) - .Case<mlir::Float16Type>([&](mlir::Float16Type) { format = "%g\n"; }) - .Case<mlir::Float32Type>([&](mlir::Float32Type) { format = "%g\n"; }) - .Case<mlir::Float64Type>([&](mlir::Float64Type) { format = "%g\n"; }) - .Case<IntegerType>([&](IntegerType type) { - switch (type.getWidth()) { - case 1: - case 8: - case 16: - case 32: - format = type.isUnsigned() ? "%u\n" : "%d\n"; - break; - case 64: - format = type.isUnsigned() ? "%llu\n" : "%lld\n"; - break; - } - }) - .Case<IndexType>([&](IndexType) { format = "%lld\n"; }) - .Case<onnx_mlir::krnl::StringType>( - [&](onnx_mlir::krnl::StringType) { format = "%s\n"; }) - .Case<LLVM::LLVMPointerType>( - [&](LLVM::LLVMPointerType) { format = "%s\n"; }) - .Default([&](Type type) { - llvm::errs() << "type: " << type << "\n"; - llvm_unreachable("Unhandled type"); - }); - + StringRef format = getFormat(inputType); std::string concat(msg.str() + format.str()); StringRef newFormat(concat); b.create<KrnlPrintOp>(loc, newFormat, input); } +void KrnlBuilder::printf(Value input, Type inputType) const { + StringRef format = getFormat(inputType); + b.create<KrnlPrintOp>(loc, format, input); +} } // namespace onnx_mlir diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index c2c2317bda..f253979996 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -136,6 +136,7 @@ struct KrnlBuilder : public DialectBuilder { void printf(mlir::StringRef msg) const; void printf( mlir::StringRef msg, mlir::Value input, mlir::Type inputType) const; + void printf(mlir::Value input, mlir::Type inputType) const; // Onnx-mlir runtime functions. void randomNormal(mlir::Value alloc, mlir::Value numberOfRandomValues, diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 49d5634a15..a0779def68 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -31,6 +31,7 @@ add_onnx_mlir_library(OMONNXOps ShapeInference/Flatten.cpp ShapeInference/Gather.cpp ShapeInference/GatherElements.cpp + ShapeInference/GatherND.cpp ShapeInference/Gemm.cpp ShapeInference/LRN.cpp ShapeInference/MatMul.cpp diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 3d35575c98..c10260f919 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -31,6 +31,7 @@ #include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" #include "src/Support/Diagnostic.hpp" +#include <algorithm> #include <string> using namespace mlir; @@ -3374,6 +3375,114 @@ LogicalResult ONNXGatherElementsOp::inferShapes( ONNXGatherElementsOp, ONNXGatherElementsOpAdaptor>(*this, elementType); } +//===----------------------------------------------------------------------===// +// GatherND +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGatherNDOp::verify() { + ONNXGatherNDOpAdaptor operandAdaptor(*this); + if (llvm::any_of(operandAdaptor.getOperands(), + [](const Value &op) { return !hasShapeAndRank(op); })) + return success(); // Won't be able to do any checking at this stage. + + // Get operands and attributes. + Value data = operandAdaptor.data(); + Value indices = operandAdaptor.indices(); + auto dataType = data.getType().cast<ShapedType>(); + auto indicesType = indices.getType().cast<ShapedType>(); + int64_t dataRank = dataType.getRank(); + int64_t indicesRank = indicesType.getRank(); + int64_t b = batch_dims(); + + // 'data' and 'indices' must have rank strictly greater than zero. + if (dataRank < 1) + return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError( + *this->getOperation(), data, dataRank, "> 0"); + if (indicesRank < 1) + return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError( + *this->getOperation(), indices, indicesRank, "> 0"); + + ArrayRef<int64_t> dataShape = dataType.getShape(); + ArrayRef<int64_t> indicesShape = indicesType.getShape(); + int64_t indicesLastDim = indicesShape[indicesRank - 1]; + + // b must be smaller than min(rank(data), rank(indices). + int64_t minDataAndIndicesRank = std::min(dataRank, indicesRank); + if (b >= minDataAndIndicesRank) + return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError( + *this->getOperation(), "batch_dims", b, + onnx_mlir::Diagnostic::Range<int64_t>(0, minDataAndIndicesRank - 1)); + + // The first b dimensions of the shape of 'indices' and 'data' must be equal. + for (int64_t i = 0; i < b; ++i) { + int64_t dataDim = dataShape[i]; + int64_t indicesDim = indicesShape[i]; + if (indicesDim < 0 || dataDim < 0) + continue; + if (indicesDim != dataDim) + return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError( + *this->getOperation(), indices, i, indicesShape[i], + std::to_string(dataShape[i])); + } + + // Let r = rank(data), indices.shape[-1] must be in the range [1, r-b]. + if (indicesLastDim == 0) + return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError( + *this->getOperation(), indices, indicesRank - 1, indicesLastDim, + ">= 1"); + if (indicesLastDim > dataRank - b) + return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError( + *this->getOperation(), indices, indicesRank - 1, indicesLastDim, + "<= " + std::to_string(dataRank - b)); + + // All values in 'indices' are expected to satisfy the inequality: + // -data.shape[i] <= indices[...,i] <= (data.shape[i]-1)]. + for (int64_t i = 0; i < indicesRank; ++i) { + int64_t dataDimAtAxis = dataShape[i]; + if (dataDimAtAxis < 0) + continue; + + if (DenseElementsAttr valueAttribute = + getDenseElementAttributeFromONNXValue(indices)) + for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) { + static int n = 0; + int64_t index = value.getInt(); + if (index < -dataDimAtAxis || index > dataDimAtAxis - 1) + return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError( + *this->getOperation(), "indices[" + std::to_string(n) + "]", + index, + onnx_mlir::Diagnostic::Range<int64_t>( + -dataDimAtAxis, dataDimAtAxis - 1)); + n++; + } + } + + return success(); +} + +LogicalResult ONNXGatherNDOp::inferShapes( + std::function<void(mlir::Region &)> doShapeInference) { + // Cannot infer the shape of the output if the inputs shape is not yet known. + if (llvm::any_of( + this->getOperands(), [](Value op) { return !hasShapeAndRank(op); })) + return success(); + + // The output rank is given by: + // rank(output) = rank(indices) + rank(data) - indices_shape[-1] - 1 - b. + // Therefore 'indices.shape[-1]' must be known in order to compute the output + // shape. + ArrayRef<int64_t> indicesShape = + indices().getType().cast<ShapedType>().getShape(); + int64_t indicesRank = indicesShape.size(); + if (indicesShape[indicesRank - 1] < 0) + return success(); // cannot infer the oputput shape yet. + + auto elementType = data().getType().cast<ShapedType>().getElementType(); + return shapeHelperInferShapes<ONNXGatherNDOpShapeHelper, ONNXGatherNDOp, + ONNXGatherNDOpAdaptor>(*this, elementType); + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOfShape //===----------------------------------------------------------------------===// @@ -3892,10 +4001,6 @@ LogicalResult ONNXFloorOp::inferShapes( return success(); } -LogicalResult ONNXGatherNDOp::inferShapes( - std::function<void(mlir::Region &)> doShapeInference) { - return emitError(NOT_IMPLEMENTED_MESSAGE); -} LogicalResult ONNXGreaterOp::inferShapes( std::function<void(mlir::Region &)> doShapeInference) { Builder b(getContext()); diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index d9c3df2e83..b25f1135ea 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -1910,6 +1910,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND", return {20}; } }]; + let hasVerifier = 1; } def ONNXGemmOp:ONNX_Op<"Gemm", diff --git a/src/Dialect/ONNX/ShapeInference/GatherND.cpp b/src/Dialect/ONNX/ShapeInference/GatherND.cpp new file mode 100644 index 0000000000..cffb7fe7fc --- /dev/null +++ b/src/Dialect/ONNX/ShapeInference/GatherND.cpp @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------- GatherND.cpp - Shape Inference for GatherND Op ------------===// +// +// This file implements shape inference for the ONNX GatherND Operator. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" +#include <algorithm> + +using namespace mlir; + +namespace onnx_mlir { + +LogicalResult ONNXGatherNDOpShapeHelper::computeShape( + ONNXGatherNDOpAdaptor operandAdaptor) { + Value data = operandAdaptor.data(); + Value indices = operandAdaptor.indices(); + MemRefBoundsIndexCapture dataBounds(data); + MemRefBoundsIndexCapture indicesBounds(indices); + DimsExpr dataDims, indicesDims; + dataBounds.getDimList(dataDims); + indicesBounds.getDimList(indicesDims); + + int64_t dataRank = dataDims.size(); + int64_t indicesRank = indicesDims.size(); + int64_t b = op->batch_dims(); + + assert(indices.getType().isa<ShapedType>() && "Expecting a shaped type"); + auto indicesType = indices.getType().cast<ShapedType>(); + ArrayRef<int64_t> indicesShape = indicesType.getShape(); + int64_t indicesLastDim = indicesShape[indicesRank - 1]; + int64_t outputRank = dataRank + indicesRank - indicesLastDim - 1 - b; + + // Ensure the operator contraints are statisfied. + assert(dataRank >= 1 && "dataRank should be >= 1"); + assert(indicesRank >= 1 && "indicesRank should be >= 1"); + assert(b >= 0 && "batch_dim should not be negative"); + assert(b < std::min(dataRank, indicesRank) && + "batch_dims must be smaller than the min(dataRank, indicesRank)"); + assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && + "indices.shape[-1] must be in the range [1, dataRank - b]"); + + // Save the first 'b' dimension of the shape of the 'indices' tensor. + DimsExpr batchDims; + for (int64_t i = 0; i < b; ++i) + batchDims.emplace_back(indicesDims[i]); + + // output.shape = batchDims + list(indices.shape)[b:-1] + for (int64_t i = 0; i < b; ++i) + dimsForOutput().emplace_back(batchDims[i]); + for (int64_t i = b; i < indicesRank - 1; ++i) + dimsForOutput().emplace_back(indicesDims[i]); + + // When indices.shape[-1] < data_rank - b, + // output_shape += list(data.shape)[batch_dims + indices.shape[-1]:] + if (indicesLastDim < dataRank - b) + for (int64_t i = b + indicesLastDim; i < dataRank; ++i) + dimsForOutput().emplace_back(dataDims[i]); + + assert((int64_t)dimsForOutput().size() == outputRank && + "Incorrect shape computation"); + + return success(); +} + +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp index 46ba4c4c4a..fec182616e 100644 --- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp +++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp @@ -417,6 +417,7 @@ template struct ONNXOpShapeHelper<ONNXExpandOp>; template struct ONNXOpShapeHelper<ONNXFlattenOp>; template struct ONNXOpShapeHelper<ONNXGatherOp>; template struct ONNXOpShapeHelper<ONNXGatherElementsOp>; +template struct ONNXOpShapeHelper<ONNXGatherNDOp>; template struct ONNXOpShapeHelper<ONNXGemmOp>; template struct ONNXOpShapeHelper<ONNXMatMulOp>; template struct ONNXOpShapeHelper<ONNXMaxPoolSingleOutOp>; diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp index 5cfa392762..3c2f6e8160 100644 --- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp +++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp @@ -230,6 +230,7 @@ DECLARE_SHAPE_HELPER(ONNXDepthToSpaceOp) DECLARE_SHAPE_HELPER(ONNXFlattenOp) DECLARE_SHAPE_HELPER(ONNXGatherOp) DECLARE_SHAPE_HELPER(ONNXGatherElementsOp) +DECLARE_SHAPE_HELPER(ONNXGatherNDOp) DECLARE_SHAPE_HELPER(ONNXLRNOp) DECLARE_SHAPE_HELPER(ONNXReduceSumOp) DECLARE_SHAPE_HELPER(ONNXReshapeOp) diff --git a/src/Runtime/OMTensor.inc b/src/Runtime/OMTensor.inc index a31d87284e..6658d7a54f 100644 --- a/src/Runtime/OMTensor.inc +++ b/src/Runtime/OMTensor.inc @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #ifdef __cplusplus +#include <array> #include <cassert> #include <complex> #include <map> @@ -161,10 +162,10 @@ static inline void printElement( printf("%lld", (long long)((int64_t *)dataPtr)[elemOffset]); break; case ONNX_TYPE_FLOAT: - printf("%g", ((float *)dataPtr)[elemOffset]); + printf("%f", ((float *)dataPtr)[elemOffset]); break; case ONNX_TYPE_DOUBLE: - printf("%g", ((double *)dataPtr)[elemOffset]); + printf("%f", ((double *)dataPtr)[elemOffset]); break; case ONNX_TYPE_STRING: printf("%s", ((const char **)dataPtr)[elemOffset]); @@ -423,63 +424,78 @@ void omTensorPrint(const char *msg, const OMTensor *tensor) { printf("\trank = %lld\n", (long long)rank); printf("\tdataType = %s\n", getDataTypeName(dataType)); printf("\tnumElems = %lld\n", (long long)omTensorGetNumElems(tensor)); + printf("\tshape: "); + for (int64_t i = 0; i < rank; i++) + printf("[%lld]", (long long)shape[i]); + printf("\n"); printf("\tstrides: "); for (int64_t i = 0; i < rank; i++) printf("[%lld]", (long long)strides[i]); printf("\n"); - printf("\tdata: (["); +#define LOOP_1(INDEX, IV, UB) \ + printf("["); \ + for (int64_t IV = 0; IV < UB; ++IV) { \ + if (IV) \ + printf(", "); \ + indexes[INDEX] = IV; \ + int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank); \ + printElement(dataPtr, elemOffset, dataType); \ + } \ + printf("]"); + +#define LOOP_2(INDEX, IV, UB, ...) \ + printf("["); \ + for (int64_t IV = 0; IV < UB; ++IV) { \ + if (IV) \ + printf(", "); \ + indexes[INDEX] = IV; \ + LOOP_1(INDEX + 1, __VA_ARGS__) \ + } \ + printf("]"); + +#define LOOP_3(INDEX, IV, UB, ...) \ + printf("["); \ + for (int64_t IV = 0; IV < UB; ++IV) { \ + if (IV) \ + printf(", "); \ + indexes[INDEX] = IV; \ + LOOP_2(INDEX + 1, __VA_ARGS__) \ + } \ + printf("]"); + +#define LOOP_4(INDEX, IV, UB, ...) \ + printf("["); \ + for (int64_t IV = 0; IV < UB; ++IV) { \ + if (IV) \ + printf(", "); \ + indexes[INDEX] = IV; \ + LOOP_3(INDEX + 1, __VA_ARGS__) \ + } \ + printf("]"); + + printf("\tdata: ("); switch (rank) { - case 1: - for (int64_t i = 0; i < shape[0]; ++i) { - if (i) - printf(", "); - int64_t indexes[] = {i}; - int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank); - printElement(dataPtr, elemOffset, dataType); - } - break; - case 2: - for (int64_t i = 0; i < shape[0]; ++i) { - if (i) - printf(", "); - printf("["); - for (int64_t j = 0; j < shape[1]; ++j) { - if (j) - printf(", "); - int64_t indexes[] = {i, j}; - int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank); - printElement(dataPtr, elemOffset, dataType); - } - printf("]"); - } - break; - case 3: - for (int64_t i = 0; i < shape[0]; ++i) { - if (i) - printf(", "); - printf("["); - for (int64_t j = 0; j < shape[1]; ++j) { - if (j) - printf(", "); - printf("["); - for (int64_t k = 0; k < shape[2]; ++k) { - if (k) - printf(", "); - int64_t indexes[] = {i, j, k}; - int64_t elemOffset = - computeElemOffset(tensor->_strides, indexes, rank); - printElement(dataPtr, elemOffset, dataType); - } - printf("]"); - } - printf("]"); - } - break; + case 1: { + int64_t indexes[1]; + LOOP_1(0, i, shape[0]) + } break; + case 2: { + int64_t indexes[2]; + LOOP_2(0, i, shape[0], j, shape[1]) + } break; + case 3: { + int64_t indexes[3]; + LOOP_3(0, i, shape[0], j, shape[1], k, shape[2]) + } break; + case 4: { + int64_t indexes[4]; + LOOP_4(0, i, shape[0], j, shape[1], k, shape[2], l, shape[3]) + } break; default: assert(false && "not implemented"); } - printf("])\n"); + printf(")\n"); } #ifdef __cplusplus @@ -657,8 +673,8 @@ inline bool omTensorAreTwoOmtsClose( eqAllclose.begin(), eqAllclose.end(), [&](T eq) { return eq >= 0; }); if (!satisfied) { - // Figure out where and what went wrong, this can be slow; but hopefully we - // don't need this often. + // Figure out where and what went wrong, this can be slow; but hopefully + // we don't need this often. for (const auto &idx : omTensorComputeIndexSet(a)) { T aElem = omTensorGetElem<T>(a, idx); T bElem = omTensorGetElem<T>(b, idx); diff --git a/src/Support/Diagnostic.hpp b/src/Support/Diagnostic.hpp index 3b44e6a5da..a044b452fa 100644 --- a/src/Support/Diagnostic.hpp +++ b/src/Support/Diagnostic.hpp @@ -36,7 +36,7 @@ class Diagnostic { public: Range(T min, T max) : min(min), max(max) { - assert(min < max && "Illegal range"); + assert(min <= max && "Illegal range"); } }; diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 5bef3c7757..d0b42d94b6 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -298,6 +298,11 @@ def get_test_models(): "test_gather_elements_1_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_gather_elements_negative_indices_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # GatherND + "test_gathernd_example_int32_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}}, + "test_gathernd_example_float32_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}}, + "test_gathernd_example_int32_batch_dim1_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}}, + # Gemm "test_gemm_all_attributes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_gemm_alpha_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index 04f9ae252b..e05e6013a6 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -164,6 +164,7 @@ func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> { // ----- +// COM: Rank of 'data' has to be >=1 func @test_gather_elements_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64>) -> tensor<*xf32> { // expected-error @+1 {{onnx.GatherElements: operand '<block argument> of type 'tensor<f32>' at index: 0' has rank 0, rank should be > 0}} %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<f32>, tensor<5xi64>) -> tensor<*xf32> @@ -172,6 +173,7 @@ func @test_gather_elements_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64> // ----- +// COM: Rank of 'indices' must be equal to the rank of `data`. func @test_gather_elements_verifier_2(%arg0 : tensor<5xf32>, %arg1 : tensor<5x3xi64>) -> tensor<*xf32> { // expected-error @+1 {{onnx.GatherElements: operand '<block argument> of type 'tensor<5x3xi64>' at index: 1' has rank 2, rank should be 1}} %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<5xf32>, tensor<5x3xi64>) -> tensor<*xf32> @@ -180,6 +182,7 @@ func @test_gather_elements_verifier_2(%arg0 : tensor<5xf32>, %arg1 : tensor<5x3x // ----- +// COM: 'axis' valid range is [-r, r-1], where r = rank(data). func @test_gather_elements_verifier_3(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x1x32xi64>) -> tensor<*xf32> { // expected-error @+1 {{onnx.GatherElements: 'axis' value is 4, accepted range is [-4, 3]}} %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x1x32xi64>) -> tensor<*xf32> @@ -188,6 +191,7 @@ func @test_gather_elements_verifier_3(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tens // ----- +// COM: All index values in 'indices' are expected to be within bounds [-s, s-1] along axis of size s. func @test_gather_elements_verifier_4(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor<*xf32> { // expected-error @+2 {{onnx.GatherElements: 'indices' value is 3, accepted range is [-3, 2]}} %indices = "onnx.Constant"() {value = dense<[3]> : tensor<1xi64>} : () -> tensor<1xi64> @@ -197,6 +201,58 @@ func @test_gather_elements_verifier_4(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf3 // ----- +// COM: Rank of 'data' has to be >=1 +func @test_gatherND_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64>) -> tensor<*xf32> { + // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<f32>' at index: 0' has rank 0, rank should be > 0}} + %1 = "onnx.GatherND"(%arg0, %arg1) : (tensor<f32>, tensor<5xi64>) -> tensor<*xf32> +} + +// ----- + +// COM: Rank of 'indices' has to be >=1 +func @test_gatherND_verifier_2(%arg0 : tensor<2xf32>, %arg1 : tensor<i64>) -> tensor<*xf32> { + // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<i64>' at index: 1' has rank 0, rank should be > 0}} + %1 = "onnx.GatherND"(%arg0, %arg1) : (tensor<2xf32>, tensor<i64>) -> tensor<*xf32> +} + +// ----- + +// COM: The value batch_dims must be smaller than the minimum of rank(data) and rank(indices). +func @test_gatherND_verifier_3(%arg0 : tensor<1x2x3xf32>, %arg1 : tensor<2x2x2x2xi64>) -> tensor<*xf32> { + // expected-error @+1 {{onnx.GatherND: 'batch_dims' value is 3, accepted range is [0, 2]}} + %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 3 : si64}: (tensor<1x2x3xf32>, tensor<2x2x2x2xi64>) -> tensor<*xf32> +} + +// ----- + +// COM: The first 'batchDims' dimensions of the shape of the 'indices' and 'data' tensors must be equal. +func @test_gatherND_verifier_4(%arg0 : tensor<2x2x3x4xf32>, %arg1 : tensor<2x3x2xi64>) -> tensor<*xf32> { + // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<2x3x2xi64>' at index: 1' has dimension at index 1 with value 3, value should be 2}} + %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 2 : si64} : (tensor<2x2x3x4xf32>, tensor<2x3x2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () +} + +// ----- + +// COM: The last dimension of the 'indices' shape must be a value in the range [1, rank(data)-batch_dims]. +func @test_gatherND_verifier_5(%arg0 : tensor<1x2x3x4xf32>, %arg1 : tensor<1x4xi64>) -> tensor<*xf32> { + // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<1x4xi64>' at index: 1' has dimension at index 1 with value 4, value should be <= 3}} + %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<1x2x3x4xf32>, tensor<1x4xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () +} + +// ----- + +// COM: All values in 'indices' are expected to satisfy the inequality: +// COM: -data.shape[i] <= indices[...,i] <= (data.shape[i]-1)]. +func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> { + // expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}} + %indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64> + %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> +} + +// ----- + func @test_onehotencoder_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.OneHotEncoder' op input is a tensor of float, int32, or double, but no cats_int64s attribute}} %1 = "onnx.OneHotEncoder"(%arg0) { cats_string = ["a","b","c"]} : (tensor<2x2xf32>) -> tensor<*xf32> diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 6e3a2b65a9..ca2e30669f 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2532,6 +2532,82 @@ func @test_resize2(%arg0 : tensor<3x4xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : memref<2xi64> } +//----- + +// COM: Test GatherND with indices_shape[-1] == rank(data) - batch_dims +func @test_gather_nd_1(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<2xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<2xf32> + "std.return"(%0) : (tensor<2xf32>) -> () +// CHECK-LABEL: @test_gather_nd_1 +// CHECK-SAME: ([[PARAM_0:%.+]]: memref<2x2xf32>, [[PARAM_1:%.+]]: memref<2x2xi64>) -> memref<2xf32> { +// CHECK: [[RESHAPED_INDICES:%.+]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xi64> to memref<1x2x2xi64> +// CHECK: [[RESHAPED_DATA:%.+]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xf32> to memref<1x2x2xf32> +// CHECK-DAG: [[RES_BUFFER:%.+]] = memref.alloc() : memref<2xf32> +// CHECK-DAG: [[RES_BUFFER_INDEX:%.+]] = memref.alloca() : memref<index> +// CHECK-DAG: [[CST_0_0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_0:%.+]] = arith.constant 1 : index +// CHECK: krnl.store [[CST_0_0]], [[RES_BUFFER_INDEX]][] : memref<index> +// CHECK: [[LOOP:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1) with ([[LOOP]]#0 -> [[I_0:%.+]] = 0 to 1, [[LOOP]]#1 -> [[I_1:%.+]] = 0 to 2){ +// CHECK-DAG: [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP]]#0, [[LOOP]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[CST_0_1:%.+]] = arith.constant 0 : index +// CHECK: [[LOAD_INDEX_1:%.+]] = krnl.load [[RESHAPED_INDICES]][[[IV]]#0, [[IV]]#1, [[CST_0_1]]] : memref<1x2x2xi64> +// CHECK-DAG: [[INDEX_1:%.+]] = arith.index_cast [[LOAD_INDEX_1]] : i64 to index +// CHECK-DAG: [[CST_1_1:%.+]] = arith.constant 1 : index +// CHECK: [[LOAD_INDEX_2:%.+]] = krnl.load [[RESHAPED_INDICES]][[[IV]]#0, [[IV]]#1, [[CST_1_1]]] : memref<1x2x2xi64> +// CHECK: [[INDEX_2:%.+]] = arith.index_cast [[LOAD_INDEX_2]] : i64 to index +// CHECK-DAG: [[DATA_VAL:%.+]] = krnl.load [[RESHAPED_DATA]][[[IV]]#0, [[INDEX_1]], [[INDEX_2]]] : memref<1x2x2xf32> +// CHECK-DAG: [[RES_BUFFER_INDEX_VAL:%.+]] = krnl.load [[RES_BUFFER_INDEX]][] : memref<index> +// CHECK: krnl.store [[DATA_VAL]], [[RES_BUFFER]][[[RES_BUFFER_INDEX_VAL]]] : memref<2xf32> +// CHECK: [[PLUS_ONE:%.+]] = arith.addi [[RES_BUFFER_INDEX_VAL]], [[CST_1_0]] : index +// CHECK: krnl.store [[PLUS_ONE]], [[RES_BUFFER_INDEX]][] : memref<index> +// CHECK: } +// CHECK: [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2], strides: [1] : memref<2xf32> to memref<2xf32> +// CHECK: return [[RES]] : memref<2xf32> +} + +//----- + +// COM: Test GatherND with indices_shape[-1] < rank(data) - batch_dims +func @test_gather_nd_2(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>) -> tensor<2x1x2xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<2x1x2xf32> + "std.return"(%0) : (tensor<2x1x2xf32>) -> () +// CHECK-LABEL: func @test_gather_nd_2 +// CHECK-SAME: ([[PARAM_0:%.+]]: memref<2x2x2xf32>, [[PARAM_1:%.+]]: memref<2x1x2xi64>) -> memref<2x1x2xf32> { +// CHECK-DAG: [[RESHAPED_INDICES:%.+]] = memref.reinterpret_cast [[PARAM_1]] to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x1x2xi64> to memref<1x2x2xi64> +// CHECK-DAG: [[RESHAPED_DATA:%.+]] = memref.reinterpret_cast [[PARAM_0]] to offset: [0], sizes: [1, 2, 2, 2], strides: [8, 4, 2, 1] : memref<2x2x2xf32> to memref<1x2x2x2xf32> +// CHECK-DAG: [[RES_BUFFER:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: [[CST_0_0:%.+]] = arith.constant 0 : index +// CHECK: [[CST_1_0:%.+]] = arith.constant 1 : index +// CHECK: [[RES_INDEX_BUFFER:%.+]] = memref.alloca() : memref<index> +// CHECK: krnl.store [[CST_0_0]], [[RES_INDEX_BUFFER]][] : memref<index> +// CHECK: [[LOOP_0:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0]]#0, [[LOOP_0]]#1) with ([[LOOP_0]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK-DAG: [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0]]#0, [[LOOP_0]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[CST_0_1:%.+]] = arith.constant 0 : index +// CHECK: [[LOAD_INDEX_1:%.+]] = krnl.load [[RESHAPED_INDICES]]{{.}}[[IV]]#0, [[IV]]#1, [[CST_0_1]]{{.}} : memref<1x2x2xi64> +// CHECK-DAG: [[INDEX_1:%.+]] = arith.index_cast [[LOAD_INDEX_1]] : i64 to index +// CHECK-DAG: [[CST_1_1:%.+]] = arith.constant 1 : index +// CHECK: [[LOAD_INDEX_2:%.+]] = krnl.load [[RESHAPED_INDICES]]{{.}}[[IV]]#0, [[IV]]#1, [[CST_1_1]]{{.}} : memref<1x2x2xi64> +// CHECK-DAG: [[INDEX_2:%.+]] = arith.index_cast [[LOAD_INDEX_2]] : i64 to index +// CHECK-DAG: [[CST_0_2:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[DATA_1:%.+]] = krnl.load [[RESHAPED_DATA]]{{.}}[[IV]]#0, [[INDEX_1]], [[INDEX_2]], [[CST_0_2]]{{.}} : memref<1x2x2x2xf32> +// CHECK-DAG: [[RES_INDEX_1:%.+]] = krnl.load [[RES_INDEX_BUFFER]][] : memref<index> +// CHECK: krnl.store [[DATA_1]], [[RES_BUFFER]]{{.}}[[RES_INDEX_1]]{{.}} : memref<4xf32> +// CHECK: [[PLUS_ONE:%.+]] = arith.addi [[RES_INDEX_1]], [[CST_1_0]] : index +// CHECK: krnl.store [[PLUS_ONE]], [[RES_INDEX_BUFFER]][] : memref<index> +// CHECK: [[CST_1_2:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[DATA_2:%.+]] = krnl.load [[RESHAPED_DATA]]{{.}}[[IV]]#0, [[INDEX_1]], [[INDEX_2]], [[CST_1_2]]{{.}} : memref<1x2x2x2xf32> +// CHECK-DAG: [[RES_INDEX_2:%.+]] = krnl.load [[RES_INDEX_BUFFER]][] : memref<index> +// CHECK: krnl.store [[DATA_2]], [[RES_BUFFER]]{{.}}[[RES_INDEX_2]]{{.}} : memref<4xf32> +// CHECK: [[PLUS_ONE_1:%.+]] = arith.addi [[RES_INDEX_2]], [[CST_1_0]] : index +// CHECK: krnl.store [[PLUS_ONE_1]], [[RES_INDEX_BUFFER]][] : memref<index> +// CHECK: } +// CHECK: [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2, 1, 2], strides: [2, 2, 1] : memref<4xf32> to memref<2x1x2xf32> +// CHECK: return [[RES]] : memref<2x1x2xf32> +} + //----- func @test_reversesequence_1(%arg0: tensor<10x?xf32>, %arg1: tensor<10xi64>) -> tensor<*xf32> { diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index eecd97512b..e0f9cc9a70 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1568,6 +1568,62 @@ func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64> // CHECK: return [[RES]] : tensor<3x1x2xf32> } + +// ----- + +func @test_gather_nd_1(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<*xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_nd_1 + // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<2xf32> + // CHECK: return [[RES]] : tensor<2xf32> +} + +// ----- + +func @test_gather_nd_2(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x1xi64>) -> tensor<*xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x1xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_nd_2 + // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x1xi64>) -> tensor<2x2xf32> + // CHECK: return [[RES]] : tensor<2x2xf32> +} + +// ----- + +func @test_gather_nd_3(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<*xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_nd_3 + // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32> + // CHECK: return [[RES]] : tensor<2x2xf32> +} + +// ----- + +func @test_gather_nd_4(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>) -> tensor<*xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_nd_4 + // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<2x1x2xf32> + // CHECK: return [[RES]] : tensor<2x1x2xf32> +} + +// ----- + +func @test_gather_nd_5(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1xi64>) -> tensor<*xf32> { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<2x2x2xf32>, tensor<2x1xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_nd_5 + // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<2x2x2xf32>, tensor<2x1xi64>) -> tensor<2x2xf32> + // CHECK: return [[RES]] : tensor<2x2xf32> +} + // ----- func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> { diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 5d6ff15d5a..a15e7f254f 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -294,6 +294,7 @@ 'Flatten', 'Gather', 'GatherElements', + 'GatherND', 'Hardmax', 'InstanceNormalization', 'LogSoftmax',