Skip to content

Commit

Permalink
[Stablehlo] support aten.all.dim (#3746)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Sep 30, 2024
1 parent 7f63cb2 commit 476b32a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
5 changes: 3 additions & 2 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenAllOp>(op)) {
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
auto constAttr =
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
Expand Down Expand Up @@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
AtenLinalgVectorNormOp>(op)) {
result = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAllOp>(op)) {
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
result = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
Expand Down Expand Up @@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN

#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
Expand Down
9 changes: 1 addition & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,6 @@
"RandnLikeDtypeModule_basic",
"RandnLikeModule_basic",
"RandnModule_basic",
"ReduceAllDimBool_basic",
"ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ReduceMinKeepDimReturnBoth_basic",
"ReduceMinKeepDim_basic",
"ReduceProdDimIntFloatModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand All @@ -770,6 +762,7 @@
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down
20 changes: 20 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


class ReduceAllDimFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.all(a, dim=0)


@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


Expand Down

0 comments on commit 476b32a

Please sign in to comment.