Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add argmax lowering and integer return types #299

Merged
merged 1 commit into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions e2e_testing/torchscript/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

# ==============================================================================

class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, a):
return torch.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================

class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, dim=1)

@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, 0, True)

@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

1 change: 1 addition & 0 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import quantized_models
from . import elementwise
from . import reduction
from . import argmax

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Expand Down
16 changes: 16 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
}

def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
}

def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement
]> {
Expand Down
155 changes: 153 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
// of *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a
// mature copmiler usually have the verifier turned off for compile time
// mature compiler usually have the verifier turned off for compile time
// reasons).
//
// The compiler cannot crash even if the user wrote an erroneous program!
Expand Down Expand Up @@ -1141,12 +1141,161 @@ static Value createLinalgPayloadCalculationForReduceOp(
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<AddFOp>(loc, payloadArgs);

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForReduceOp");
return nullptr;
}

namespace {
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
public:
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

Location loc = argmaxOp.getLoc();
AtenArgmaxOp::Adaptor adaptor(operands);
Value input = adaptor.self();
RankedTensorType resultType =
getTypeConverter()
->convertType(argmaxOp.getResult().getType())
.cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
if (!outElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires integer-like result type");

bool keepDim = false;
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure();

int64_t dim;
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
// For pytorch, if the value of Dim is None, argmax
// returns the index of the max value of the flattened input tensor,
// so here we flatten the input tensor.
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
reassociation[0].push_back(i);
input = rewriter.create<linalg::TensorCollapseShapeOp>(
argmaxOp->getLoc(), input, reassociation);
// Becomes 0 for flattened tensor.
dim = 0;
// Recast to fix shape.
inputType = input.getType().cast<RankedTensorType>();
}
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires Float input element type");
}

// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<mlir::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize);
} else if (keepDim)
resultShape.push_back(c1);
}
// First fill the output buffer for the index.
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, outElementType);

// Second fill the output buffer for the running max.
Value initTensorMax =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, inElementType)
.result();

FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType,
APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));

Value fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
.result();

// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(rewriter.getAffineDimExpr(size.index()));

if (unsigned(dim) == size.index()) {
iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(getParallelIteratorTypeName());
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
}
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc,
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldIndex = blockArgs[1];
Value oldValue = blockArgs[2];

Value newIndex = rewriter.create<IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));

Value predicate;
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<mlir::CmpFOp>(
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});

// This cast is required to fix the shape in the case of keepDim=True
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
linalgOp.getResult(0));
return success();
}
};
} // namespace
namespace {

// Converts an elementwise op.
Expand Down Expand Up @@ -1896,6 +2045,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
sizes.push_back(-1);
continue;
}

int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
Expand Down
49 changes: 40 additions & 9 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
return visitAtenArgmaxOp(argmax, operands);
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
return visitAtenAnyDimOp(anyDim, operands);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
Expand Down Expand Up @@ -397,6 +399,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenArgmaxOp(AtenArgmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenAnyDimOp(AtenAnyDimOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy>
Expand Down Expand Up @@ -733,8 +738,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
llvm::SmallVector<int64_t> dimList;
bool keepdimBool;
if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
knowledge.hasSizes = true;
int64_t inputRank = input.sizes.size();
// TODO: This is not safe. Need to check the list users and use aliasing
Expand All @@ -745,20 +750,47 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
for (auto en : llvm::enumerate(input.sizes)) {
if (dimSet.contains(en.index())) {
if (keepdimBool)
if (keepDim)
knowledge.sizes.push_back(1);
} else {
knowledge.sizes.push_back(en.value());
}
}
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
auto sizes = listConstruct.elements();
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(),
knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(),
kUnknownSize);
}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
int64_t dim;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else if (op.dim().getType().isa<IntegerType>())
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
kUnknownSize);
}
// If dim is no kind of Integer, keepDim is ignored,
// and the result will bea rank-0 tensor
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
Expand All @@ -767,22 +799,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
int64_t dim;
bool keepdimBool;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepdimBool)
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else {
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1,
kUnknownSize);
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
Expand Down
Loading