From 4a7a7d76f8870cad43a1803312efce7a8ae8643b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 27 Feb 2024 22:48:07 -0800 Subject: [PATCH] [onnx] Fix ReduceMean lowering to torch (#2956) Torch lowering only supported the most recent version. Refactored the lowering so more easily handle default values and optional operands / attributes. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 228 ++++---- lib/Conversion/TorchToLinalg/Reduction.cpp | 97 ++-- .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++ .../TorchConversion/Transforms/Passes.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 493 ++++++++++-------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 186 ++++--- 6 files changed, 608 insertions(+), 449 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index bc2cde573967..adf6d1cb639a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1104,129 +1104,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value axes; int64_t keepDims; int64_t noop_with_empty_axes; - // Deal with case when no axes arg is passed - if (binder.op->getNumOperands() == 1) { - if (binder.tensorOperand(data) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - SmallVector axesList; - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - Value axesValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - axesList); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, axesValueList, keepDimsBool); - } else { - rewriter.replaceOp(binder.op, data); - } - return success(); - } - if (binder.tensorOperands(data, axes) || + if (binder.tensorOperandAtIndex(data, 0) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", 0)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - // deal with case when axes is empty - if (sizes.size() == 1 && sizes[0] == 0) { - if (noop_with_empty_axes == 0) { - // create dims list with all dims [0, data.getSizes().size()) - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - dimList.push_back(curr); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + // If any of the input dims are 0 we set to the upper limit: + if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && + (llvm::any_of(dataTy.getSizes(), + [](int64_t d) { return d == Torch::kUnknownSize; }) || + keepDims)) { + auto dty = dataTy.getDtype(); + Value scalar; + if (FloatType fpTy = dyn_cast(dty)) { + auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + scalar = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), + inf.convertToDouble())); + } + + if (IntegerType intTy = dyn_cast(dty)) { + auto mx = + intTy.isSigned() + ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + scalar = rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + mx.getSExtValue())); + } + + llvm::SmallVector fillDims; + for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { + auto staticDim = resultType.getSizes()[i]; + if (staticDim != Torch::kUnknownSize) { + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(staticDim))); + continue; } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - dimList); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimsBool); - } else { - rewriter.replaceOp(binder.op, data); + + Value iv = rewriter.create( + binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, data, iv)); } + + Value none = rewriter.create(binder.getLoc()); + Value fillDimsList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, fillDimsList, scalar, none, none, none, + none); return success(); } + + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - // convert axes (tensor) into torch int list while dealing with neg axis - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); + rewriter.create(binder.getLoc(), axes, zero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); } + Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList, keepDimBool); return success(); diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index e050764993e6..92f50523c764 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -60,18 +60,15 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - RankedTensorType valResultType = - getTypeConverter() - ->convertType(op.getResult(0).getType()) - .template cast(); - - RankedTensorType idxResultType = - this->getTypeConverter() - ->convertType(op.getResult(1).getType()) - .template cast(); + auto typec = this->getTypeConverter(); + auto valResultType = + cast(typec->convertType(op.getResult(0).getType())); + auto idxResultType = + cast(typec->convertType(op.getResult(1).getType())); RankedTensorType inputType = input.getType().template cast(); - Type idxElementType = idxResultType.getElementType(); + Type idxElementType = + getElementTypeOrSelf(typec->convertType(idxResultType)); if (!idxElementType.isa()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires integer-like result type"); @@ -109,14 +106,12 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } // Constant op to account for the reduction along dim. - auto c1 = rewriter.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); - } else if (keepDim) - resultShape.push_back(c1); + } } // First fill the output buffer for the index. Value filledTensorIdx = @@ -146,27 +141,23 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value filledTensorVal = rewriter.create(loc, fillValue, initTensorVal).result(); + SmallVector iteratorTypes( + inputType.getRank(), utils::IteratorType::parallel); + iteratorTypes[dim] = utils::IteratorType::reduction; + // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. + SmallVector exprs; - SmallVector iteratorTypes; SmallVector resultExprs; for (auto size : llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { exprs.push_back(rewriter.getAffineDimExpr(size.index())); - - if (unsigned(dim) == size.index()) { - iteratorTypes.push_back(utils::IteratorType::reduction); - // If `keepDim`, create affine map to the first element - // in the current dimension. - if (keepDim) - resultExprs.push_back(rewriter.getAffineConstantExpr(0)); - } else { - iteratorTypes.push_back(utils::IteratorType::parallel); + if (unsigned(dim) != size.index()) resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); - } } + auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, rewriter.getContext()); auto linalgOp = rewriter.create( @@ -219,12 +210,58 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { nestedLoc, ValueRange({resultVal, resultIndex})); }); - // This cast is required to fix the shape in the case of keepDim=True - Value valuesCast = rewriter.create(loc, valResultType, - linalgOp.getResult(0)); - Value idxCast = rewriter.create(loc, idxResultType, - linalgOp.getResult(1)); - rewriter.replaceOp(op, {valuesCast, idxCast}); + if (!keepDim) { + Value rVal = rewriter.create(loc, valResultType, + linalgOp.getResult(0)); + Value rIdx = rewriter.create(loc, idxResultType, + linalgOp.getResult(1)); + llvm::SmallVector res{rVal, rIdx}; + rewriter.replaceOp(op, res); + return success(); + } + + llvm::SmallVector valShape(valResultType.getShape()); + llvm::SmallVector idxShape(idxResultType.getShape()); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i] = valShape[i + 1]; + idxShape[i] = idxShape[i + 1]; + } + + valShape.resize(valShape.size() - 1); + idxShape.resize(idxShape.size() - 1); + + Value rVal = rewriter.create( + loc, valResultType.clone(valShape), linalgOp.getResult(0)); + Value rIdx = rewriter.create( + loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); + + SmallVector reassociation(valShape.size()); + if (reassociation.size() > 0) { + for (int i = 0; i < dim; ++i) + reassociation[i].push_back(i); + reassociation[std::max(0, dim - 1)].push_back(dim); + for (int i = dim, s = reassociation.size(); i < s; ++i) + reassociation[i].push_back(i + 1); + } + + valShape.push_back(0); + idxShape.push_back(0); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i + 1] = valShape[i]; + idxShape[i + 1] = idxShape[i]; + } + + valShape[dim] = 1; + idxShape[dim] = 1; + + Value unsqueezeVal = rewriter.create( + loc, valResultType, rVal, reassociation); + + Value unsqueezeIdx = rewriter.create( + loc, idxResultType, rIdx, reassociation); + + llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; + rewriter.replaceOp(op, unsqueezes); return success(); } }; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f9c1f63b568c..51a710d940e9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1316,6 +1316,57 @@ class DecomposeAten_LogSoftmaxBackwardDataOp }; } // namespace +namespace { +class DecomposeAtenAMinMaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Torch::AtenAminOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector dimList; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { + return rewriter.notifyMatchFailure(op, "dims not foldable constants"); + } + + bool keepdim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { + return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + } + + auto loc = op.getLoc(); + std::sort(dimList.begin(), dimList.end(), std::greater()); + + Value reduction = op.getSelf(); + auto resultTy = cast(op.getType()); + auto reductionTy = cast(reduction.getType()); + llvm::SmallVector reductionShape(reductionTy.getSizes()); + + for (auto dim : dimList) { + auto dimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + reductionShape[dim] = 1; + if (!keepdim) { + for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) + reductionShape[i] = reductionShape[i + 1]; + reductionShape.resize(reductionShape.size() - 1); + } + + reductionTy = rewriter.getType( + reductionShape, resultTy.getOptionalDtype()); + auto idxTy = rewriter.getType( + reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); + llvm::SmallVector types{reductionTy, idxTy}; + reduction = rewriter + .create(loc, types, reduction, + dimValue, op.getKeepdim()) + .getResult(0); + } + + rewriter.replaceOp(op, reduction); + return success(); + } +}; +} // namespace + // Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into // `AtenMinDimOp` namespace { @@ -6867,6 +6918,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 7ac95ab6c4e9..55bedc1192eb 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -77,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToTMTensorPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); pm.addNestedPass(createConvertTorchToTensorPass()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 74f7300c9274..a4ac58b1d909 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1472,6 +1472,62 @@ } ONNX_XFAIL_SET = { + # Failure - cast error + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "MseLossMeanReductionModule_basic", + "PermuteNegativeIndexModule_basic", + "StdBiasedModule_basic", + "VarBiasedModule_basic", + "VarMeanBiasedModule_basic", + + # Failure - constant int lowering + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + + # Failure - incorrect numerics + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseSeluModule_basic", + "FlipModuleStaticShape_basic", + "FlipNegativeIndexModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "ResNet18Module_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "TupleModule_basic", + + # Failure - incorrect shape + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "ExpandModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxNegativeDim_basic", + "ViewSizeFromOtherTensor_basic", + # Failure - onnx_export "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -1594,6 +1650,7 @@ "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", "ExponentialModule_basic", + "FloatImplicitModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -1613,6 +1670,7 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IntFloatModule_basic", + "IntImplicitModule_basic", "IouOfModule_basic", "IsFloatingPointFloat_True", "IsFloatingPointInt_False", @@ -1818,13 +1876,8 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", - + # Failure - onnx_import - "BucketizeTensorFloatModule_basic", - "BucketizeTensorModule_basic", - "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -1832,31 +1885,6 @@ "DiagonalModule_with_dims_and_offset", "DiagonalModule_with_negative_dims", "DiagonalModule_with_offset", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorInt8Module_basic", - "ElementwiseClampTensorIntModule_basic", - "HBC_basic", - "IndexPut1DFloatAccumulateModule_basic", - "IndexPut1DIntAccumulateModule_basic", - "IndexPut2DFloatAccumulateModule_basic", - "IndexPut2DIntAccumulateModule_basic", - "IndexPut3DFloatAccumulateModule_basic", - "IndexPut3DIntAccumulateModule_basic", - "IndexPutHackedTwin1DFloatAccumulateModule_basic", - "IndexPutHackedTwin1DIntAccumulateModule_basic", - "IndexPutHackedTwin2DFloatAccumulateModule_basic", - "IndexPutHackedTwin2DIntAccumulateModule_basic", - "IndexPutHackedTwin3DFloatAccumulateModule_basic", - "IndexPutHackedTwin3DIntAccumulateModule_basic", - "NormalizeModule_basic", - "PadWithNoneValModule_basic", - "QuantizedMLP_basic", - "RandModule_basic", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceFloatProdModuleIncludeSelf", @@ -1867,21 +1895,11 @@ "ScatterReduceIntSumModuleIncludeSelf", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", - - # Failure - onnx_lowering + + # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AtenMmFloatTypes_basic", - "AtenMmIntTypes_basic", - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", - "AtenTriuModule_basic", - "AtenTriuWithNegDiagonalModule_basic", - "AtenTriuWithPosDiagonalModule_basic", "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", @@ -1890,78 +1908,73 @@ "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", "AvgPool2dStaticModule_basic", - "BernoulliFloatModule_basic", - "BernoulliModule_basic", - "BernoulliPModule_basic", - "BernoulliTensorModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", - "CrossEntropyLossModule_basic", - "CrossEntropyLossNoReductionModule_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", + + # Failure - onnx_lowering: onnx.Cast + "BucketizeTensorOutInt32RightModule_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "HBC_basic", + "QuantizedMLP_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI64ToI32Module_basic", + + # Failure - onnx_lowering: onnx.Clip + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseClampTensorIntModule_basic", + "NormalizeModule_basic", + + # Failure - onnx_lowering: onnx.Einsum "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseMishModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "ElementwiseToDtypeI64ToI8Module_basic", - "ElementwiseToDtypeI64ToUI8Module_basic", - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", + + # Failure - onnx_lowering: onnx.Gemm + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MmTanhModule_basic", + + # Failure - onnx_lowering: onnx.HardSwish "HardswishModule_basic", "HardswishRandomModule_basic", - "IndexPut1DFloatNonAccumulateModule_basic", - "IndexPut1DIntNonAccumulateModule_basic", - "IndexPut2DFloatNonAccumulateModule_basic", - "IndexPut2DIntNonAccumulateModule_basic", - "IndexPut3DFloatNonAccumulateModule_basic", - "IndexPut3DIntNonAccumulateModule_basic", - "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin1DIntNonAccumulateModule_basic", - "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin2DIntNonAccumulateModule_basic", - "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "MobilenetV3Module_basic", + + # Failure - onnx_lowering: onnx.LogSoftmax "LogSoftmaxIntModule_basic", + "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModule_basic", + + # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - "MmDagModule_basic", - "MmModule_basic", - "MmModule_chained", - "MmTanhModule_basic", - "MobilenetV3Module_basic", - "MseLossSumReductionWithDifferentElemTypeModule_basic", - "NativeDropoutTrainModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", + + # Failure - onnx_lowering: onnx.Mod + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + + # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", + + # Failure - onnx_lowering: onnx.Pad + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", "PadModule_basic", - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", - "RandnDtypeDeviceModule_basic", - "RandnGeneratorF64Module_basic", - "RandnGeneratorModule_basic", - "RandnLikeDtypeModule_basic", - "RandnLikeModule_basic", - "RandnModule_basic", - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", - "ReduceL2NormModule_basic", - "ReduceL3NormAllDimsModule_basic", - "ReduceL3NormKeepDimModule_basic", - "ReduceProdDimIntFloatModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumFloatModule_basic", - "ReduceSumSignedIntModule_basic", - "ReduceSumUnsignedIntModule_basic", + "PadWithNoneValModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule3dInput_Left", @@ -1976,19 +1989,43 @@ "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "ScatterSrcModule_basic", - "ScatterSrcStaticModule_basic", - "ScatterValueFloatModule_basic", - "ScatterValueIntModule_basic", - "SoftplusModule_basic", - "SortTensorDescending_basic", - "SortTensorInteger_basic", - "SortTensorNegativeDimension_basic", - "SortTensorSpecificDimension_basic", - "SortTensor_basic", - "SqueezeModule_allUnitDim", - "SqueezeModule_broadcast", - "SqueezeModule_static", + + # Failure - onnx_lowering: onnx.RandomNormal + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnModule_basic", + + # Failure - onnx_lowering: onnx.RandomNormalLike + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniform + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniformLike + "BernoulliFloatModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL1 + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL2 + "ReduceL2NormModule_basic", + + # Failure - onnx_lowering: onnx.ReduceProd + "BernoulliModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "ReduceProdDimIntFloatModule_basic", "StdCorrectionAllDimReduceModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionLargeInputModule_basic", @@ -1999,14 +2036,6 @@ "StdDimKeepDimTrueModule_basic", "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", - "TriuBroadcastModule_basic", - "TriuModule_basic", - "TypeConversionI1ToI32Module_basic", - "TypeConversionI64ToI32Module_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", "VarCorrectionAllDimReduceModule_basic", "VarCorrectionKeepDimModule_basic", "VarCorrectionLargeInputModule_basic", @@ -2025,58 +2054,85 @@ "VarMeanDimModule_basic", "VarMeanUnbiasedModule_basic", "VarUnbiasedModule_basic", - "_LogSoftmaxModuleStable_basic", - "_LogSoftmaxModule_basic", - - # Failure - cast_error - "MeanDimNoneDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanModule_basic", - "MseLossMeanReductionModule_basic", - "StdBiasedModule_basic", - "VarBiasedModule_basic", - "VarMeanBiasedModule_basic", - - # Failure - constant_int - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", - "ReduceMinAlongDim_basic", - "ReduceMinFloatModule_basic", - "ReduceMinKeepDimReturnBoth_basic", - "ReduceMinSignedIntModule_basic", - "ReduceMinUnsignedIntModule_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", - - # Failure - operand_type - "ElementwiseAcosIntModule_basic", - "ElementwiseAsinIntModule_basic", - "ElementwiseAtanTensorIntModule_basic", - "ElementwiseCosIntModule_basic", - "ElementwiseErfIntModule_basic", - "ElementwiseExpIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLogIntModule_basic", - "ElementwiseSinIntModule_basic", - "ElementwiseTanIntModule_basic", - "ElementwiseUnaryIntModule_basic", - - # Failure - expand_multidim - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - - # Failure - rankless_return + + # Failure - onnx_lowering: onnx.ReduceSum + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + + # Failure - onnx_lowering: onnx.Resize + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", + + # Failure - onnx_lowering: onnx.ScatterElements + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + + # Failure - onnx_lowering: onnx.ScatterND + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + + # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + + # Failure - onnx_lowering: onnx.Softplus + "ElementwiseMishModule_basic", + "SoftplusModule_basic", + + # Failure - onnx_lowering: onnx.Squeeze + "SqueezeModule_allUnitDim", + "SqueezeModule_broadcast", + "SqueezeModule_static", + + # Failure - onnx_lowering: onnx.TopK + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + + # Failure - onnx_lowering: onnx.Trilu + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + + # Failure - rankless return "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", @@ -2088,8 +2144,8 @@ "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - - # Failure - view_lowering + + # Failure - torch.aten.view lower "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", @@ -2097,13 +2153,11 @@ "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputNonContiguousDynamic_basic", "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", - "IndexTensorSelectDimModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic", "RepeatModule_basic", "SelectIntModule_basic", @@ -2116,63 +2170,50 @@ "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", "ViewSizeDimLedByCollapsedOnesModule_basic", "ViewSizeDimLedByExpandedOnesModule_basic", - - # Failure - numerical - "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", - "ElementwiseSeluModule_basic", - "EmbeddingModule1DIndices_basic", - "FlipNegativeIndexModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectDynamicModulebasic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", - "PixelShuffleModuleStaticRank4Float32_basic", - "ResNet18Module_basic", - "SliceCopyEndGreaterThanDimSize_Module_basic", - "SliceCopyNegative_Module_basic", - "SliceCopyNonZeroDim_Module_basic", - "SliceCopy_Module_basic", - "TupleModule_basic", - - # Failure - shape - "ArangeStartOutDtypeModule_basic", - "ArangeStartOutViewModule_basic", - "BroadcastDynamicDimModule_basic", - "BroadcastToModule_basic", - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "ExpandModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "PermuteNegativeIndexModule_basic", - "ReduceAmaxKeepDim_basic", - "ReduceMaxKeepDimReturnBoth_basic", - "ReduceMaxNegativeDim_basic", - "ViewSizeFromOtherTensor_basic", - # Failure - onnx traces differently - "ElementwiseSigmoidIntModule_basic", - # Failure - unknown + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", "ElementwisePreluModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", "FlattenDynamicModule_basic", - "FlipModuleStaticShape_basic", "GluStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorSelectDimModule_basic", "MaskedFillTensorFloatValueModule_basic", "ReduceAllDimEmpty_basic", "ReduceAllDimFloat_basic", @@ -2180,8 +2221,6 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", - "FloatImplicitModule_basic", - "IntImplicitModule_basic", } ONNX_CRASHING_SET = { } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 704e03acc1e2..42be32166c5f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -926,107 +926,121 @@ func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_reduce_min_empty_set_fp +func.func @test_reduce_min_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_empty_set_int +func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> + return %0 : !torch.vtensor<[2,1,4],si32> +} + +// ----- + + // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> return %0 : !torch.vtensor<[4,1],i1> } -// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example -func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> -} +// ----- -// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example -func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims +func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> - return %0 : !torch.vtensor<[3,2],f32> + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> } -// CHECK-LABEL: func.func @test_reduce_min_empty_set -func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,1,4],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> - return %0 : !torch.vtensor<[2,1,4],f32> -} +// ----- -// CHECK-LABEL: func.func @test_reduce_min_keepdims_example -func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> +// CHECK-LABEL: func.func @test_reduce_all_dims_default +func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> } -// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example -func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// ----- + +func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> } // -----