Skip to content

Commit

Permalink
[torch] Lower torch.aten.sinh to linalg (#2662)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Dec 18, 2023
1 parent 9c655d0 commit 791c666
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
42 changes: 23 additions & 19 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenSinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenCoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>(
b, converter, payloadArgs[0], op);
Expand Down Expand Up @@ -1315,8 +1319,8 @@ class ConvertElementwiseOp : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenCoshOp, AtenReluOp, AtenPreluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
if (!isa<AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenPreluOp,
AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
Expand Down Expand Up @@ -1968,23 +1972,23 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<
AtenTanhOp, AtenCoshOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp,
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
Expand Down
18 changes: 18 additions & 0 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[]
return %0 : !torch.vtensor<[],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise$binary(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
Expand Down Expand Up @@ -46,6 +48,8 @@ func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vt
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise$ternary(
// CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
Expand All @@ -57,6 +61,8 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
return %0 : !torch.vtensor<[?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise$with_scalar_capture(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
Expand All @@ -75,6 +81,8 @@ func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1
return %0 : !torch.vtensor<[?],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise$static_1(
// CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0) -> (d0)>,
Expand All @@ -84,3 +92,13 @@ func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vt
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise_sinh
// CHECK: linalg.generic
// CHECK: math.sinh
func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> {
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

0 comments on commit 791c666

Please sign in to comment.