Skip to content

Commit

Permalink
[TorchToLinalg] Add aten.fft_rfft and lowering (llvm#3857)
Browse files Browse the repository at this point in the history
- Add `AtenFftRfftOp` to Torch dialect.
- Add conversion of `AtenFftRfftOp` to Linalg, using a `linalg.matmul`
per output component (real and imaginary). Computing the DFT is
_O(n^2)_.
- Add decomposition of `AtenFftRfftOp` into Torch-level ops (same
paradigm as above).
- Add unit and end-to-end tests.
  • Loading branch information
giacs-epic authored Nov 27, 2024
1 parent 4498569 commit 46a5772
Show file tree
Hide file tree
Showing 12 changed files with 646 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13323,6 +13323,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [
}];
}

def Torch_AtenFftRfftOp : Torch_Op<"aten.fft_rfft", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$n,
Torch_IntType:$dim,
AnyTorchOptionalStringType:$norm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFftRfftOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenFftRfftOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace Torch {

int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
Value toIntListConstruct(PatternRewriter &rewriter, Location loc,
ArrayRef<int64_t> cstInput);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
/// Returns the index indicated by `v` for a list of given `length`.
/// If the index is negative, it is adjusted to `length` + `v`.
Expand Down
191 changes: 191 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -1376,6 +1377,194 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
};
} // namespace

namespace {

/// Creates coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
Value getDFTMatmulCoeff(OpBuilder b, Location loc,
RankedTensorType matrixType) {

ComplexType complexTy = llvm::cast<ComplexType>(matrixType.getElementType());
mlir::FloatType floatType =
llvm::cast<mlir::FloatType>(complexTy.getElementType());

// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getDimSize(0);

SmallVector<std::complex<APFloat>> values;
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
double realV = cos(v);
double imagV = -sin(v);

bool unused;
APFloat real(realV);
real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
APFloat imag(imagV);
imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);

values.push_back(std::complex<APFloat>(real, imag));
}
}
return b.create<arith::ConstantOp>(
loc, matrixType, DenseElementsAttr::get(matrixType, values));
}

struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = adaptor.getSelf();

int64_t dim;
auto dimVal = op.getDim();
if (isa<torch::Torch::NoneType>(dimVal.getType())) {
dim = -1;
} else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires dim to be constant");
}

if (!isa<torch::Torch::NoneType>(op.getN().getType())) {
return rewriter.notifyMatchFailure(op, "unimplemented: parameter n");
}

if (!isa<torch::Torch::NoneType>(op.getNorm().getType())) {
return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm");
}

RankedTensorType inputType =
cast<RankedTensorType>(adaptor.getSelf().getType());
if (!inputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "unsupported: only ranked tensors are supported");
}

const ArrayRef<int64_t> inputShape = inputType.getShape();
dim += dim < 0 ? inputShape.size() : 0;

const int64_t fftLength = inputShape[dim];
if (fftLength == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op, "unsupported: FFT signal length must be static");
}
const int64_t rank = inputType.getRank();
const int64_t lastDim = rank - 1;
const int64_t outputFftDim = fftLength / 2 + 1;
const bool needTranspose = dim != lastDim;

// Transpose if FFT dimension is not the last one
llvm::SmallVector<int64_t> perms = llvm::to_vector(llvm::seq(rank));
std::swap(perms[dim], perms[lastDim]);
if (needTranspose) {
self = transposeValue(loc, self, perms, rewriter);
}

RankedTensorType newResultType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
ComplexType complexElemType =
llvm::cast<ComplexType>(newResultType.getElementType());
Type elemType = complexElemType.getElementType();

// coeffMatrix : tensor<fftLength x outputFftDim x complex<f32>>
RankedTensorType coeffType =
RankedTensorType::get({fftLength, outputFftDim}, complexElemType);
// coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N)
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType);

// #matmul_trait = {
// indexing_maps = [
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>,
// affine_map<(d_0, ... d_m, f, o) -> (f, o)>,
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)>
// ],
// iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"]
// }
// linalg.generic #matmul_trait
// ins(%A, %B : tensor<D_0 x ... x D_m x fftLength x f32>,
// tensor<fftLength x outputFftDim x complex<f32>>)
// outs(%C : tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>) {
// ^bb0(%a: f32, %b: complex<f32>, %c: complex<f32>) :
// %re = complex.re %b : f32
// %im = complex.im %b : f32
// %mulre = arith.mulf %a, %re: f32
// %mulim = arith.mulf %a, %im: f32
// %mulcplx = complex.create %mulre, %mulim : complex<f32>
// %add = complex.add %c, %mulcplx: complex<f32>
// linalg.yield %add : complex<f32>
// } -> (tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>)

Value lhs = self;
Value rhs = coeffMatrix;
RankedTensorType lhsType = llvm::cast<RankedTensorType>(lhs.getType());
ArrayRef<int64_t> lhsShape(lhsType.getShape());
ArrayRef<int64_t> rhsShape(coeffType.getShape());

unsigned batchRank = lhsShape.size() - 1;

SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<utils::IteratorType> iteratorTypes(
batchRank, utils::IteratorType::parallel);
SmallVector<Value> resultShape;
for (unsigned i = 0; i < batchRank; i++) {
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
resultShape.push_back(getDimOp(rewriter, loc, lhs, i));
}
unsigned fIdx = batchRank, oIdx = batchRank + 1;
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx),
rewriter.getAffineDimExpr(oIdx)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)});
resultShape.insert(resultShape.end(),
{getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)});

Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, complexElemType);
auto indexingMaps = AffineMap::inferFromExprList(
{lhsExpr, rhsExpr, outExpr}, rewriter.getContext());
iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction,
utils::IteratorType::parallel});

Value complexRes =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(),
/*inputs=*/ValueRange{lhs, rhs},
/*outputs=*/zeroTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value re = b.create<complex::ReOp>(loc, elemType, r);
Value im = b.create<complex::ImOp>(loc, elemType, r);
Value mulRe = b.create<arith::MulFOp>(loc, l, re);
Value mulIm = b.create<arith::MulFOp>(loc, l, im);
Value mulCplx = b.create<complex::CreateOp>(
loc, complexElemType, mulRe, mulIm);
Value add = b.create<complex::AddOp>(loc, mulCplx, res);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);

// Transpose back
if (needTranspose) {
complexRes = transposeValue(loc, complexRes, perms, rewriter);
}

rewriter.replaceOp(op, complexRes);
return success();
}
};

} // namespace

void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand All @@ -1390,4 +1579,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
target.addIllegalOp<AtenFftRfftOp>();
patterns.add<ConvertAtenFftRfftOp>(typeConverter, context);
}
82 changes: 82 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10936,6 +10936,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_rfft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Expected dim in [-rank, rank-1]\"\n"
" %false = torch.constant.bool false\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %10 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.add.int %arg2, %10 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %11 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg2 : !torch.int\n"
" }\n"
" %2 = torch.aten.ge.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %10 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.lt.int %1, %10 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %5, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %10 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.append.t %4, %10 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.aten.add.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.aten._set_item.t %4, %1, %8 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n"
" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n"
Expand Down Expand Up @@ -13077,6 +13121,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fft_rfft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
" %int10 = torch.constant.int 10\n"
" %int7 = torch.constant.int 7\n"
" %int9 = torch.constant.int 9\n"
" %int6 = torch.constant.int 6\n"
" %int8 = torch.constant.int 8\n"
" %int5 = torch.constant.int 5\n"
" %0 = torch.prim.Uninitialized : !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %int8 : !torch.int\n"
" } else {\n"
" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n"
" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int10 : !torch.int\n"
" } else {\n"
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" }\n"
" torch.prim.If.yield %7 : !torch.int\n"
" }\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
" %int7 = torch.constant.int 7\n"
Expand Down
Loading

0 comments on commit 46a5772

Please sign in to comment.