Skip to content

Commit

Permalink
Adds accumulator types in TorchToLinalg for AtenMmOp and `AtenConvo…
Browse files Browse the repository at this point in the history
…lutionOp` (llvm#3027)
  • Loading branch information
nithinsubbiah authored Mar 14, 2024
1 parent 0b2f9c8 commit 798bfd7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
55 changes: 48 additions & 7 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,14 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
"mismatching contracting dimension for torch.aten.mm"));
}

auto resultTy = op.getType().cast<ValueTensorType>();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) {
elementType = accumulatorDType;
}
Value zeroFill = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);

Expand Down Expand Up @@ -163,6 +169,13 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
}

if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType);
}
// When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result
Expand Down Expand Up @@ -875,18 +888,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
castIndexToInt(weightDims[i]), strideIntValues[i]));
}

Type accumulatorDType = getDefaultAccType(rewriter, resultDTy);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outDims), resultDTy);
loc, getAsOpFoldResult(outDims), accumulatorDType);

Value outputTensor;
if (accumulatorDType != resultDTy)
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) {
Value c0;
if (resultDTy.isa<mlir::FloatType>()) {
c0 = rewriter.create<arith::ConstantOp>(loc,
FloatAttr::get(resultDTy, 0.0));
} else if (resultDTy.isa<mlir::IntegerType>()) {
c0 = rewriter.create<arith::ConstantOp>(loc,
IntegerAttr::get(resultDTy, 0));
if (accumulatorDType.isa<mlir::FloatType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(accumulatorDType, 0.0));
} else if (accumulatorDType.isa<mlir::IntegerType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(accumulatorDType, 0));
}
outputTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Expand Down Expand Up @@ -973,6 +990,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1027,6 +1050,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1065,6 +1094,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1137,6 +1172,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,5 +553,5 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getI64Type();
if (inputType.isSignedInteger(64))
return rewriter.getI64Type();
llvm::report_fatal_error("unhandled type for getDefaultAccType");
return inputType;
}

0 comments on commit 798bfd7

Please sign in to comment.