Skip to content

Commit

Permalink
add argmax lowering
Browse files Browse the repository at this point in the history
Add argmax lowering from torch to linalg
  • Loading branch information
dan-garvey committed Oct 12, 2021
1 parent 19e9fc4 commit 0038b4f
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 12 deletions.
66 changes: 66 additions & 0 deletions e2e_testing/torchscript/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

# ==============================================================================

class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, a):
return torch.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================

class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, dim=1)

@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, 0, True)

@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

1 change: 1 addition & 0 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import quantized_models
from . import elementwise
from . import reduction
from . import argmax

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Expand Down
16 changes: 16 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
}

def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
}

def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement
]> {
Expand Down
173 changes: 171 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
// of *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a
// mature copmiler usually have the verifier turned off for compile time
// mature compiler usually have the verifier turned off for compile time
// reasons).
//
// The compiler cannot crash even if the user wrote an erroneous program!
Expand Down Expand Up @@ -1141,12 +1141,179 @@ static Value createLinalgPayloadCalculationForReduceOp(
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<AddFOp>(loc, payloadArgs);

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForReduceOp");
return nullptr;
}

namespace {
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
public:
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;

LogicalResult matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

Location loc = argmaxOp.getLoc();
AtenArgmaxOp::Adaptor adaptor(operands);
Value input = adaptor.self();
RankedTensorType resultType = getTypeConverter()
->convertType(argmaxOp.getResult().getType())
.cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
if (!outElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires integer-like result type");

bool keepDim = false;
if (!matchPattern(argmaxOp.keepdim(),
m_TorchConstantBool(&keepDim)))
return failure();

int64_t dim;
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
// For pytorch, if the value of Dim is None, argmax
// returns the index of the max value of the flattened input tensor,
// so here we flatten the input tensor.
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
reassociation[0].push_back(i);
input = rewriter.create<linalg::TensorCollapseShapeOp>(
argmaxOp->getLoc(), input, reassociation);
// Becomes 0 for flattened tensor.
dim=0;
// Recast to fix shape.
inputType = input.getType().cast<RankedTensorType>();
}
Type inElementType = inputType.getElementType();
if (inElementType.isa<mlir::FloatType>()){
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires Float input element type");
}

// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<mlir::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim!=i) {
auto currentDimSize =
rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize);
}
else if (keepDim)
resultShape.push_back(c1);
}
// First fill the output buffer for the index.
Value filledTensorIdx = createZeroInitTensor(rewriter, loc, resultShape, outElementType);

// Second fill the output buffer for the running max.
Value initTensorMax =
rewriter
.create<linalg::InitTensorOp>(loc, resultShape, inElementType)
.result();

FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType, APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));

if (!fillValueMaxAttr)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported aten.argmax element type");

Value fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
.result();


// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(rewriter.getAffineDimExpr(size.index()));

if (unsigned(dim)==size.index()) {
iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(getParallelIteratorTypeName());
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}), input,
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldIndex = blockArgs[1];
Value oldValue = blockArgs[2];

Value newIndex = rewriter.create<IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));

Value predicate;
if (inElementType.isa<mlir::FloatType>()) {
predicate = rewriter.create<mlir::CmpFOp>(
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
// The following code is meaningless until int element type
// is supported.
} else if (inElementType.isa<IntegerType>()) {
predicate = rewriter.create<mlir::CmpIOp>(
nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
} else {
didEncounterError = true;
return;
}

auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});

if (didEncounterError)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported aten.argmax element type");

// This cast is required to fix the shape in the case of keepDim=True
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
linalgOp.getResult(0));
return success();
}
};
} // namespace
namespace {

// Converts an elementwise op.
Expand Down Expand Up @@ -1896,6 +2063,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
19 changes: 17 additions & 2 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
if (succeeded(parser.parseOptionalQuestion())) {
sizes.push_back(-1);
continue;
}

}
int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
Expand Down Expand Up @@ -273,6 +272,22 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
auto dtype = getDtype();
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (type.getWidth() == width){
if (!hasSizes())
return UnrankedTensorType::get(
IntegerType::get(type.getContext(), width));
return RankedTensorType::get(getSizes(),
IntegerType::get(type.getContext(), width));
}
}
}
}


if (!hasSizes())
return UnrankedTensorType::get(getDtype());
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
Expand Down
Loading

0 comments on commit 0038b4f

Please sign in to comment.