Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lowering of channelwise quantization to linalg #10

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ SmallVector<Value> getAsConstantIndexValues(OpBuilder &b, Location loc,
// TODO: remove this when list gets full support.
SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
const TypeConverter *converter,
SmallVectorImpl<Value> &vs);
const SmallVectorImpl<Value> &vs);

mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType,
Expand Down
339 changes: 308 additions & 31 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,32 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
if (!isUnsignedType)
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
// get width of the zero point value if it is a tensor
int64_t bitWidth = 0;
if (isa<RankedTensorType>(zp.getType())) {
auto zpType = cast<RankedTensorType>(zp.getType());
bitWidth = zpType.getElementType().getIntOrFloatBitWidth();
} else {
bitWidth = zp.getType().getIntOrFloatBitWidth();
}
Value minSIValue =
rewriter.create<arith::ConstantIntOp>(loc, minSI, bitWidth);

// Use a linalg.generic op to add the minSIValue to the zero point if it is a
// tensor.
if (isa<RankedTensorType>(zp.getType())) {
zp = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, zp,
cast<RankedTensorType>(zp.getType()).getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result =
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
b.create<linalg::YieldOp>(loc, result);
});
} else {
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
}

minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{arg},
Expand Down Expand Up @@ -742,6 +765,85 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
} // namespace

namespace {
struct QuantizationValues {
Value self;
Value zeroPoint;
bool isUnsigned;
};

QuantizationValues getQuantizationPerTensorValues(
ConversionPatternRewriter &rewriter, Location loc,
Aten_MakePerTensorQuantizedTensorOp makePerTensorQuantizedTensorOp,
const TypeConverter *const typeConverter) {
QuantizationValues values;
Value self = makePerTensorQuantizedTensorOp.getSelf();
Value zeroPoint = makePerTensorQuantizedTensorOp.getZeroPoint();
self = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(self.getType()), self);
zeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(zeroPoint.getType()),
zeroPoint);
zeroPoint =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), zeroPoint);
auto torchDtype =
cast<ValueTensorType>(makePerTensorQuantizedTensorOp.getType())
.getDtype();

values.self = self;
values.zeroPoint = zeroPoint;
values.isUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);

return values;
}

QuantizationValues getQuantizationPerChannelValues(
ConversionPatternRewriter &rewriter, Location loc,
Aten_MakePerChannelQuantizedTensorOp makePerChannelQuantizedTensorOp,
const TypeConverter *const typeConverter) {
QuantizationValues values;
Value self = makePerChannelQuantizedTensorOp.getSelf();
Value zeroPoint = makePerChannelQuantizedTensorOp.getZeroPoint();
self = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(self.getType()), self);
zeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(zeroPoint.getType()),
zeroPoint);

// create a linalg op since we need to do some arithmetic on the zero point
// as is it a tensor.
auto zeroPointType = cast<RankedTensorType>(zeroPoint.getType());
auto selfType = cast<RankedTensorType>(self.getType());
if (zeroPointType.getElementTypeBitWidth() >
selfType.getElementTypeBitWidth()) {

zeroPoint = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, zeroPoint, rewriter.getI32Type(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), payloadArgs[0]);
b.create<linalg::YieldOp>(loc, result);
});
} else if (zeroPointType.getElementTypeBitWidth() <
selfType.getElementTypeBitWidth()) {
zeroPoint = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, zeroPoint, rewriter.getI32Type(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = rewriter.create<arith::ExtUIOp>(
loc, rewriter.getI32Type(), payloadArgs[0]);
b.create<linalg::YieldOp>(loc, result);
});
}
auto torchDtype =
cast<ValueTensorType>(makePerChannelQuantizedTensorOp.getType())
.getDtype();

values.self = self;
values.zeroPoint = zeroPoint;
values.isUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);

return values;
}

struct ConvolutionAttributes {
SmallVector<Value> padding;
SmallVector<Value> outputPadding;
Expand Down Expand Up @@ -964,6 +1066,144 @@ LogicalResult handleUngroupedConv(ConversionPatternRewriter &rewriter, Location
return success();
}

LogicalResult handleUngroupedConvChannelwiseQuantized(ConversionPatternRewriter &rewriter, Location loc,
Value &weight, Value &paddedInput,
Value &outputTensor, Value &inputZp, Value &weightZp,
size_t numSpatialDims, size_t inRank,
DenseIntElementsAttr stridesAttr, DenseIntElementsAttr dilationAttr,
bool inputIsChannelwiseQuantized, bool weightIsChannelwiseQuantized,
Type inputDTy, Type weightDTy, Type accumulatorDType, Type resultDTy,
AtenConvolutionOp op, const TypeConverter *typeConverter) {
Value conv;
// There is no linalg.conv op for channelwise quantized arguments.
// Generate generic linalg ops to first substract the zero points and then
// perform the convolution. The zeropoint is either a scalar or a tensor.
if (inputIsChannelwiseQuantized) {
// create zero init tensor with the same shape as input
Value zeroTensor = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, paddedInput), inputDTy);
SmallVector<int64_t, 4> addedDimensions;
for (size_t i = 0; i < inRank; ++i)
if (i != 1)
addedDimensions.push_back(i);
// broadcast the zeropoint to match the input shape
Value broadcastedInputZeroPoint =
rewriter
.create<linalg::BroadcastOp>(loc, inputZp, zeroTensor,
addedDimensions)
->getResult(0);

// create output zero tensor
Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, paddedInput),
inputDTy);
// subtract the zeropoint from the input
paddedInput =
rewriter
.create<linalg::SubOp>(
loc, ValueRange{paddedInput, broadcastedInputZeroPoint},
ValueRange{initTensor})
->getResult(0);
} else {
// else just substract the value
// subtract the zeropoint from the input
auto paddedInputElementType =
cast<RankedTensorType>(paddedInput.getType()).getElementType();
inputZp = rewriter.create<arith::TruncIOp>(loc, paddedInputElementType,
inputZp);
paddedInput = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{paddedInput}, paddedInputElementType,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result =
b.create<arith::SubIOp>(loc, payloadArgs[0], inputZp);
b.create<linalg::YieldOp>(loc, result);
});
}

// Do the same for the weight
if (weightIsChannelwiseQuantized) {
// create zero init tensor with the same shape as weight
Value zeroTensor = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, weight), weightDTy);
SmallVector<int64_t, 4> addedDimensions;
for (size_t i = 1; i < inRank; ++i)
addedDimensions.push_back(i);
// broadcast the zeropoint to match the weight shape
Value broadcastedWeightZeroPoint =
rewriter
.create<linalg::BroadcastOp>(loc, weightZp, zeroTensor,
addedDimensions)
->getResult(0);

// create output zero tensor
Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, weight), weightDTy);

// subtract the zeropoint from the weight
weight = rewriter
.create<linalg::SubOp>(
loc, weight.getType(),
ValueRange{weight, broadcastedWeightZeroPoint},
ValueRange{initTensor})
.getResult(0);
} else {
// else just substract the value
// subtract the zeropoint from the weight
auto paddedWeightElementType =
cast<RankedTensorType>(weight.getType()).getElementType();
weightZp = rewriter.create<arith::TruncIOp>(
loc, paddedWeightElementType, weightZp);
weight = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{weight},
cast<RankedTensorType>(weight.getType()).getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result =
b.create<arith::SubIOp>(loc, payloadArgs[0], weightZp);
b.create<linalg::YieldOp>(loc, result);
});
}

// Perform the convolution with the zero-point subtracted inputs
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::Conv1DNcwFcwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
conv = rewriter
.create<linalg::Conv3DNcdhwFcdhwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
}
Type newResultType = typeConverter->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

LogicalResult handleUngroupedConvQuantized(ConversionPatternRewriter &rewriter, Location loc,
Value &weight, Value &paddedInput,
Value &outputTensor, Value &inputZp, Value &weightZp,
Expand Down Expand Up @@ -1185,35 +1425,55 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Value inputZp, weightZp;
bool inputUnsigned = false;
bool weightUnsigned = false;
if (auto make = op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
input = make.getSelf();
inputZp = make.getZeroPoint();
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(input.getType()), input);
inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp);
inputZp =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
if (auto makeInputPerTensorQuantizedOp =
op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {

QuantizationValues quantizationParameters =
getQuantizationPerTensorValues(
rewriter, loc, makeInputPerTensorQuantizedOp, typeConverter);
input = quantizationParameters.self;
inputZp = quantizationParameters.zeroPoint;
inputUnsigned = quantizationParameters.isUnsigned;
}

if (auto make = op.getWeight()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
weight = make.getSelf();
weightZp = make.getZeroPoint();

weight = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weight.getType()), weight);
weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp);
weightZp = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(),
weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
if (auto makeWeightPerTensorQuantizedOp =
op.getWeight()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
QuantizationValues quantizationParameters =
getQuantizationPerTensorValues(
rewriter, loc, makeWeightPerTensorQuantizedOp, typeConverter);
weight = quantizationParameters.self;
weightZp = quantizationParameters.zeroPoint;
weightUnsigned = quantizationParameters.isUnsigned;
}

// check if input is channelwise quantized
bool inputIsChannelwiseQuantized = false;
if (auto makeInputPerChannelQuantizedOp =
op.getInput()
.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>()) {
inputIsChannelwiseQuantized = true;
QuantizationValues quantizationParameters =
getQuantizationPerChannelValues(
rewriter, loc, makeInputPerChannelQuantizedOp, typeConverter);
input = quantizationParameters.self;
inputZp = quantizationParameters.zeroPoint;
inputUnsigned = quantizationParameters.isUnsigned;
}

// check if weight is channelwise quantized
bool weightIsChannelwiseQuantized = false;
if (auto makeWeightPerChannelQuantizedOp =
op.getWeight()
.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>()) {
weightIsChannelwiseQuantized = true;
QuantizationValues quantizationParameters =
getQuantizationPerChannelValues(
rewriter, loc, makeWeightPerChannelQuantizedOp, typeConverter);
weight = quantizationParameters.self;
weightZp = quantizationParameters.zeroPoint;
weightUnsigned = quantizationParameters.isUnsigned;
}

if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
Expand Down Expand Up @@ -1242,6 +1502,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
return op.emitError("unimplemented: non-fp not-int type");

size_t inRank = cast<RankedTensorType>(input.getType()).getRank();
size_t numSpatialDims = inRank - 2;
if (numSpatialDims < 1 || numSpatialDims > 3)
Expand Down Expand Up @@ -1383,7 +1644,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
weightSliceSizes.append(weightDims);

Value conv;
// the code so far is able to respect all numSpatialDims
// the code below this point is numSpatialDims specific and
// convolutionAttributes->groups specific
Expand All @@ -1405,6 +1665,23 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}
}

if (convolutionAttributes->groups == 1 &&
(inputIsChannelwiseQuantized || weightIsChannelwiseQuantized)) {
if (failed(handleUngroupedConvChannelwiseQuantized(rewriter, loc,
weight, paddedInput,
outputTensor, inputZp, weightZp,
numSpatialDims, inRank,
stridesAttr, dilationAttr,
inputIsChannelwiseQuantized, weightIsChannelwiseQuantized,
inputDTy, weightDTy, accumulatorDType, resultDTy,
op, getTypeConverter()))){
return failure();
}
else {
return success();
}
}

if (convolutionAttributes->groups == 1 && inputZp) {
if (failed(handleUngroupedConvQuantized(rewriter, loc,
Expand Down
Loading