Skip to content

Commit

Permalink
[Torch] support adaptive_max_pool1d when return_indices equals False (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 authored Oct 11, 2024
1 parent 8787970 commit 7b11dfc
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7368,10 +7368,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});

rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
};
Expand Down

0 comments on commit 7b11dfc

Please sign in to comment.