Skip to content

Commit

Permalink
[torch] Support torch.convolution quantized lowering to linalg (#…
Browse files Browse the repository at this point in the history
…2811)

Linalg has quantized specific operations. We can lower to these
operations when there is a known zeropoint and scale operations. This
allows the `convolution` to occur with lower bitwidth's, improving the
overall performance.
  • Loading branch information
rsuderman authored Jan 30, 2024
1 parent 4c55784 commit 25a5a22
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 153 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
// padding value is zero.
Value getDynamicZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<Value> &padding,
int unpaddedDims = 0);
int unpaddedDims = 0, Value pad = {});

// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
Expand Down
382 changes: 265 additions & 117 deletions lib/Conversion/TorchToLinalg/Linear.cpp

Large diffs are not rendered by default.

28 changes: 16 additions & 12 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,11 +1342,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto valueTy = value.getType();
auto qtensor = op->getOperand(0);
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype();
auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
if (!makeQTensor) {
op->emitWarning(
"unimplemented: dequantizing tensor of unknown scale / zero-point");

Value zp, scale;
if (auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
zp = makeQTensor.getZeroPoint();
scale = makeQTensor.getScale();
}

if (auto quant = qtensor.getDefiningOp<AtenQuantizePerTensorOp>()) {
zp = quant.getZeroPoint();
scale = quant.getScale();
}

if (!zp || !scale) {
return nullptr;
}

Expand All @@ -1362,10 +1371,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
}

Value zp = makeQTensor.getZeroPoint();
zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()),
makeQTensor.getZeroPoint());
b, loc, converter->convertType(zp.getType()), zp);
auto zpTy = zp.getType();

if (zpTy != outIntTy) {
Expand All @@ -1380,10 +1387,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);
}

Value scale = makeQTensor.getScale();
scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()),
makeQTensor.getScale());
b, loc, converter->convertType(scale.getType()), scale);
if (scale.getType() != value.getType()) {
scale = b.create<arith::TruncFOp>(loc, value.getType(), scale);
}
Expand Down Expand Up @@ -2233,7 +2238,6 @@ class ConvertDequantizePerChannel
auto qoperand = op.getOperand();
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
if (!make) {
llvm::errs() << "Did not find make per channel\n";
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
}

Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
// padding value is zero.
Value torch_to_linalg::getDynamicZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
int unpaddedDims) {
int unpaddedDims, Value pad) {
assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType");
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank();
Expand All @@ -93,12 +93,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
SmallVector<int64_t>(inRank, kUnknownSize))),
elementType);

Value cf0 =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
SmallVector<OpFoldResult> paddingValues =
getAsOpFoldResult(paddingIncludingUnchanged);
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
/*high=*/paddingValues, cf0);
/*high=*/paddingValues, pad);
}

Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
Expand Down
32 changes: 27 additions & 5 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class QuantizeOperands : public OpRewritePattern<SrcOp> {
llvm::SmallVector<Value> operands(op->getOperands());

bool dequanted = false;
for (auto &operand : operands) {
auto f = [&dequanted](Value operand) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
Expand All @@ -39,7 +39,11 @@ class QuantizeOperands : public OpRewritePattern<SrcOp> {
operand = dequant.getOperand();
dequanted = true;
}
}
return operand;
};

operands[0] = f(operands[0]);
operands[1] = f(operands[1]);

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
Expand Down Expand Up @@ -77,6 +81,7 @@ template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
if (!rhsScale || !lhsScale)
return failure();

auto resultTy = cast<ValueTensorType>(op.getType());
auto biasTy = bias.getType().cast<ValueTensorType>();
auto biasETy = biasTy.getOptionalDtype();
if (!biasETy || !isa<mlir::FloatType>(biasETy))
Expand All @@ -95,9 +100,27 @@ template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
bias = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
bias = rewriter.create<AtenIntReprOp>(
op.getLoc(),
rewriter.getType<ValueTensorType>(
biasTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed)),
bias);

operands[2] = bias;
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);

auto convTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed));
auto conv = rewriter.create<SrcOp>(op.getLoc(), convTy, operands);

auto convQTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto makeOut = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), convQTy, conv, biasScale, zero);
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(op, op.getType(),
makeOut);

return success();
}
};
Expand Down Expand Up @@ -151,7 +174,7 @@ class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);

// Attach the quantize information to the resulting quint32:
// Attach the quantize information to the resulting qint32:
auto intReprTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed));
Expand Down Expand Up @@ -194,7 +217,6 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
QuantizeAccumulator<AtenConvolutionOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);

Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"AtenMmQuint8_basic",
"Conv2dQInt8Module_basic",

# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
Expand Down Expand Up @@ -1541,4 +1542,5 @@
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic",
}
35 changes: 35 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,38 @@ def forward(self, x, weight, bias):
@register_test_case(module_factory=lambda: ConvTbcModule())
def ConvTbcModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))

class Conv2dQInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
])
def forward(self, inputVec, weight, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec)

weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3)
weight = torch.dequantize(weight)

bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
bias = torch.dequantize(bias)

return torch.ops.aten.conv2d(inputVec,
weight,
bias=bias,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
groups=1)
@register_test_case(module_factory=lambda: Conv2dQInt8Module())
def Conv2dQInt8Module_basic(module, tu: TestUtils):
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, bias)
28 changes: 14 additions & 14 deletions test/Dialect/Torch/fuse-quantized-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtens
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
%16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>

// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
// CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
// CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32>
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK: %[[DTYPE:.+]] = torch.constant.int 14
// CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
// CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1
// CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
// CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32>
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],si32>
// CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
return %16 : !torch.vtensor<[1,3,7,7],f32>
}

0 comments on commit 25a5a22

Please sign in to comment.