Skip to content

Commit

Permalink
Merge branch 'main' into Added-support-for-torch-arange-float-module
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-TyRnT committed Feb 27, 2024
2 parents c7e6780 + ba6ba92 commit 142d14e
Show file tree
Hide file tree
Showing 30 changed files with 1,106 additions and 328 deletions.
54 changes: 54 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -11206,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
Expand All @@ -11229,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
Expand Down Expand Up @@ -12353,6 +12380,33 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at
}];
}

def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$grid,
Torch_IntType:$interpolation_mode,
Torch_IntType:$padding_mode,
Torch_BoolType:$align_corners
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGridSamplerOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenGridSamplerOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
23 changes: 0 additions & 23 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,29 +309,6 @@ inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
return intAttr.getValue().getSExtValue();
}

/// Returns the value from an `IntegerAttr` as an integral index.
///
/// @param intAttr the `IntegerAttr` from which to extract the index
/// @param dimSize the size of the dimension that the attribute indexes into
/// @return the index value
///
/// Use this function when the given `IntegerAttr` represents an index into
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
/// negative index values are converted into positive vales by counting
/// elements from the "right" side of the dimension, as in python, numpy, etc.
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
/// given, any negative indices are returned as negative numbers.
///
/// No bounds checking is performed on the index to ensure that it is within
/// the legal range for `dimSize`.
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize + signedIndex; // count backwards from dimSize
}

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
109 changes: 79 additions & 30 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value ratio, trainingMode;
if (numOperands == 3) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
Value trainVal = operands[2];
auto trainTensorType =
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
if (!trainTensorType)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have a type");

Type inputDtype = trainTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(1))
return rewriter.notifyMatchFailure(
binder.op,
"train tensor must have an integer dtype of width 1");

std::optional<unsigned> inputRank = Torch::getTensorRank(trainVal);
if (!inputRank || *inputRank != 0)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have rank 0");

if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else {
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
}
} else if (numOperands == 2) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand Down Expand Up @@ -1571,7 +1597,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"ConstantOfShape", 20,
"ConstantOfShape", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value shape;
Expand All @@ -1582,15 +1608,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto shapeSizes =
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
Type selectResultType = shapeType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

for (int i = 0; i < shapeSizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand All @@ -1601,6 +1626,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
dimList.push_back(dim);
}

Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
Expand All @@ -1609,7 +1635,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

// Get fill_value if it is present.
// Assumption : resultDType and value attr type match.
Value value_const;
auto attr = binder.op->getAttr("torch.onnx.value");
auto resultDType = resultType.getDtype();

Expand All @@ -1620,34 +1645,58 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
resultType.toBuiltinTensor().clone(resultDType),
rewriter.getFloatAttr(resultDType, 0.0));
}
if (!isa<DenseElementsAttr>(attr)) {
return rewriter.notifyMatchFailure(
binder.op, "`value` attr needs to be a tensor.");

// If its a dense resource attr we need to convert to a dense type:
if (DenseResourceElementsAttr rattr =
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
binder.op->emitError(
"unimplemented: importing on big endian systems");
return failure();
}

auto ty = cast<ShapedType>(rattr.getType());
auto ptr = rattr.getRawHandle().getBlob()->getData();
auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
attr = dyn_cast_or_null<SplatElementsAttr>(denseAttr);
}

Attribute splattr;
if (isa<SplatElementsAttr>(attr)) {
auto denseAttr = attr.cast<DenseElementsAttr>();
splattr = denseAttr.getSplatValue<Attribute>();
}

auto denseAttr = attr.cast<DenseElementsAttr>();
auto denseAttrEleType = denseAttr.getElementType();
if (!isa<FloatType, IntegerType>(denseAttrEleType)) {
if (!isa<FloatAttr, IntegerAttr>(splattr)) {
return rewriter.notifyMatchFailure(
binder.op,
"`value` attr tensor only supports types int and float for now.");
}

// Create constant op for value
if (denseAttrEleType.isa<IntegerType>()) {
int64_t intVal = denseAttr.getSplatValue<IntegerAttr>().getSInt();
value_const = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(intVal));
}
if (denseAttrEleType.isa<FloatType>()) {
float floatVal =
denseAttr.getSplatValue<FloatAttr>().getValue().convertToFloat();
value_const = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(floatVal));
Value splatvalue;
if (auto intattr = dyn_cast<IntegerAttr>(splattr)) {
IntegerType intty = cast<IntegerType>(intattr.getType());
int64_t value;
if (intty.isUnsignedInteger()) {
value = intattr.getUInt();
} else if (intty.isSignedInteger()) {
value = intattr.getSInt();
} else {
value = intattr.getInt();
}
splatvalue =
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), value);
}

if (auto fpattr = dyn_cast<FloatAttr>(splattr))
splatvalue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getF64FloatAttr(fpattr.getValueAsDouble()));

rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal,
binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal,
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
return success();
});
Expand Down
39 changes: 31 additions & 8 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,

Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
Expand All @@ -76,27 +77,49 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
builtinTypeStart, zero, dimSize);
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
dimSize, dimSize);

// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
// We cannot use to positive valid dim as for negative strides we need to
// clamp to `-1` so that the full tensor bounds are available:
Value end = builtinTypeEnd;
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
end = dimSize;
} else {
end = castIntToIndex(rewriter, loc, end);
Value endcmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, end, zero);
Value endadd = rewriter.create<arith::AddIOp>(loc, end, dimSize);
end = rewriter.create<arith::SelectOp>(loc, endcmp, endadd, end);
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, end,
zero);
end = rewriter.create<arith::SelectOp>(loc, endcmp, negone, end);
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, end,
dimSize);
end = rewriter.create<arith::SelectOp>(loc, endcmp, dimSize, end);
}

// Slice logic: resultSize = floordiv(end - start + step - 1, step)
resultShape = getTensorSizes(rewriter, loc, input);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);

// We check the difference between start and end to determine the total size:
Value stepcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
stepIndex, zero);
Value stepsign = rewriter.create<arith::SelectOp>(loc, stepcmp, one, negone);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, stepsign);
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);

// Clamp the size to [0, ...]:
Value szcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
resultSize, zero);
resultSize = rewriter.create<arith::SelectOp>(loc, szcmp, zero, resultSize);
resultShape[dim] = resultSize;

strides.resize(inputType.getRank(), one);
offsets.resize(inputType.getRank(), zero);

offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
strides[dim] = stepIndex;
return success();
}

Expand Down
Loading

0 comments on commit 142d14e

Please sign in to comment.