Skip to content

Commit

Permalink
refactor: add function to handleUngroupedConvs
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiJPl committed Jan 9, 2025
1 parent 1a5cbae commit f57a99f
Showing 1 changed file with 137 additions and 95 deletions.
232 changes: 137 additions & 95 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,123 @@ void handleTranspose(ConversionPatternRewriter &rewriter, Location loc,
convolutionAttributes.stride.append(numSpatialDims, 1);
}

LogicalResult handleUngroupedConv(ConversionPatternRewriter &rewriter, Location loc,
Value &weight, Value &paddedInput,
Value &outputTensor, size_t numSpatialDims,
DenseIntElementsAttr stridesAttr, DenseIntElementsAttr dilationAttr,
Type accumulatorDType, Type resultDTy,
AtenConvolutionOp op, const TypeConverter *typeConverter) {
Value conv;
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::Conv1DNcwFcwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
conv = rewriter
.create<linalg::Conv3DNcdhwFcdhwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
Type newResultType = typeConverter->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

LogicalResult handleUngroupedConvQuantized(ConversionPatternRewriter &rewriter, Location loc,
Value &weight, Value &paddedInput,
Value &outputTensor, Value &inputZp, Value &weightZp,
size_t numSpatialDims,
DenseIntElementsAttr stridesAttr, DenseIntElementsAttr dilationAttr,
Type accumulatorDType, Type resultDTy,
AtenConvolutionOp op, const TypeConverter *typeConverter) {
Value conv;
switch (numSpatialDims) {
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
break;
case 3: {
// The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering.
llvm::SmallVector<int64_t> inPerms, weightPerms;
inPerms.push_back(0); // N stays at the front for input.
// Then we expect the spatial dimensions
for (size_t i = 0; i < numSpatialDims; ++i) {
inPerms.push_back(i + 2);
weightPerms.push_back(i + 2);
}
inPerms.push_back(1);
weightPerms.append({1, 0});

paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);

conv = rewriter
.create<linalg::Conv3DNdhwcDhwcfQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);

llvm::SmallVector<int64_t> outPerms;
outPerms.push_back(0);
outPerms.push_back(inPerms.size() - 1);
for (size_t i = 0; i < numSpatialDims; ++i) {
outPerms.push_back(i + 1);
}

conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

break;
}
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};

Type newResultType = typeConverter->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1142,107 +1259,32 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
// - grouped 1d-3d (quantized)
// - ungrouped 1d-3d
if (convolutionAttributes->groups == 1 && !inputZp) {
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::Conv1DNcwFcwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
conv = rewriter
.create<linalg::Conv3DNcdhwFcdhwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
if (failed(handleUngroupedConv(rewriter, loc,
weight, paddedInput,
outputTensor, numSpatialDims,
stridesAttr, dilationAttr,
accumulatorDType, resultDTy,
op, getTypeConverter()))){
return failure();
}
else {
return success();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

if (convolutionAttributes->groups == 1 && inputZp) {
switch (numSpatialDims) {
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
break;
case 3: {
// The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering.
llvm::SmallVector<int64_t> inPerms, weightPerms;
inPerms.push_back(0); // N stays at the front for input.
// Then we expect the spatial dimensions
for (size_t i = 0; i < numSpatialDims; ++i) {
inPerms.push_back(i + 2);
weightPerms.push_back(i + 2);
}
inPerms.push_back(1);
weightPerms.append({1, 0});

paddedInput =
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);

conv = rewriter
.create<linalg::Conv3DNdhwcDhwcfQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);

llvm::SmallVector<int64_t> outPerms;
outPerms.push_back(0);
outPerms.push_back(inPerms.size() - 1);
for (size_t i = 0; i < numSpatialDims; ++i) {
outPerms.push_back(i + 1);
}
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

break;
if (failed(handleUngroupedConvQuantized(rewriter, loc,
weight, paddedInput,
outputTensor, inputZp, weightZp,
numSpatialDims,
stridesAttr, dilationAttr,
accumulatorDType, resultDTy,
op, getTypeConverter()))){
return failure();
}
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
else {
return success();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

// Special depthwise case: Cin = Cout = groups.
Expand Down

0 comments on commit f57a99f

Please sign in to comment.