Skip to content

Commit

Permalink
[MLIR][TORCH] Fix failing OnnxToLinalg lowering for aten.dropout op
Browse files Browse the repository at this point in the history
This commit adds the support for failing OnnxToLinalg lowering tests
for aten.dropout op.
This commit also add the TorchToLinalg lowering and canonicalizer
for AtenFloatImplicitOp and AtenIntImplicitOp.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 committed Feb 22, 2024
1 parent 124bd23 commit 0bf4b86
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 16 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11206,6 +11206,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
Expand All @@ -11229,6 +11230,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
Expand Down
38 changes: 32 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value ratio, trainingMode;
if (numOperands == 3) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
Value trainVal = operands[2];
auto trainTensorType =
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
if (!trainTensorType)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have a type");

Type inputDtype = trainTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(1))
return rewriter.notifyMatchFailure(
binder.op,
"train tensor must have an integer dtype of width 1");

std::optional<unsigned> inputRank = Torch::getTensorRank(trainVal);
if (!inputRank || *inputRank != 0)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have rank 0");

if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else {
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
}
} else if (numOperands == 2) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand Down
20 changes: 14 additions & 6 deletions lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ class ConvertPrimNumToTensorScalarOp
} // namespace

namespace {
class ConvertAtenScalarImplicitOp
: public OpConversionPattern<AtenScalarImplicitOp> {
// Converts a tensor with one element to a scalar value.
template <typename OpTy>
class ConvertAtenImplicitLikeOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op,
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
return success();
Expand Down Expand Up @@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
patterns.add<ConvertAtenScalarImplicitOp>(typeConverter, context);
target.addIllegalOp<AtenScalarImplicitOp>();
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
context);
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
context);
patterns.add<ConvertAtenImplicitLikeOp<AtenIntImplicitOp>>(typeConverter,
context);
target.addIllegalOp<AtenScalarImplicitOp, AtenFloatImplicitOp,
AtenIntImplicitOp>();
}
32 changes: 32 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenFloatImplicitOp
//===----------------------------------------------------------------------===//
void AtenFloatImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarFloatValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenIntImplicitOp
//===----------------------------------------------------------------------===//
void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenSizeOp
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 5 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@

# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",

"FloatImplicitModule_basic",
"IntImplicitModule_basic",
}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -2173,8 +2176,6 @@
"ElementwiseSigmoidIntModule_basic",

# Failure - unknown
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
Expand All @@ -2195,6 +2196,8 @@
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"FloatImplicitModule_basic",
"IntImplicitModule_basic",
}

ONNX_CRASHING_SET = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::FloatImplicit : (Tensor) -> (float)")
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
Expand Down
44 changes: 44 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100))


# ==============================================================================


class FloatImplicitModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.float64, True),
])
def forward(self, x):
return float(torch.ops.aten.FloatImplicit(x))


@register_test_case(module_factory=lambda: FloatImplicitModule())
def FloatImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double())


# ==============================================================================


class IntImplicitModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return float(torch.ops.aten.IntImplicit(x))


@register_test_case(module_factory=lambda: IntImplicitModule())
def IntImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.randint())


# ==============================================================================

class PowIntFloat(torch.nn.Module):
Expand Down
46 changes: 46 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
return %1 : !torch.number
}

// -----

// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
%float1 = torch.constant.float 1.0
%0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}

// -----

// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}

// -----

// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}

// -----

// CHECK-LABEL: func.func @torch.prims.view_of$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
Expand Down

0 comments on commit 0bf4b86

Please sign in to comment.