Skip to content

Commit

Permalink
[torch] Folders for torch.aten.*.tensor operators [add, sub, mul] (l…
Browse files Browse the repository at this point in the history
…lvm#2878)

Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
  • Loading branch information
rsuderman authored Feb 19, 2024
1 parent cea5189 commit e80054a
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 6 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down
213 changes: 213 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
return success();
}

//===----------------------------------------------------------------------===//
// NAry folder helpers
//===----------------------------------------------------------------------===//

static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFp = true;
bool allInt = true;

for (auto attr : attrs) {
if (!attr)
return false;

Type attrty;
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
attrty = dense.getType();
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
attrty = fp.getType();
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
attrty = integer.getType();
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
attrty = shaped.getElementType();
allFp &= isa<mlir::FloatType>(attrty);
allInt &= isa<mlir::IntegerType>(attrty);
}

return allFp || allInt;
}

static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
for (auto attr : attrs) {
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
if (!dense.isSplat())
return false;
}
}

return true;
}

llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
int64_t idx = 0) {
llvm::SmallVector<double> splattrs;

for (auto attr : attrs) {
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
} else {
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
}
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(intattr.getValueAsDouble());
} else {
return {};
}
}

return splattrs;
}

llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
int64_t bitwidth,
int64_t idx = 0) {
llvm::SmallVector<APInt> splattrs;

for (auto attr : attrs) {
bool isunsigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
splattrs.push_back(dense.getValues<APInt>()[idx]);
}
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
splattrs.push_back(intattr.getValue());
} else {
return {};
}

auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isunsigned) {
apint = apint.zextOrTrunc(bitwidth);
} else {
apint = apint.sextOrTrunc(bitwidth);
}
}
}

return splattrs;
}

using NAryFoldFpOperator = std::function<double(ArrayRef<double>)>;
using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;

static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
NAryFoldFpOperator fpFolder,
NAryFoldIntOperator intFolder) {
constexpr int64_t maxFold = 16;
if (!checkSameDTypes(operands))
return nullptr;

auto resultTy = dyn_cast<ValueTensorType>(ty);
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;

auto dty = resultTy.getDtype();
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);

auto fpTy = dyn_cast<mlir::FloatType>(dty);
auto intTy = dyn_cast<mlir::IntegerType>(dty);
if (!fpTy && !intTy)
return nullptr;

bool allSplats = checkAllSplats(operands);
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;

if (!allSplats && !withinMaxFold)
return nullptr;

// We do not support broadcasting in the non-splat case so validate same
// shaped inputs / outputs:
if (!allSplats) {
auto resultShape = resultBTy.getShape();
for (int i = 0, s = operands.size(); i < s; ++i) {
if (auto dense = dyn_cast<DenseElementsAttr>(operands[i])) {
if (dense.isSplat())
continue;
auto operandShape = cast<ShapedType>(dense.getType()).getShape();
if (operandShape.size() != resultShape.size())
return nullptr;
for (int i = 0, s = operandShape.size(); i < s; ++i)
if (operandShape[i] != resultShape[i])
return nullptr;
}
}
}

const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements();

if (fpTy) {
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs = getFoldValueAtIndexFp(operands, i);
double fold = fpFolder(inputs);

APFloat val(fold);
bool unused;
val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}

if (intTy) {
llvm::SmallVector<APInt> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs =
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
folded.push_back(intFolder(inputs));
}
return DenseElementsAttr::get(resultBTy, folded);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenAddTensorOp
//===----------------------------------------------------------------------===//
Expand All @@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenAddScalarOp
//===----------------------------------------------------------------------===//
Expand All @@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenSubScalarOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenEqTensorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ def emit_with_mutating_variants(key, **kwargs):
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
Expand Down
5 changes: 2 additions & 3 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
}

// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6]] = torch.constant.int 6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[INT6]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
Expand Down
Loading

0 comments on commit e80054a

Please sign in to comment.