Skip to content

Commit

Permalink
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout an…
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Feb 27, 2024
1 parent 3cbe6c9 commit d81747e
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 21 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11231,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
Expand All @@ -11254,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
Expand Down
38 changes: 32 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value ratio, trainingMode;
if (numOperands == 3) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
Value trainVal = operands[2];
auto trainTensorType =
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
if (!trainTensorType)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have a type");

Type inputDtype = trainTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(1))
return rewriter.notifyMatchFailure(
binder.op,
"train tensor must have an integer dtype of width 1");

std::optional<unsigned> inputRank = Torch::getTensorRank(trainVal);
if (!inputRank || *inputRank != 0)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have rank 0");

if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else {
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
}
} else if (numOperands == 2) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand Down
20 changes: 14 additions & 6 deletions lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ class ConvertPrimNumToTensorScalarOp
} // namespace

namespace {
class ConvertAtenScalarImplicitOp
: public OpConversionPattern<AtenScalarImplicitOp> {
// Converts a tensor with one element to a scalar value.
template <typename OpTy>
class ConvertAtenImplicitLikeOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op,
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
return success();
Expand Down Expand Up @@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
patterns.add<ConvertAtenScalarImplicitOp>(typeConverter, context);
target.addIllegalOp<AtenScalarImplicitOp>();
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
context);
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
context);
patterns.add<ConvertAtenImplicitLikeOp<AtenIntImplicitOp>>(typeConverter,
context);
target.addIllegalOp<AtenScalarImplicitOp, AtenFloatImplicitOp,
AtenIntImplicitOp>();
}
14 changes: 9 additions & 5 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(div.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
div.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::FloatType>())
return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (dtype.isa<mlir::IntegerType>()) {
if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs);
}
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
}
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands);
Expand Down
32 changes: 32 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenFloatImplicitOp
//===----------------------------------------------------------------------===//
void AtenFloatImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarFloatValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenIntImplicitOp
//===----------------------------------------------------------------------===//
void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenSizeOp
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 7 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@

# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",

"FloatImplicitModule_basic",
"IntImplicitModule_basic",
}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -989,6 +992,8 @@
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseEqBoolScalarModule_basic",
Expand Down Expand Up @@ -2146,8 +2151,6 @@
"ElementwiseSigmoidIntModule_basic",

# Failure - unknown
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
Expand All @@ -2168,6 +2171,8 @@
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"FloatImplicitModule_basic",
"IntImplicitModule_basic",
}

ONNX_CRASHING_SET = { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::FloatImplicit : (Tensor) -> (float)")
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
Expand Down
44 changes: 44 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100))


# ==============================================================================


class FloatImplicitModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.float64, True),
])
def forward(self, x):
return float(torch.ops.aten.FloatImplicit(x))


@register_test_case(module_factory=lambda: FloatImplicitModule())
def FloatImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double())


# ==============================================================================


class IntImplicitModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return float(torch.ops.aten.IntImplicit(x))


@register_test_case(module_factory=lambda: IntImplicitModule())
def IntImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.randint())


# ==============================================================================

class PowIntFloat(torch.nn.Module):
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseDivTensorIntegerModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int32, True),
])
def forward(self, a, b):
return torch.div(a, b)


@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule())
def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32))


# ==============================================================================


class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.uint8, True),
([-1, -1], torch.uint8, True),
])
def forward(self, a, b):
return torch.div(a, b)


@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule())
def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8))


# ==============================================================================


class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):

def __init__(self):
Expand Down
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3

// -----

// CHECK-LABEL: @test_div_int32
func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32>
%0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32>
return %0 : !torch.vtensor<[3,4,5],si32>
}

// -----

// CHECK-LABEL: @test_div_uint8
func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8>
Expand Down
46 changes: 46 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
return %1 : !torch.number
}

// -----

// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
%float1 = torch.constant.float 1.0
%0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}

// -----

// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}

// -----

// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}

// -----

// CHECK-LABEL: func.func @torch.prims.view_of$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
Expand Down

0 comments on commit d81747e

Please sign in to comment.