From d65925a8b465d4a84be947d37197178d5c5cc6d2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Feb 2024 13:35:25 -0800 Subject: [PATCH] [onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914) Sample compilation crashes due to sigmoid with integer inputs/outputs. This fix avoids crashing but still experiences an error. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 5 +- .../TorchToLinalg/Uncategorized.cpp | 106 +++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 4 +- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 525335161db8..405a02bb3c58 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector intermediateShape(operandTy.getShape()); for (int i = 0, s = operandTy.getRank(); i < s; ++i) { - if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) { + if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) intermediateShape[i] = -1; - } + if (intermediateShape[i] == ShapedType::kDynamic) + intermediateShape[i] = Torch::kUnknownSize; } auto intermediateType = Torch::ValueTensorType::get( context, intermediateShape, resultTorchType.getOptionalDtype()); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0019acfc2944..e8e671955835 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { } template -static Value -createCalculationForMathOpWithDtypeConversion(OpBuilder &b, - const TypeConverter *converter, - Value payloadArg, Operation *op) { - Type dtype = converter->convertType(op->getResult(0).getType()) - .template cast() - .getElementType(); +static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, + Value payloadArg, Operation *op) { + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = + cast(converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); Location loc = op->getLoc(); - Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); - return b.create(loc, arg); + Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy); + auto newOp = b.create(loc, arg); + return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } template @@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; @@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createEqual(b, loc, abs.getType(), abs, infinity); } if (isa(op)) { - auto negate = createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = cast( + converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); + + Value arg = payloadArgs[0]; + arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy); + auto negate = b.create(loc, arg); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); - return b.create(loc, one, added); + auto div = b.create(loc, one, added); + outTy.dump(); + return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { if (!relu.getType() diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 66fbc41588e6..a1cee9037933 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2165,6 +2165,9 @@ "ReduceMaxKeepDimReturnBoth_basic", "ReduceMaxNegativeDim_basic", "ViewSizeFromOtherTensor_basic", + + # Failure - onnx traces differently + "ElementwiseSigmoidIntModule_basic", # Failure - unknown "ChunkListUnpackUneven_Module_basic", @@ -2192,7 +2195,6 @@ } ONNX_CRASHING_SET = { - "ElementwiseSigmoidIntModule_basic", "FlipModule_basic", "IndexTensorNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic",