Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten.isclose support and its torch-to-tosa lowering #2512

Merged
merged 5 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -928,6 +930,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [
}];
}

def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
Torch_FloatType:$rtol,
Torch_FloatType:$atol,
Torch_BoolType:$equal_nan
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIscloseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
AtenIscloseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// check args
double rtol, atol;
bool equalNan;
if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol)))
return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant");
if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol)))
return rewriter.notifyMatchFailure(op, "atol must be a scalar constant");
if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan)))
return rewriter.notifyMatchFailure(
op, "unimplemented: equal_nan is expected to be false");

// check tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
if (!selfType || !otherType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!selfType.getElementType().isa<mlir::FloatType>() ||
!otherType.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only FP element type is supported");
}

auto rhsSubOp = rewriter.create<tosa::SubOp>(
op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther());
auto rhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), selfType, rhsSubOp);

auto lhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, adaptor.getOther());
auto rtolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(rtol));
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
rtolConstOp, lhsAbsOp, /*shift=*/0);
auto atolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(atol));
auto addOp =
rewriter.create<tosa::AddOp>(op->getLoc(), otherType, atolConstOp, mulOp);

auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(op, outType, addOp,
rhsAbsOp);

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
AtenClampOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5134,6 +5187,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]:
return upstream_shape_functions.unsqueeze(self, dim)

Expand Down Expand Up @@ -2171,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int:
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
_, query_dtype = query_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
emit("aten::view_as_real : (Tensor) -> (Tensor)")
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")

# Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
Expand Down
45 changes: 45 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4580,3 +4580,48 @@ def forward(self, x):
@register_test_case(module_factory=lambda: Add_Module())
def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))


# ==============================================================================


class IscloseStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
([5, 5], torch.float32, True),
])
def forward(self, x, y):
return torch.isclose(x, y)


@register_test_case(module_factory=lambda: IscloseStaticModule())
def IscloseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5), tu.rand(5, 5))
zezhang marked this conversation as resolved.
Show resolved Hide resolved


# ==============================================================================


class IscloseStaticModuleTrue(torch.nn.Module):

def __init__(self):
super().__init__()
self.register_buffer('tensor', torch.ones(1))

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.isclose(x, self.tensor)

@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))
29 changes: 29 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08
// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
%float1.000000e-08 = torch.constant.float 1.000000e-08
%float1.000000e-05 = torch.constant.float 1.000000e-05
%false = torch.constant.bool false
%0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1>
return %0 : !torch.vtensor<[5,5],i1>
}