From 8d6075012408368a4efb31f228054698fef16df2 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 13 Feb 2024 15:31:07 +0000 Subject: [PATCH] [MLIR][Torch] Add OnnxToTorch and TorchToLinalg support for trig ops This commit adds the OnnxToTorch lowering for cosh, acosh, asin, asinh, and atanh op. This commit also adds the TorchToLinalg lowering for acosh, asin, asinh, and atanh op. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 360 +++++++++++++----- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 59 ++- .../TorchToLinalg/Uncategorized.cpp | 97 +++-- .../Transforms/AbstractInterpLibrary.cpp | 70 +++- .../build_tools/abstract_interp_lib_gen.py | 56 ++- .../build_tools/torch_ods_gen.py | 8 +- .../test_suite/elementwise.py | 176 +++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 81 ++++ 8 files changed, 746 insertions(+), 161 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index adf5e8396751..0becb668636e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -19,96 +19,6 @@ //===----------------------------------------------------------------------===// -def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanhOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanh_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCoshOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCosh_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenHardtanhOp : Torch_Op<"aten.hardtanh", [ AllowsTypeRefinement, HasValueSemantics, @@ -886,6 +796,96 @@ def Torch_AtenSin_Op : Torch_Op<"aten.sin_", [ }]; } +def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::asin : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::asin_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsin_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsin_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsinhOp : Torch_Op<"aten.asinh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::asinh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsinh_Op : Torch_Op<"aten.asinh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::asinh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenExpOp : Torch_Op<"aten.exp", [ AllowsTypeRefinement, HasValueSemantics, @@ -1021,6 +1021,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ }]; } +def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ AllowsTypeRefinement, HasValueSemantics, @@ -1066,6 +1111,51 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ }]; } +def Torch_AtenAcoshOp : Torch_Op<"aten.acosh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::acosh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAcosh_Op : Torch_Op<"aten.acosh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::acosh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenTanOp : Torch_Op<"aten.tan", [ AllowsTypeRefinement, HasValueSemantics, @@ -1111,6 +1201,51 @@ def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [ }]; } +def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ AllowsTypeRefinement, HasValueSemantics, @@ -1156,6 +1291,51 @@ def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ }]; } +def Torch_AtenAtanhOp : Torch_Op<"aten.atanh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atanh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtanh_Op : Torch_Op<"aten.atanh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e39c42b50422..e8c36d8cad54 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -103,7 +103,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - // TODO: Acosh unimplemented in torch-mlir // Add became forward compatible with Torch in version 7. patterns.onOp("Add", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -203,9 +202,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand, constAxis, constKeepDims); return success(); }); - // TODO: Asin unimplemented in torch-mlir - // TODO: Asinh unimplemented in torch-mlir - // TODO: Atanh unimplemented in torch-mlir + patterns.onOp("Asin", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Asinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -217,6 +235,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Atanh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -228,6 +257,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Acosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1041,6 +1081,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Cosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e4b683d41cee..25c2a93d2797 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -216,22 +216,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -276,18 +260,50 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && @@ -1505,7 +1521,8 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -2350,27 +2367,27 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< - AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp, - AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, - AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, - AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, - AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, - AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, - AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp>(); + AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, + AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, + AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, + AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, + AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp, + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 320f53f0b7b6..29c94304288b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6306,11 +6306,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %18 = torch.aten.append.t %7, %17 : !torch.list, !torch.int -> !torch.list\n" " return %7 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.asinh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" @@ -6318,10 +6326,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.acosh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atanh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6358,18 +6386,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" " %none = torch.constant.none\n" " %int1 = torch.constant.int 1\n" @@ -9371,6 +9387,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.acosh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -9391,6 +9412,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asinh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -12473,6 +12504,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 7fe6e8457fe8..c014808af97a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -89,18 +89,39 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim return diagonal -def aten〇tan〡shape(self: List[int]) -> List[int]: +def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇atan〡shape(self: List[int]) -> List[int]: +def aten〇asin〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇asinh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇cos〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) def aten〇cosh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇acos〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇acosh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇tan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇tanh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇atan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇atanh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇erf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -128,15 +149,6 @@ def aten〇exp〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇sin〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇cos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇acos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]: broadcast = upstream_shape_functions.broadcast(x1, x2) return broadcast[:dim] + broadcast[dim + 1:] @@ -1856,6 +1868,11 @@ def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇acosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: @@ -1878,6 +1895,16 @@ def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asinh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -4191,6 +4218,13 @@ def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6f674601393d..65e9f44c1126 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -255,8 +255,6 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops for key in [ - "aten::tanh : (Tensor) -> (Tensor)", - "aten::cosh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", @@ -274,12 +272,18 @@ def emit_with_mutating_variants(key, **kwargs): "aten::erfinv : (Tensor) -> (Tensor)", "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", + "aten::asin : (Tensor) -> (Tensor)", + "aten::asinh : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::cosh : (Tensor) -> (Tensor)", "aten::acos : (Tensor) -> (Tensor)", + "aten::acosh : (Tensor) -> (Tensor)", "aten::tan : (Tensor) -> (Tensor)", + "aten::tanh : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", + "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c1a827ffe108..2f74ceb84416 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -107,6 +107,182 @@ def ElementwiseCoshIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAcoshModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshModule()) +def ElementwiseAcoshModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAcoshIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshIntModule()) +def ElementwiseAcoshIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinModule()) +def ElementwiseAsinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinIntModule()) +def ElementwiseAsinIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhModule()) +def ElementwiseAsinhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhIntModule()) +def ElementwiseAsinhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAtanhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhModule()) +def ElementwiseAtanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAtanhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhIntModule()) +def ElementwiseAtanhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseBinaryModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2ee21c1e3841..3e4a476dbfbb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -127,6 +127,15 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: @test_atanh +func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_acos func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -558,6 +567,78 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // ----- +// CHECK-LABEL: @test_cosh_example +func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_cosh +func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh_example +func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh +func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin_example +func.func @test_asin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin +func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh_example +func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh +func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_si8 func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32>