Skip to content

Commit

Permalink
BladeDISC related patches
Browse files Browse the repository at this point in the history
* fix float width
* fix divide_floor & export promoteTypes api (#9)
* To comply with the old pytorch versions
* Add native_dropout_backward & native_layer_norm_backward decomposition (#15)
* add native_dropout and related ops pattern (#1211)
* [MHLO] fix dot general contract
* Fix batch_norm, div.Tensor_mode and folder (#21)
* reimplement linear lowering
* reimplement 2-D rhs for mutmul
* add torchdynamo
  • Loading branch information
Tanyo Kwok committed Nov 9, 2022
1 parent 9a73b9e commit 6c929e7
Show file tree
Hide file tree
Showing 23 changed files with 798 additions and 166 deletions.
28 changes: 27 additions & 1 deletion include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5977,9 +5977,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
Pure,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
ReadOnly,
]> {
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
Expand Down Expand Up @@ -6510,6 +6511,31 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [
}];
}

def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$results
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
Expand All @@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
56 changes: 38 additions & 18 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;

bool skipMultiplyAlpha(Value alphaValue) {
Expand Down Expand Up @@ -112,16 +113,19 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
if (!selfTy)
return op.emitError("only Tensor types supported in MHLO");

if (selfTy.getElementType().isa<mlir::FloatType>()) {
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());
if (selfTy != outTy) {
auto out = rewriter.create<MhloOpT>(op.getLoc(), selfTy, self);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, out);
return success();
} else {
rewriter.replaceOpWithNewOp<MhloOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
self);
return success();
} else {
return op.emitError(
"only floating-point datatype legalization supported");
}
}
};
Expand Down Expand Up @@ -307,15 +311,10 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);

if (!isa<AtenDivTensorModeOp>(op)) {
rewriter.replaceOp(op, result);
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, nullptr);
return success();
}

Expand All @@ -327,6 +326,17 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");

auto computeTy = outType;
if (outElemTy.isIntOrIndex()) {
computeTy =
RankedTensorType::get(outType.getShape(), rewriter.getF32Type());
}
lhs = mhlo::promoteType(rewriter, lhs, computeTy);
rhs = mhlo::promoteType(rewriter, rhs, computeTy);
auto loc = op.getLoc();
auto result =
rewriter.create<ChloOpT>(loc, computeTy, lhs, rhs, nullptr).getResult();

if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
Expand All @@ -340,7 +350,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, result);
return success();
}
};
Expand Down Expand Up @@ -401,6 +411,9 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
std::is_same<AtenOpT, AtenGtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GT);
} else if (std::is_same<AtenOpT, AtenGeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GE);
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
Expand Down Expand Up @@ -723,7 +736,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
return success();
}

Expand All @@ -749,7 +766,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
return success();
}

Expand Down Expand Up @@ -1203,7 +1224,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
#undef INSERT_UNARY_FPONLY_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -1240,6 +1261,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(

INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
Expand All @@ -1259,12 +1281,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenErfOp);

INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
Expand Down
Loading

0 comments on commit 6c929e7

Please sign in to comment.