Skip to content

Commit

Permalink
[torch] Fix lowerings of rshift and lshift (#3665)
Browse files Browse the repository at this point in the history
I missed adding second operand conversion and adding them to the set of
rewrite patterns.
  • Loading branch information
rsuderman authored Aug 24, 2024
1 parent 9a4c8c6 commit b3b8e2e
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,15 +850,21 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
cast<RankedTensorType>(converter->convertType(lshiftScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value other =
convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/operands[1].getType(),
/*dstOriginalDtype=*/dtype);
return b.create<arith::ShLIOp>(loc, self, other);
}
if (auto rshiftScalar = dyn_cast<Aten__Rshift__ScalarOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(rshiftScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value other =
convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/operands[1].getType(),
/*dstOriginalDtype=*/dtype);
return b.create<arith::ShRUIOp>(loc, self, other);
}
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
Expand Down Expand Up @@ -1610,7 +1616,8 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
Expand Down Expand Up @@ -3304,10 +3311,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
Expand Down

0 comments on commit b3b8e2e

Please sign in to comment.