Skip to content

Commit

Permalink
[TorchToLinalg] Adds Quantization Support for ConvTranspose (llvm#3240)
Browse files Browse the repository at this point in the history
I spent a little while debugging numerics issues with some tests similar
to the ones in quantized_models.py, only to find that pytorch's
quantized conv transpose is catastrophically inaccurate. I'll upstream
the issue and only leave the tests here which are of the form quantize
-> dequantize -> op.
  • Loading branch information
zjgarvey authored Apr 30, 2024
1 parent 9442c66 commit 72349f7
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 25 deletions.
59 changes: 34 additions & 25 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
if (!isUnsignedType)
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
Expand Down Expand Up @@ -797,6 +798,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto resultTy = cast<ValueTensorType>(op.getType());

Value inputZp, weightZp;
bool inputUnsigned = false;
bool weightUnsigned = false;
if (auto make = op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
input = make.getSelf();
Expand All @@ -806,6 +809,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}

if (auto make = op.getWeight()
Expand All @@ -818,6 +823,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}

if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
Expand Down Expand Up @@ -916,15 +923,35 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// convert any uint8 quantization to int8 quantization
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
}
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
}
// Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput;
if (transposed) {
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!isa<mlir::FloatType>(resultDTy))
return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet");
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);

if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 =
Expand Down Expand Up @@ -994,7 +1021,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

// Allocate padded input tensor
Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy);
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
Expand All @@ -1017,24 +1044,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
strideInts.clear();
strideInts.append(numSpatialDims, 1);
} else {
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}

if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);

if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}

// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
Expand Down
5 changes: 5 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -372,6 +373,7 @@
"Conv2dQInt8Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
Expand Down Expand Up @@ -544,6 +546,7 @@
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
Expand Down Expand Up @@ -2100,6 +2103,7 @@
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
}

ONNX_XFAIL_SET = {
Expand Down Expand Up @@ -2254,6 +2258,7 @@
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"Convolution2DModule_basic",
"Convolution2DStridedModule_basic",
Expand Down
53 changes: 53 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, bias)


N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2


class ConvTranspose2DQInt8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz


@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)

0 comments on commit 72349f7

Please sign in to comment.