Skip to content

Commit

Permalink
[onnx] Fix onnx.sigmoid for integer inputs/outputs (llvm#2914)
Browse files Browse the repository at this point in the history
Sample compilation crashes due to sigmoid with integer inputs/outputs.
This fix avoids crashing but still experiences an error.
  • Loading branch information
rsuderman authored Feb 16, 2024
1 parent 7a0d0e9 commit d65925a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 59 deletions.
5 changes: 3 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
intermediateShape[i] = -1;
}
if (intermediateShape[i] == ShapedType::kDynamic)
intermediateShape[i] = Torch::kUnknownSize;
}
auto intermediateType = Torch::ValueTensorType::get(
context, intermediateShape, resultTorchType.getOptionalDtype());
Expand Down
106 changes: 50 additions & 56 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
}

template <typename MathOpTy>
static Value
createCalculationForMathOpWithDtypeConversion(OpBuilder &b,
const TypeConverter *converter,
Value payloadArg, Operation *op) {
Type dtype = converter->convertType(op->getResult(0).getType())
.template cast<RankedTensorType>()
.getElementType();
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
Value payloadArg, Operation *op) {
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy =
cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Location loc = op->getLoc();
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
return b.create<MathOpTy>(loc, arg);
Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
auto newOp = b.create<MathOpTy>(loc, arg);
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}

template <typename OpTy>
Expand Down Expand Up @@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenExpm1Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLogOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog2Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog10Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog1pOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenErfOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenRsqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenNegOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcoshOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
Expand Down Expand Up @@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();

Value arg = payloadArgs[0];
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
auto negate = b.create<arith::NegFOp>(loc, arg);
auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
return b.create<arith::DivFOp>(loc, one, added);
auto div = b.create<arith::DivFOp>(loc, one, added);
outTy.dump();
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType()
Expand Down
4 changes: 3 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,9 @@
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic",

# Failure - onnx traces differently
"ElementwiseSigmoidIntModule_basic",

# Failure - unknown
"ChunkListUnpackUneven_Module_basic",
Expand Down Expand Up @@ -2192,7 +2195,6 @@
}

ONNX_CRASHING_SET = {
"ElementwiseSigmoidIntModule_basic",
"FlipModule_basic",
"IndexTensorNegativeIndexModule_basic",
"MoveDimIntNegativeIndexModule_basic",
Expand Down

0 comments on commit d65925a

Please sign in to comment.