Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: added struct for ConvolutionAttributes and function to preprocessPadding #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 137 additions & 90 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,77 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
} // namespace

namespace {
struct ConvolutionAttributes {
SmallVector<Value> padding;
SmallVector<Value> outputPadding;
SmallVector<int64_t> stride;
SmallVector<Value, 4> strideValues;
SmallVector<int64_t> dilation;
SmallVector<Value, 4> dilationValues;
Value groupsValue;
int64_t groups;
bool transposed;
};

FailureOr<ConvolutionAttributes>
getConvolutionAttributes(ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *const typeConverter,
AtenConvolutionOp op,
AtenConvolutionOpAdaptor &adaptor) {
ConvolutionAttributes attributes;
if (!getListConstructElements(op.getPadding(), attributes.padding))
return rewriter.notifyMatchFailure(
op, "only support padding from a list construct");
attributes.padding =
getTypeConvertedValues(rewriter, loc, typeConverter, attributes.padding);

if (!getListConstructElements(op.getOutputPadding(),
attributes.outputPadding))
return rewriter.notifyMatchFailure(
op, "only support output_padding from a list construct");
attributes.outputPadding = getTypeConvertedValues(
rewriter, loc, typeConverter, attributes.outputPadding);

if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(attributes.stride)))
return rewriter.notifyMatchFailure(op, "only support constant int strides");
attributes.strideValues =
getAsConstantIntValues(rewriter, loc, attributes.stride);

if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(attributes.dilation)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
attributes.dilationValues =
getAsConstantIntValues(rewriter, loc, attributes.dilation);

if (!matchPattern(op.getGroups(), m_TorchConstantInt(&attributes.groups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
attributes.groupsValue = castIntToIndex(rewriter, loc, adaptor.getGroups());
return attributes;
}

Value preprocessPadding(ConversionPatternRewriter &rewriter, Value pad,
AtenConvolutionOp op, Type inputDataType) {
if (!pad) {
if (isa<mlir::FloatType>(inputDataType))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDataType, 0.0));
if (isa<mlir::IntegerType>(inputDataType))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDataType, 0));
}
if (pad.getType() != inputDataType) {
if (isa<mlir::FloatType>(inputDataType))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDataType, pad);

if (isa<mlir::IntegerType>(inputDataType))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDataType, pad);
}
return pad;
}

class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -750,10 +821,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.getInput(); /* in form of N*C*H*W */
Value weight = adaptor.getWeight(); /* in form of F*C/G*H*W */
Value input = adaptor.getInput(); /* in form of N*C*IH*IW */
Value weight = adaptor.getWeight(); /* in form of F*C/G*KH*KW */
Value bias = adaptor.getBias();
auto resultTy = cast<ValueTensorType>(op.getType());
auto resultTy =
cast<ValueTensorType>(op.getType()); /* in form of N*F*OH*OW */

Value inputZp, weightZp;
bool inputUnsigned = false;
Expand Down Expand Up @@ -826,47 +898,25 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.createOrFold<arith::IndexCastOp>(loc, intType, v);
};

SmallVector<Value> paddingIntValues;
if (!getListConstructElements(op.getPadding(), paddingIntValues))
return rewriter.notifyMatchFailure(
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
SmallVector<Value> outputPaddingIntValues;
if (!getListConstructElements(op.getOutputPadding(),
outputPaddingIntValues))
return rewriter.notifyMatchFailure(
op, "only support output_padding from a list construct");
outputPaddingIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputPaddingIntValues);
SmallVector<int64_t> strideInts;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t> dilationInts;
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
FailureOr<ConvolutionAttributes> convolutionAttributes =
getConvolutionAttributes(rewriter, loc, getTypeConverter(), op,
adaptor);
if (failed(convolutionAttributes))
return failure();

Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
Value groups = convolutionAttributes->groupsValue;

Value inputBatch = getDimOp(rewriter, loc, input, 0);
Value inputChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inputSpatialDimensions;
for (size_t i = 2; i < inRank; i++)
inDims.push_back(getDimOp(rewriter, loc, input, i));
inputSpatialDimensions.push_back(getDimOp(rewriter, loc, input, i));
Value weightBatch = getDimOp(rewriter, loc, weight, 0);
Value weightChannels = getDimOp(rewriter, loc, weight, 1);
SmallVector<Value> weightDims;
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Expand All @@ -876,14 +926,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
rewriter.create<cf::AssertOp>(loc, inputValid,
rewriter.getStringAttr(err));
};
validate(inChannels,
validate(inputChannels,
"invalid: groups must divide input channel size evenly.");
validate(weightBatch,
"invalid: groups must divide weight batch size evenly.");
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// convert any uint8 quantization to int8 quantization
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
Expand All @@ -895,24 +941,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
}
// Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch};
SmallVector<Value> outDims{inputBatch, weightBatch};
Value paddedInput;
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);

if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
Value pad = inputZp;
pad = preprocessPadding(rewriter, pad, op, inputDTy);
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Expand Down Expand Up @@ -956,26 +989,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
.getResult(0);

// Calculate padded input size, allocate tensor
SmallVector<Value> outerSizes{inBatch, inChannels};
SmallVector<Value> innerSizes{inBatch, inChannels};
SmallVector<Value> outerSizes{inputBatch, inputChannels};
SmallVector<Value> innerSizes{inputBatch, inputChannels};
SmallVector<Value> offsets{c0, c0};
for (size_t i = 0; i < numSpatialDims; i++) {
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
Value innerSize = rewriter.create<arith::SubIOp>(loc, inputSpatialDimensions[i], c1);
innerSize = rewriter.create<arith::MulIOp>(
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
loc, innerSize, castIntToIndex(rewriter, loc, convolutionAttributes->strideValues[i]));
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);

Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
offset = rewriter.create<arith::MulIOp>(
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
loc, offset, castIntToIndex(rewriter, loc, convolutionAttributes->dilationValues[i]));
offset = rewriter.create<arith::SubIOp>(
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
loc, offset, castIntToIndex(rewriter, loc, convolutionAttributes->padding[i]));

Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);
outerSize = rewriter.create<arith::AddIOp>(
loc, outerSize,
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
castIntToIndex(rewriter, loc, convolutionAttributes->outputPadding[i]));

outerSizes.push_back(outerSize);
offsets.push_back(offset);
Expand All @@ -987,7 +1020,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
for (auto stride : strideIntValues)
for (auto stride : convolutionAttributes->strideValues)
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

Expand All @@ -998,30 +1031,39 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
// Calculate output dims
for (size_t i = 0; i < numSpatialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i],
outputPaddingIntValues[i]));
rewriter, loc, inputSpatialDimensions[i],
convolutionAttributes->padding[i],
convolutionAttributes->dilationValues[i],
castIndexToInt(weightDims[i]),
convolutionAttributes->strideValues[i],
convolutionAttributes->outputPadding[i]));

// Set stride to 1
strideInts.clear();
strideInts.append(numSpatialDims, 1);
convolutionAttributes->stride.clear();
convolutionAttributes->stride.append(numSpatialDims, 1);
} else {
// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
op, rewriter, input, convolutionAttributes->padding,
/*unpaddedDims=*/2, pad);

// Calculate output dims
for (size_t i = 0; i < numSpatialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
rewriter, loc, inputSpatialDimensions[i],
convolutionAttributes->padding[i],
convolutionAttributes->dilationValues[i],
castIndexToInt(weightDims[i]),
convolutionAttributes->strideValues[i]));
}

Type accumulatorDType = getDefaultAccType(rewriter, inputDTy);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outDims), accumulatorDType);

Value outputTensor;
// Bias is optional, if it is not provided, we initialize the output tensor
// with 0.
if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
accumulatorDType);
Expand All @@ -1038,6 +1080,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);

} else {
// If bias is provided, we initialize the output tensor with bias. This
// saves the need to add bias later.
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
Expand All @@ -1055,11 +1099,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
->getResult(0);
}

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
auto stridesAttr = rewriter.getI64VectorAttr(convolutionAttributes->stride);
auto dilationAttr = rewriter.getI64VectorAttr(convolutionAttributes->dilation);

Value inputStride =
rewriter.create<arith::FloorDivSIOp>(loc, inChannels, groups);
rewriter.create<arith::FloorDivSIOp>(loc, inputChannels, groups);
Value weightStride =
rewriter.create<arith::FloorDivSIOp>(loc, weightBatch, groups);

Expand All @@ -1069,21 +1113,21 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
loc, rewriter.getIndexAttr(1)));
SmallVector<Value> outDimSlice(outDims);
outDimSlice[1] = weightStride;
SmallVector<Value> inputSliceSizes{inBatch, inputStride};
inputSliceSizes.append(inDims);
SmallVector<Value> inputSliceSizes{inputBatch, inputStride};
inputSliceSizes.append(inputSpatialDimensions);
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
weightSliceSizes.append(weightDims);

Value conv;
// the code so far is able to respect all numSpatialDims
// the code below this point is numSpatialDims specific and numGroups
// specific
// the code below this point is numSpatialDims specific and
// convolutionAttributes->groups specific
// TODO: factor out the above code into a helper function, and then separate
// convolution into:
// - grouped 1d-3d
// - grouped 1d-3d (quantized)
// - ungrouped 1d-3d
if (numGroups == 1 && !inputZp) {
if (convolutionAttributes->groups == 1 && !inputZp) {
switch (numSpatialDims) {
case 1:
conv = rewriter
Expand Down Expand Up @@ -1124,7 +1168,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numGroups == 1 && inputZp) {
if (convolutionAttributes->groups == 1 && inputZp) {
switch (numSpatialDims) {
case 2:
conv = rewriter
Expand Down Expand Up @@ -1188,14 +1232,15 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
}

// Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
// apparently disagree.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero
// multiple of groups) to be depthwise in their documentation, but the
// linalg ops apparently disagree.
auto inShape = makeShapeTorchCompatible(
cast<RankedTensorType>(input.getType()).getShape());
auto weightShape = makeShapeTorchCompatible(
cast<RankedTensorType>(weight.getType()).getShape());
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
if (inShape[1] == convolutionAttributes->groups &&
weightShape[0] == convolutionAttributes->groups &&
weightShape[1] == 1) {
// Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
Expand Down Expand Up @@ -1295,12 +1340,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<int64_t> outShape;
for (auto i = 0; i < (long)inShape.size(); i++) {
if (i == 1) {
outShape.push_back(numGroups);
outShape.push_back(convolutionAttributes->groups);
}
if (i == (long)dim) {
outShape.push_back(inShape[i] == kUnknownSize
? kUnknownSize
: inShape[i] / numGroups);
: inShape[i] / convolutionAttributes->groups);
} else {
outShape.push_back(inShape[i]);
}
Expand All @@ -1326,8 +1371,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto inShape = makeShapeTorchCompatible(inType.getShape());

SmallVector<int64_t> outShape{
numGroups,
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
convolutionAttributes->groups,
(inShape[0] == kUnknownSize
? kUnknownSize
: inShape[0] / convolutionAttributes->groups)};
outShape.append(inShape.begin() + 1, inShape.end());

SmallVector<ReassociationIndices> indices{{0, 1}};
Expand Down