From b3b8e2e96a6af8b9e838c07b3095b8633c701526 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 23 Aug 2024 20:27:18 -0700 Subject: [PATCH] [torch] Fix lowerings of rshift and lshift (#3665) I missed adding second operand conversion and adding them to the set of rewrite patterns. --- .../TorchToLinalg/Uncategorized.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1d13c2700c62..29e1e80d9732 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -850,7 +850,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( cast(converter->convertType(lshiftScalar.getType())) .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Value other = + convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/operands[1].getType(), + /*dstOriginalDtype=*/dtype); return b.create(loc, self, other); } if (auto rshiftScalar = dyn_cast(op)) { @@ -858,7 +861,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( cast(converter->convertType(rshiftScalar.getType())) .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Value other = + convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/operands[1].getType(), + /*dstOriginalDtype=*/dtype); return b.create(loc, self, other); } if (auto subScalar = dyn_cast(op)) { @@ -1610,7 +1616,8 @@ class ConvertElementwiseOp : public ConversionPattern { AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, + Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, @@ -3304,10 +3311,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, + Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, + AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,