Skip to content

Commit

Permalink
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold authored Aug 4, 2023
1 parent c0d8248 commit 20a2b68
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 0 deletions.
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@
}

STABLEHLO_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
Expand Down Expand Up @@ -867,6 +869,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7746,6 +7746,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

def Torch_AtenTileOp : Torch_Op<"aten.tile", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::tile : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTileOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenTileOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement,
ReadOnly
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6790,6 +6790,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.lt.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.list<int>) {\n"
" %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" %8 = torch.aten.add.t %7, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" torch.prim.If.yield %8 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.list<int>\n"
" }\n"
" %4 = call @\"__torch_mlir_shape_fn.aten.repeat\"(%arg0, %3) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8635,6 +8652,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4634,6 +4634,45 @@ class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
};
} // namespace

namespace {
// Unconditionally decompose `aten.tile` into `aten.repeat`.
class DecomposeAtenTileOp : public OpRewritePattern<AtenTileOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTileOp op,
PatternRewriter &rewriter) const override {
auto input = op.getSelf();
auto repeats = op.getDims();
SmallVector<Value> dimsElements;
if (!getListConstructElements(repeats, dimsElements)) {
return rewriter.notifyMatchFailure(
op, "failed to get elements of `dims` param");
}
auto dimsSize = dimsElements.size();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only support input tensor with shape information");
}
auto inputRank = inputType.getSizes().size();
if (dimsSize < inputRank) {
auto constantOne = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(1));
for (auto i = dimsSize; i < inputRank; ++i) {
dimsElements.insert(dimsElements.begin(), constantOne);
}
repeats = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimsElements);
}
rewriter.replaceOpWithNewOp<Torch::AtenRepeatOp>(op, op.getType(), input,
repeats);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -4805,6 +4844,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenScalarTensorOp>();
target.addIllegalOp<AtenScatterValueOp>();
target.addIllegalOp<AtenTypeAsOp>();
target.addIllegalOp<AtenTileOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,17 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
out.append(self[i] * repeats[i + leading_rank])
return out

@check_shape_function([
Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length
Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length
])
def aten〇tile〡shape(self: List[int], dims: List[int]) -> List[int]:
dims_length = len(dims)
self_length = len(self)
if dims_length < self_length:
dims = [1] * (self_length - dims_length) + dims
return aten〇repeat〡shape(self, dims)

def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1775,6 +1786,11 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1]))
def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]))
def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
Expand Down
41 changes: 41 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,47 @@ def RepeatModule_basic(module, tu: TestUtils):
# ==============================================================================


class TileSmallDimsSizeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, x):
return x.tile([3, 4])


@register_test_case(module_factory=lambda: TileSmallDimsSizeModule())
def TileSmallDimsSizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================

class TileBigDimsSizeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, x):
return x.tile([3, 4, 5, 6])


@register_test_case(module_factory=lambda: TileBigDimsSizeModule())
def TileBigDimsSizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================


class ExpandModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 20a2b68

Please sign in to comment.