Skip to content


add argmax lowering
Browse files Browse the repository at this point in the history
Add argmax lowering from torch to linalg
  • Loading branch information
dan-garvey committed Oct 13, 2021
1 parent 19e9fc4 commit 6114dea
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 12 deletions.
66 changes: 66 additions & 0 deletions e2e_testing/torchscript/
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

# ==============================================================================

class ArgmaxModule(torch.nn.Module):
def __init__(self):

([-1, -1], torch.float32, True),

def forward(self, a):
return torch.argmax(a)

@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================

class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):

([-1, -1, -1], torch.float32, True),
def forward(self, a):
return torch.argmax(a, dim=1)

@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):

([-1, -1], torch.float32, True),
def forward(self, a):
return torch.argmax(a, 0, True)

@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

1 change: 1 addition & 0 deletions e2e_testing/torchscript/
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import quantized_models
from . import elementwise
from . import reduction
from . import argmax

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Expand Down
16 changes: 16 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";

def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
]> {
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
let results = (outs
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";

def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
]> {
Expand Down
160 changes: 158 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
// of *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a
// mature copmiler usually have the verifier turned off for compile time
// mature compiler usually have the verifier turned off for compile time
// reasons).
// The compiler cannot crash even if the user wrote an erroneous program!
Expand Down Expand Up @@ -1141,12 +1141,166 @@ static Value createLinalgPayloadCalculationForReduceOp(
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
return b.create<AddFOp>(loc, payloadArgs);

op->emitError("unimplemented lowering in "
return nullptr;

namespace {
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;

matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

Location loc = argmaxOp.getLoc();
AtenArgmaxOp::Adaptor adaptor(operands);
Value input = adaptor.self();
RankedTensorType resultType =
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
if (!outElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
"aten.arg_max to linalg.* requires integer-like result type");

bool keepDim = false;
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure();

int64_t dim;
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
// For pytorch, if the value of Dim is None, argmax
// returns the index of the max value of the flattened input tensor,
// so here we flatten the input tensor.
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
input = rewriter.create<linalg::TensorCollapseShapeOp>(
argmaxOp->getLoc(), input, reassociation);
// Becomes 0 for flattened tensor.
dim = 0;
// Recast to fix shape.
inputType = input.getType().cast<RankedTensorType>();
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
"aten.arg_max to linalg.* requires Float input element type");

// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<mlir::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
} else if (keepDim)
// First fill the output buffer for the index.
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, outElementType);

// Second fill the output buffer for the running max.
Value initTensorMax =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, inElementType)

FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));

Value fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)

// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {

if (unsigned(dim) == size.index()) {
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
} else {
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldIndex = blockArgs[1];
Value oldValue = blockArgs[2];

Value newIndex = rewriter.create<IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));

Value predicate;
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<mlir::CmpFOp>(
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedLoc, ValueRange({resultIndex, resultMax}));

if (didEncounterError)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported aten.argmax element type");

// This cast is required to fix the shape in the case of keepDim=True
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
return success();
} // namespace
namespace {

// Converts an elementwise op.
Expand Down Expand Up @@ -1896,6 +2050,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,

int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
Expand Down
50 changes: 41 additions & 9 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
return visitAtenArgmaxOp(argmax, operands);
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
return visitAtenAnyDimOp(anyDim, operands);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
Expand Down Expand Up @@ -397,6 +399,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
visitAtenArgmaxOp(AtenArgmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
visitAtenAnyDimOp(AtenAnyDimOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy>
Expand Down Expand Up @@ -733,8 +738,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
knowledge.dtype = input.dtype;
llvm::SmallVector<int64_t> dimList;
bool keepdimBool;
if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
knowledge.hasSizes = true;
int64_t inputRank = input.sizes.size();
// TODO: This is not safe. Need to check the list users and use aliasing
Expand All @@ -745,20 +750,48 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
for (auto en : llvm::enumerate(input.sizes)) {
if (dimSet.contains(en.index())) {
if (keepdimBool)
if (keepDim)
} else {
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
auto sizes = listConstruct.elements();
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(),
knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(),
return getLatticeElement(op->getResult(0)).join(knowledge);
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
int64_t dim;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
} else {
// Assumes if dim is not an int, that it is None.
// keepDim is ignored in this case, and the result will be
// a rank-0 tensor.
knowledge.hasSizes = true;
return getLatticeElement(op->getResult(0)).join(knowledge);

ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
Expand All @@ -767,22 +800,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
knowledge.dtype = input.dtype;
int64_t dim;
bool keepdimBool;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepdimBool)
if (keepDim)
knowledge.sizes[dim] = 1;
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
} else {
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1,
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
return getLatticeElement(op->getResult(0)).join(knowledge);
Expand Down

0 comments on commit 6114dea

Please sign in to comment.