Skip to content

Commit

Permalink
[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)
Browse files Browse the repository at this point in the history
This commit adds to the lowering of `aten.view` handling for the
following cases:

- `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)`
- `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))`

Fixes: #2448
  • Loading branch information
ramiro050 authored Sep 27, 2023
1 parent e69266a commit 7c6b9d2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
51 changes: 44 additions & 7 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
ArrayRef<int64_t> yDims,
SmallVector<int64_t> &xIndices,
SmallVector<int64_t> &yIndices) {
if (xDims.empty() || yDims.empty())
return failure();

auto isValidReduction = [](int64_t expectedReductionProduct,
ArrayRef<int64_t> arrayToReduce) -> bool {
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
Expand Down Expand Up @@ -255,13 +258,34 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return success();
}

// If one of the two dims arrays has size 0 and the other array only
// has dims of size 1, a mapping is created from no dimensions to
// all the dimensions of the other array.
static LogicalResult mapTrailingSizeOneDims(ArrayRef<int64_t> xDims,
ArrayRef<int64_t> yDims,
SmallVector<int64_t> &xIndices,
SmallVector<int64_t> &yIndices) {
SmallVector<int64_t> ignoredIndices;
if (xDims.empty()) {
return mapAllDimsToSingleDim(ArrayRef<int64_t>({1}), yDims,
ignoredIndices, yIndices);
} else if (yDims.empty()) {
return mapAllDimsToSingleDim(xDims, ArrayRef<int64_t>({1}), xIndices,
ignoredIndices);
} else {
return failure();
}
}

// Calculates the size of a dynamic dimension if all other dimensions are
// statically known, and rewrites that dynamic dimension with the static size.
//
// Note: this function assumes that all the dimensions in `inputShape` map to
// all the dimensions in `outputShape`.
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
MutableArrayRef<int64_t> outputShape) {
if (inputShape.empty() || outputShape.empty())
return;
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
if (inputDynamicDimCount + outputDynamicDimCount != 1)
Expand Down Expand Up @@ -420,7 +444,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) {
// Used for ensuring that we don't have an ambiguous expansion
bool assumedDynamicDimNotSplit = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
while (inputDim < nextUnchangedInput || outputDim < nextUnchangedOutput) {
auto inputShapeSlice =
MutableArrayRef<int64_t>(inputShape)
.slice(inputDim, nextUnchangedInput - inputDim);
Expand All @@ -441,9 +465,15 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
"(e.g. [-1, -1] -> [-1, -1, -1])");
}

if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice,
inputSliceIndices,
outputSliceIndices))) {
if (succeeded(mapTrailingSizeOneDims(inputShapeSlice, outputShapeSlice,
inputSliceIndices,
outputSliceIndices))) {
} else if (outputShapeSlice.empty()) {
inputSliceIndices.assign(
llvm::to_vector(llvm::seq<int64_t>(0, inputShapeSlice.size())));
} else if (succeeded(mapAllDimsToSingleDim(
inputShapeSlice, outputShapeSlice, inputSliceIndices,
outputSliceIndices))) {
calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice);
// Update shape to pass the tensor.expand_shape and
// tensor.collapse_shape verifiers. If one of the dimensions of the
Expand All @@ -462,7 +492,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
/// `mapStaticallyKnownDims` maps the smallest number of
/// input and output dimensions in the slice statically
/// known to have the same number of elements.
} else if (inputShapeSlice[0] == kUnknownSize) {
} else if (inputShapeSlice.size() > 0 &&
inputShapeSlice[0] == kUnknownSize) {
// If the input is dynamic, assume it is not split
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
outputSizeInt[outputDim]);
Expand All @@ -478,8 +509,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
"in `aten.view`");
}

inputAssociations.emplace_back();
outputAssociations.emplace_back();
// If one of the slices is empty, this means we are handling
// the case of trailing dimensions, which does not require a
// new reassociation; the trailing dimensions get added to the
// last reassociation created.
if (inputShapeSlice.size() > 0 && outputShapeSlice.size() > 0) {
inputAssociations.emplace_back();
outputAssociations.emplace_back();
}
for (int64_t inputSliceIndex : inputSliceIndices)
inputAssociations.back().push_back(inputSliceIndex + inputDim);
for (int64_t outputSliceIndex : outputSliceIndices)
Expand Down
36 changes: 35 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,40 @@ def forward(self, a):
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 128))

class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(0), 1, 1, 1)

@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))

class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, 1, 1, 1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(0))

@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 1, 1, 1))

# ==============================================================================

class ReshapeAliasExpandModule(torch.nn.Module):
Expand Down Expand Up @@ -710,4 +744,4 @@ def forward(self, a):

@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
module.forward(tu.rand(2, 4))

0 comments on commit 7c6b9d2

Please sign in to comment.