From 6466130960df5ae22f4a2c5bd4db40b0c0e35ce9 Mon Sep 17 00:00:00 2001 From: default Date: Tue, 3 Dec 2024 17:27:37 +0000 Subject: [PATCH] feat: support lowering of channelwise quantization to linalg --- include/torch-mlir/Conversion/Utils/Utils.h | 2 +- lib/Conversion/TorchToLinalg/Linear.cpp | 339 ++++++++++++++++-- lib/Conversion/Utils/Utils.cpp | 2 +- .../Conversion/TorchToLinalg/convolution.mlir | 39 +- 4 files changed, 347 insertions(+), 35 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d21dd5504dcd..043d60714620 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -79,7 +79,7 @@ SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, // TODO: remove this when list gets full support. SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, const TypeConverter *converter, - SmallVectorImpl &vs); + const SmallVectorImpl &vs); mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, mlir::Type elementType, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 6a1ba9e5f907..a6008740fdaf 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -40,9 +40,32 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create( - loc, minSI, cast(zp.getType()).getWidth()); - zp = rewriter.create(loc, zp, minSIValue); + // get width of the zero point value if it is a tensor + int64_t bitWidth = 0; + if (isa(zp.getType())) { + auto zpType = cast(zp.getType()); + bitWidth = zpType.getElementType().getIntOrFloatBitWidth(); + } else { + bitWidth = zp.getType().getIntOrFloatBitWidth(); + } + Value minSIValue = + rewriter.create(loc, minSI, bitWidth); + + // Use a linalg.generic op to add the minSIValue to the zero point if it is a + // tensor. + if (isa(zp.getType())) { + zp = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, zp, + cast(zp.getType()).getElementType(), + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = + rewriter.create(loc, payloadArgs[0], minSIValue); + b.create(loc, result); + }); + } else { + zp = rewriter.create(loc, zp, minSIValue); + } + minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, ValueRange{arg}, @@ -742,6 +765,85 @@ class ConvertAtenBmmOp : public OpConversionPattern { } // namespace namespace { +struct QuantizationValues { + Value self; + Value zeroPoint; + bool isUnsigned; +}; + +QuantizationValues getQuantizationPerTensorValues( + ConversionPatternRewriter &rewriter, Location loc, + Aten_MakePerTensorQuantizedTensorOp makePerTensorQuantizedTensorOp, + const TypeConverter *const typeConverter) { + QuantizationValues values; + Value self = makePerTensorQuantizedTensorOp.getSelf(); + Value zeroPoint = makePerTensorQuantizedTensorOp.getZeroPoint(); + self = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(self.getType()), self); + zeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(zeroPoint.getType()), + zeroPoint); + zeroPoint = + rewriter.create(loc, rewriter.getI32Type(), zeroPoint); + auto torchDtype = + cast(makePerTensorQuantizedTensorOp.getType()) + .getDtype(); + + values.self = self; + values.zeroPoint = zeroPoint; + values.isUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); + + return values; +} + +QuantizationValues getQuantizationPerChannelValues( + ConversionPatternRewriter &rewriter, Location loc, + Aten_MakePerChannelQuantizedTensorOp makePerChannelQuantizedTensorOp, + const TypeConverter *const typeConverter) { + QuantizationValues values; + Value self = makePerChannelQuantizedTensorOp.getSelf(); + Value zeroPoint = makePerChannelQuantizedTensorOp.getZeroPoint(); + self = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(self.getType()), self); + zeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(zeroPoint.getType()), + zeroPoint); + + // create a linalg op since we need to do some arithmetic on the zero point + // as is it a tensor. + auto zeroPointType = cast(zeroPoint.getType()); + auto selfType = cast(self.getType()); + if (zeroPointType.getElementTypeBitWidth() > + selfType.getElementTypeBitWidth()) { + + zeroPoint = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, zeroPoint, rewriter.getI32Type(), + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = rewriter.create( + loc, rewriter.getI32Type(), payloadArgs[0]); + b.create(loc, result); + }); + } else if (zeroPointType.getElementTypeBitWidth() < + selfType.getElementTypeBitWidth()) { + zeroPoint = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, zeroPoint, rewriter.getI32Type(), + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = rewriter.create( + loc, rewriter.getI32Type(), payloadArgs[0]); + b.create(loc, result); + }); + } + auto torchDtype = + cast(makePerChannelQuantizedTensorOp.getType()) + .getDtype(); + + values.self = self; + values.zeroPoint = zeroPoint; + values.isUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); + + return values; +} + struct ConvolutionAttributes { SmallVector padding; SmallVector outputPadding; @@ -964,6 +1066,144 @@ LogicalResult handleUngroupedConv(ConversionPatternRewriter &rewriter, Location return success(); } +LogicalResult handleUngroupedConvChannelwiseQuantized(ConversionPatternRewriter &rewriter, Location loc, + Value &weight, Value &paddedInput, + Value &outputTensor, Value &inputZp, Value &weightZp, + size_t numSpatialDims, size_t inRank, + DenseIntElementsAttr stridesAttr, DenseIntElementsAttr dilationAttr, + bool inputIsChannelwiseQuantized, bool weightIsChannelwiseQuantized, + Type inputDTy, Type weightDTy, Type accumulatorDType, Type resultDTy, + AtenConvolutionOp op, const TypeConverter *typeConverter) { + Value conv; + // There is no linalg.conv op for channelwise quantized arguments. + // Generate generic linalg ops to first substract the zero points and then + // perform the convolution. The zeropoint is either a scalar or a tensor. + if (inputIsChannelwiseQuantized) { + // create zero init tensor with the same shape as input + Value zeroTensor = rewriter.create( + loc, tensor::getMixedSizes(rewriter, loc, paddedInput), inputDTy); + SmallVector addedDimensions; + for (size_t i = 0; i < inRank; ++i) + if (i != 1) + addedDimensions.push_back(i); + // broadcast the zeropoint to match the input shape + Value broadcastedInputZeroPoint = + rewriter + .create(loc, inputZp, zeroTensor, + addedDimensions) + ->getResult(0); + + // create output zero tensor + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, paddedInput), + inputDTy); + // subtract the zeropoint from the input + paddedInput = + rewriter + .create( + loc, ValueRange{paddedInput, broadcastedInputZeroPoint}, + ValueRange{initTensor}) + ->getResult(0); + } else { + // else just substract the value + // subtract the zeropoint from the input + auto paddedInputElementType = + cast(paddedInput.getType()).getElementType(); + inputZp = rewriter.create(loc, paddedInputElementType, + inputZp); + paddedInput = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, ValueRange{paddedInput}, paddedInputElementType, + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = + b.create(loc, payloadArgs[0], inputZp); + b.create(loc, result); + }); + } + + // Do the same for the weight + if (weightIsChannelwiseQuantized) { + // create zero init tensor with the same shape as weight + Value zeroTensor = rewriter.create( + loc, tensor::getMixedSizes(rewriter, loc, weight), weightDTy); + SmallVector addedDimensions; + for (size_t i = 1; i < inRank; ++i) + addedDimensions.push_back(i); + // broadcast the zeropoint to match the weight shape + Value broadcastedWeightZeroPoint = + rewriter + .create(loc, weightZp, zeroTensor, + addedDimensions) + ->getResult(0); + + // create output zero tensor + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, weight), weightDTy); + + // subtract the zeropoint from the weight + weight = rewriter + .create( + loc, weight.getType(), + ValueRange{weight, broadcastedWeightZeroPoint}, + ValueRange{initTensor}) + .getResult(0); + } else { + // else just substract the value + // subtract the zeropoint from the weight + auto paddedWeightElementType = + cast(weight.getType()).getElementType(); + weightZp = rewriter.create( + loc, paddedWeightElementType, weightZp); + weight = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, ValueRange{weight}, + cast(weight.getType()).getElementType(), + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = + b.create(loc, payloadArgs[0], weightZp); + b.create(loc, result); + }); + } + + // Perform the convolution with the zero-point subtracted inputs + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 3: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D, 2D, and 3D convolution supported"); + } + Type newResultType = typeConverter->convertType(op.getType()); + if (accumulatorDType != resultDTy) { + Type resultElementType = + cast(newResultType).getElementType(); + conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, + resultElementType); + } + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); +} + LogicalResult handleUngroupedConvQuantized(ConversionPatternRewriter &rewriter, Location loc, Value &weight, Value &paddedInput, Value &outputTensor, Value &inputZp, Value &weightZp, @@ -1185,35 +1425,55 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value inputZp, weightZp; bool inputUnsigned = false; bool weightUnsigned = false; - if (auto make = op.getInput() - .getDefiningOp()) { - input = make.getSelf(); - inputZp = make.getZeroPoint(); - input = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(input.getType()), input); - inputZp = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(inputZp.getType()), - inputZp); - inputZp = - rewriter.create(loc, rewriter.getI32Type(), inputZp); - auto torchDtype = cast(make.getType()).getDtype(); - inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); + if (auto makeInputPerTensorQuantizedOp = + op.getInput() + .getDefiningOp()) { + + QuantizationValues quantizationParameters = + getQuantizationPerTensorValues( + rewriter, loc, makeInputPerTensorQuantizedOp, typeConverter); + input = quantizationParameters.self; + inputZp = quantizationParameters.zeroPoint; + inputUnsigned = quantizationParameters.isUnsigned; } - if (auto make = op.getWeight() - .getDefiningOp()) { - weight = make.getSelf(); - weightZp = make.getZeroPoint(); - - weight = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(weight.getType()), weight); - weightZp = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(weightZp.getType()), - weightZp); - weightZp = rewriter.create(loc, rewriter.getI32Type(), - weightZp); - auto torchDtype = cast(make.getType()).getDtype(); - weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); + if (auto makeWeightPerTensorQuantizedOp = + op.getWeight() + .getDefiningOp()) { + QuantizationValues quantizationParameters = + getQuantizationPerTensorValues( + rewriter, loc, makeWeightPerTensorQuantizedOp, typeConverter); + weight = quantizationParameters.self; + weightZp = quantizationParameters.zeroPoint; + weightUnsigned = quantizationParameters.isUnsigned; + } + + // check if input is channelwise quantized + bool inputIsChannelwiseQuantized = false; + if (auto makeInputPerChannelQuantizedOp = + op.getInput() + .getDefiningOp()) { + inputIsChannelwiseQuantized = true; + QuantizationValues quantizationParameters = + getQuantizationPerChannelValues( + rewriter, loc, makeInputPerChannelQuantizedOp, typeConverter); + input = quantizationParameters.self; + inputZp = quantizationParameters.zeroPoint; + inputUnsigned = quantizationParameters.isUnsigned; + } + + // check if weight is channelwise quantized + bool weightIsChannelwiseQuantized = false; + if (auto makeWeightPerChannelQuantizedOp = + op.getWeight() + .getDefiningOp()) { + weightIsChannelwiseQuantized = true; + QuantizationValues quantizationParameters = + getQuantizationPerChannelValues( + rewriter, loc, makeWeightPerChannelQuantizedOp, typeConverter); + weight = quantizationParameters.self; + weightZp = quantizationParameters.zeroPoint; + weightUnsigned = quantizationParameters.isUnsigned; } if (static_cast(inputZp) != static_cast(weightZp)) { @@ -1242,6 +1502,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { !isa(weightDTy) || !isa(resultDTy)) return op.emitError("unimplemented: non-fp not-int type"); + size_t inRank = cast(input.getType()).getRank(); size_t numSpatialDims = inRank - 2; if (numSpatialDims < 1 || numSpatialDims > 3) @@ -1383,7 +1644,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector weightSliceSizes{weightStride, weightChannels}; weightSliceSizes.append(weightDims); - Value conv; // the code so far is able to respect all numSpatialDims // the code below this point is numSpatialDims specific and // convolutionAttributes->groups specific @@ -1405,6 +1665,23 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } } + + if (convolutionAttributes->groups == 1 && + (inputIsChannelwiseQuantized || weightIsChannelwiseQuantized)) { + if (failed(handleUngroupedConvChannelwiseQuantized(rewriter, loc, + weight, paddedInput, + outputTensor, inputZp, weightZp, + numSpatialDims, inRank, + stridesAttr, dilationAttr, + inputIsChannelwiseQuantized, weightIsChannelwiseQuantized, + inputDTy, weightDTy, accumulatorDType, resultDTy, + op, getTypeConverter()))){ + return failure(); + } + else { + return success(); + } + } if (convolutionAttributes->groups == 1 && inputZp) { if (failed(handleUngroupedConvQuantized(rewriter, loc, diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e3f5b6d0299a..39bda598cd6a 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -247,7 +247,7 @@ SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, // TODO: remove this when list gets full support. SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, const TypeConverter *converter, - SmallVectorImpl &vs) { + const SmallVectorImpl &vs) { return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { return converter->materializeTargetConversion( b, loc, converter->convertType(v.getType()), v); diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 480b1eeb9ed2..2d6206b8d744 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope | FileCheck %s // CHECK-LABEL: func @torch.aten.convolution$nobias( // CHECK: %[[CONSTANT:.*]] = arith.constant 0.000000e+00 : f32 @@ -53,6 +53,41 @@ func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtens // ----- +/// Checking only the important stuff. We should get one linalg.generic for the zero point scalar of the input and one linalg.sub for the weight zeropoint tensor. +// CHECK-LABEL: func.func @test_qlinearconv_nobias( +// CHECK: %[[VAL_38:.*]] = linalg.fill +// CHECK: %[[VAL_39:.*]] = arith.trunci +// CHECK: %[[VAL_41:.*]] = linalg.generic +// CHECK: ^bb0 +// CHECK: %[[VAL_44:.*]] = arith.subi +// CHECK: %[[VAL_46:.*]] = linalg.broadcast +// CHECK: %[[VAL_48:.*]] = linalg.fill +// CHECK: %[[VAL_49:.*]] = linalg.sub +// CHECK: %[[VAL_50:.*]] = linalg.conv_2d_nchw_fchw +// CHECK: } +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[4,1,1,1],ui8>, %arg4: !torch.vtensor<[4],f32>, %arg5: !torch.vtensor<[4],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4,7,7],!torch.qint32> { + %int13 = torch.constant.int 13 + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %none = torch.constant.none + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + %2 = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + %3 = torch.aten._make_per_tensor_quantized_tensor %arg0, %2, %1 : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + %4 = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %int0 : !torch.vtensor<[4,1,1,1],ui8>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4,1,1,1],!torch.quint8> + %5 = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + %6 = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + %7 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %8 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %9 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %10 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %11 = torch.aten.convolution %3, %4, %none, %9, %7, %8, %false, %10, %int1 : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[4,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,7,7],!torch.qint32> + return %11 : !torch.vtensor<[1,4,7,7],!torch.qint32> +} + +// ----- + // CHECK-LABEL: func.func @conv_broadcast( // CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,80,3000],f32>, // CHECK-SAME: %[[arg1:.*]]: !torch.vtensor<[1024,80,3],f32>, @@ -75,4 +110,4 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch. %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list %2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1024,3000],f32> return %2 : !torch.vtensor<[1,1024,3000],f32> -} +} \ No newline at end of file