Skip to content

Commit

Permalink
[Torch Dialect] emit aten.reshape_as op and add decomposition pattern. (
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold authored Nov 5, 2023
1 parent 71ca529 commit d5ee8ee
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 0 deletions.
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9223,6 +9223,29 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
}];
}

def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::reshape_as : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenReshapeAsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenReshapeAsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
AllowsTypeRefinement,
ReadOnly
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6899,6 +6899,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.reshape_as\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._reshape_alias\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8908,6 +8912,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reshape_as\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.resize_\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
23 changes: 23 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5375,6 +5375,28 @@ class DecomposeAtenTileOp : public OpRewritePattern<AtenTileOp> {
};
} // namespace

namespace {
// Unconditionally decompose `aten.reshape_as` into `aten.size` +
// `aten.reshape`.
class DecomposeAtenReshapeAsOp : public OpRewritePattern<AtenReshapeAsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenReshapeAsOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = op.getSelf();
Value other = op.getOther();

auto otherShape = rewriter.create<Torch::AtenSizeOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), other);
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(op, op.getType(), input,
otherShape);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -5557,6 +5579,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);

GreedyRewriteConfig config;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenScatterValueOp>();
target.addIllegalOp<AtenTypeAsOp>();
target.addIllegalOp<AtenTileOp>();
target.addIllegalOp<AtenReshapeAsOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"ReshapeAsModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1137,6 +1138,7 @@
"ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic",
"ReshapeAsModule_basic",
"ElementwiseGeluModule_basic",
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ def aten〇view〡shape(self: List[int], size: List[int]) -> List[int]:
def aten〇reshape〡shape(self: List[int], shape: List[int]) -> List[int]:
return upstream_shape_functions.view(self, shape)

def aten〇reshape_as〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.view(self, other)

def aten〇_reshape_alias〡shape(self: List[int], size: List[int], stride: List[int]) -> List[int]:
return upstream_shape_functions.view(self, size)

Expand Down Expand Up @@ -1942,6 +1945,11 @@ def aten〇reshape〡dtype(self_rank_dtype: Tuple[int, int], shape: List[int]) -
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_two_tensor_op())
def aten〇reshape_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], memory_format: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,25 @@ def UnsafeView1DFoldModule_basic(module, tu: TestUtils):

# ==============================================================================

class ReshapeAsModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

@export
@annotate_args([
None,
([4, 3], torch.float32, True),
([2, 6], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.reshape_as(a, b)

@register_test_case(module_factory=lambda: ReshapeAsModule())
def ReshapeAsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand(2, 6))

# ==============================================================================

class ReshapeExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit d5ee8ee

Please sign in to comment.