Skip to content

Commit

Permalink
[torch] Folders for torch.aten.*.tensor operators [add, sub, mul] (#…
Browse files Browse the repository at this point in the history
…2878)

Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
  • Loading branch information
rsuderman authored Feb 19, 2024
1 parent cea5189 commit e80054a
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 6 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down
213 changes: 213 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
return success();
}

//===----------------------------------------------------------------------===//
// NAry folder helpers
//===----------------------------------------------------------------------===//

static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFp = true;
bool allInt = true;

for (auto attr : attrs) {
if (!attr)
return false;

Type attrty;
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
attrty = dense.getType();
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
attrty = fp.getType();
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
attrty = integer.getType();
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
attrty = shaped.getElementType();
allFp &= isa<mlir::FloatType>(attrty);
allInt &= isa<mlir::IntegerType>(attrty);
}

return allFp || allInt;
}

static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
for (auto attr : attrs) {
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
if (!dense.isSplat())
return false;
}
}

return true;
}

llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
int64_t idx = 0) {
llvm::SmallVector<double> splattrs;

for (auto attr : attrs) {
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
} else {
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
}
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(intattr.getValueAsDouble());
} else {
return {};
}
}

return splattrs;
}

llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
int64_t bitwidth,
int64_t idx = 0) {
llvm::SmallVector<APInt> splattrs;

for (auto attr : attrs) {
bool isunsigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
splattrs.push_back(dense.getValues<APInt>()[idx]);
}
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
splattrs.push_back(intattr.getValue());
} else {
return {};
}

auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isunsigned) {
apint = apint.zextOrTrunc(bitwidth);
} else {
apint = apint.sextOrTrunc(bitwidth);
}
}
}

return splattrs;
}

using NAryFoldFpOperator = std::function<double(ArrayRef<double>)>;
using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;

static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
NAryFoldFpOperator fpFolder,
NAryFoldIntOperator intFolder) {
constexpr int64_t maxFold = 16;
if (!checkSameDTypes(operands))
return nullptr;

auto resultTy = dyn_cast<ValueTensorType>(ty);
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;

auto dty = resultTy.getDtype();
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);

auto fpTy = dyn_cast<mlir::FloatType>(dty);
auto intTy = dyn_cast<mlir::IntegerType>(dty);
if (!fpTy && !intTy)
return nullptr;

bool allSplats = checkAllSplats(operands);
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;

if (!allSplats && !withinMaxFold)
return nullptr;

// We do not support broadcasting in the non-splat case so validate same
// shaped inputs / outputs:
if (!allSplats) {
auto resultShape = resultBTy.getShape();
for (int i = 0, s = operands.size(); i < s; ++i) {
if (auto dense = dyn_cast<DenseElementsAttr>(operands[i])) {
if (dense.isSplat())
continue;
auto operandShape = cast<ShapedType>(dense.getType()).getShape();
if (operandShape.size() != resultShape.size())
return nullptr;
for (int i = 0, s = operandShape.size(); i < s; ++i)
if (operandShape[i] != resultShape[i])
return nullptr;
}
}
}

const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements();

if (fpTy) {
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs = getFoldValueAtIndexFp(operands, i);
double fold = fpFolder(inputs);

APFloat val(fold);
bool unused;
val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}

if (intTy) {
llvm::SmallVector<APInt> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs =
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
folded.push_back(intFolder(inputs));
}
return DenseElementsAttr::get(resultBTy, folded);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenAddTensorOp
//===----------------------------------------------------------------------===//
Expand All @@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenAddScalarOp
//===----------------------------------------------------------------------===//
Expand All @@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenSubScalarOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenEqTensorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ def emit_with_mutating_variants(key, **kwargs):
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
Expand Down
5 changes: 2 additions & 3 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
}

// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6]] = torch.constant.int 6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[INT6]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
Expand Down
Loading

0 comments on commit e80054a

Please sign in to comment.