Skip to content

Commit

Permalink
[tosa] Add support for some cases of aten.broadcast_to op (llvm#1429)
Browse files Browse the repository at this point in the history
This commit adds support for TorchToTosa lowering of
`aten.broadcast_to` op for cases:
1.) When the rank of input and output tensor is equal.
2.) When the rank of input tensor is zero.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
  • Loading branch information
vivekkhandelwal1 authored Sep 29, 2022
1 parent 0f15b3a commit 6db513c
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 31 deletions.
4 changes: 2 additions & 2 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
}

MHLO_PASS_SET = {
"BroadcastToIdentityCaseStaticModule_basic",
"GatherStaticModule_basic",
"GatherModule_basic",
"Gather2DInputModdule_basic",
Expand Down Expand Up @@ -454,7 +453,8 @@
"_LogSoftmaxModuleStable_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"BroadcastToIdentityCaseStaticModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"SliceStaticModule_basic",
}

Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ bool isBuiltInType(Type type);
// -1 is returned if the tensorRank can't be determined.
int getTensorRank(Value tensor);

bool isViewLikeOp(Operation *op);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
32 changes: 23 additions & 9 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2980,15 +2980,29 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
"size must consist of Scalar constants");

SmallVector<int64_t> inputShape(selfType.getShape());
if (!llvm::equal(inputShape, outShape))
return rewriter.notifyMatchFailure(op,
"Only identity cases are supported.");

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
rewriter.getI64ArrayAttr(outShape));

return success();
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand
// irrespective of the users of the op result.
if (!llvm::equal(inputShape, outShape)) {
for (auto user : op->getResult(0).getUsers()) {
// This case is only supported if the result of the `broadcast_to` op is
// not used by an op which is a view like.
if (isViewLikeOp(user)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: broadcast not supported for this case");
}
}
}
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
op.replaceAllUsesWith(op.self());
rewriter.eraseOp(op);
return success();
}
return rewriter.notifyMatchFailure(
op,
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
}

template <typename AtenOpT, typename TosaOpT>
Expand Down
15 changes: 1 addition & 14 deletions lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
using namespace mlir::torch;
Expand All @@ -27,20 +28,6 @@ static Value assertNonValueTensor(Value tensor) {
return tensor;
}

static bool isViewLikeOp(Operation *op) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
}

namespace {
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> {
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,17 @@ int Torch::getTensorRank(Value tensor) {
}
return tensorRank;
}

bool Torch::isViewLikeOp(Operation *op) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
}
37 changes: 31 additions & 6 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,24 +1089,49 @@ def BroadcastToModule_basic(module, tu: TestUtils):
# ==============================================================================


class BroadcastToIdentityCaseStaticModule(torch.nn.Module):
class BroadcastToSameRankStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 8], torch.float32, True),
([3, 1, 1], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [3, 1, 1])
def forward(self, x, y):
y = torch.broadcast_to(y, [3, 1, 8])
return torch.ops.aten.sub(x, y)


@register_test_case(module_factory=lambda: BroadcastToIdentityCaseStaticModule())
def BroadcastToIdentityCaseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
@register_test_case(module_factory=lambda: BroadcastToSameRankStaticModule())
def BroadcastToSameRankStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8), tu.rand(3, 1, 1))


# ==============================================================================


class BroadcastZeroRankInputStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 8], torch.float32, True),
([], torch.float32, True),
])
def forward(self, x, y):
y = torch.broadcast_to(y, [3, 1, 8])
return torch.ops.aten.sub(x, y)


@register_test_case(module_factory=lambda: BroadcastZeroRankInputStaticModule())
def BroadcastZeroRankInputStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8), tu.rand())

# ==============================================================================

Expand Down

0 comments on commit 6db513c

Please sign in to comment.