Skip to content

Commit

Permalink
Add a canonicalization pattern for aten.unflatten.int (#3656)
Browse files Browse the repository at this point in the history
Addresses an issue in <#3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
  • Loading branch information
zjgarvey authored Sep 3, 2024
1 parent 2960538 commit 295bf41
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9538,6 +9538,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenDimOp : Torch_Op<"aten.dim", [
Expand Down
93 changes: 88 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,12 +739,16 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!valueA || !valueB) {
if (!valueA && !valueB)
return nullptr;
}

return IntegerAttr::get(IntegerType::get(getContext(), 1),
valueA.getValue() | valueB.getValue());
if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
if (valueA && valueA.getValue() == 0)
return getB();
if (valueB && valueB.getValue() == 0)
return getA();
// unreachable
return nullptr;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2162,6 +2166,85 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenUnflattenIntOp
//===----------------------------------------------------------------------===//

void AtenUnflattenIntOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// if there are only two sizes and one of them is statically 1, then convert
// to an unqueeze.
patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) {
SmallVector<Value> sizeValues;
if (!getListConstructElements(op.getSizes(), sizeValues))
return rewriter.notifyMatchFailure(op,
"sizes must come from list construct");
if (sizeValues.size() != 2)
return failure();
int64_t dim0, dim1;
bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0));
bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1));
if (!dim0Constant && !dim1Constant)
return failure();
if (dim0 != 1 && dim1 != 1)
return failure();
Value unflattenDim = op.getDim();
Value self = op.getSelf();
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
// the runtime asserts below are introduced to catch malformed unflatten ops
// possibly generated from onnx IR.
Value unsqueeze;
if (dim0 == 1) {
// unsqueeze at dim
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, unflattenDim);
if (failed(maybeUnsqueeze))
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
unsqueeze = maybeUnsqueeze.value();
// check if the remaining size value is either -1 or equal to original
// size at dim
Value selfSizeAtDim =
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
Value isSameSize = rewriter.create<AtenEqIntOp>(
op.getLoc(), selfSizeAtDim, sizeValues[1]);
Value isMinusOne =
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[1]);
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
op.getLoc(), isMinusOne, isSameSize);
rewriter.create<Torch::RuntimeAssertOp>(
op.getLoc(), isMOneOrSameSize,
rewriter.getStringAttr("unflatten sizes must be compatible"));
}
if (dim1 == 1) {
// unsqueeze at dim + 1
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
Value dimPlusOne =
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
if (failed(maybeUnsqueeze))
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
unsqueeze = maybeUnsqueeze.value();
// check if the remaining size value is either -1 or equal to original
// size at dim
Value selfSizeAtDim =
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
Value isSameSize = rewriter.create<AtenEqIntOp>(
op.getLoc(), selfSizeAtDim, sizeValues[0]);
Value isMinusOne =
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[0]);
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
op.getLoc(), isMinusOne, isSameSize);
rewriter.create<Torch::RuntimeAssertOp>(
op.getLoc(), isMOneOrSameSize,
rewriter.getStringAttr("unflatten sizes must be compatible"));
}
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
unsqueeze);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenSelectIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
emit(
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True
)
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
Expand Down

0 comments on commit 295bf41

Please sign in to comment.