Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch Dialect] emit aten.tile op and decompose it into aten.repeat #2355

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@
}

STABLEHLO_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
Expand Down Expand Up @@ -867,6 +869,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7697,6 +7697,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

def Torch_AtenTileOp : Torch_Op<"aten.tile", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::tile : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTileOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenTileOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement,
ReadOnly
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6790,6 +6790,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.lt.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.list<int>) {\n"
" %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" %8 = torch.aten.add.t %7, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" torch.prim.If.yield %8 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.list<int>\n"
" }\n"
" %4 = call @\"__torch_mlir_shape_fn.aten.repeat\"(%arg0, %3) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8632,6 +8649,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4634,6 +4634,45 @@ class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
};
} // namespace

namespace {
// Unconditionally decompose `aten.tile` into `aten.repeat`.
class DecomposeAtenTileOp : public OpRewritePattern<AtenTileOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTileOp op,
PatternRewriter &rewriter) const override {
auto input = op.getSelf();
auto repeats = op.getDims();
SmallVector<Value> dimsElements;
if (!getListConstructElements(repeats, dimsElements)) {
return rewriter.notifyMatchFailure(
op, "failed to get elements of `dims` param");
}
auto dimsSize = dimsElements.size();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only support input tensor with shape information");
}
auto inputRank = inputType.getSizes().size();
if (dimsSize < inputRank) {
auto constantOne = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(1));
for (auto i = dimsSize; i < inputRank; ++i) {
dimsElements.insert(dimsElements.begin(), constantOne);
}
repeats = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimsElements);
}
rewriter.replaceOpWithNewOp<Torch::AtenRepeatOp>(op, op.getType(), input,
repeats);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -4805,6 +4844,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenScalarTensorOp>();
target.addIllegalOp<AtenScatterValueOp>();
target.addIllegalOp<AtenTypeAsOp>();
target.addIllegalOp<AtenTileOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,17 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
out.append(self[i] * repeats[i + leading_rank])
return out

@check_shape_function([
Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length
Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length
])
def aten〇tile〡shape(self: List[int], dims: List[int]) -> List[int]:
dims_length = len(dims)
self_length = len(self)
if dims_length < self_length:
dims = [1] * (self_length - dims_length) + dims
return aten〇repeat〡shape(self, dims)

def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1772,6 +1783,11 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1]))
def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]))
def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
Expand Down
41 changes: 41 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,47 @@ def RepeatModule_basic(module, tu: TestUtils):
# ==============================================================================


class TileSmallDimsSizeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, x):
return x.tile([3, 4])


@register_test_case(module_factory=lambda: TileSmallDimsSizeModule())
def TileSmallDimsSizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================

class TileBigDimsSizeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, x):
return x.tile([3, 4, 5, 6])


@register_test_case(module_factory=lambda: TileBigDimsSizeModule())
def TileBigDimsSizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================


class ExpandModule(torch.nn.Module):

def __init__(self):
Expand Down