Skip to content

Commit

Permalink
add argmax lowering
Browse files Browse the repository at this point in the history
Add argmax lowering from torch to linalg
  • Loading branch information
dan-garvey committed Sep 17, 2021
1 parent 2e63f4b commit a2c96d2
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
}

def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
}

def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement
]> {
Expand Down
32 changes: 29 additions & 3 deletions external/torch-mlir/lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static bool isValidTorchDtype(Type dtype) {
return type.getWidth() == 8;
}
}
dtype.dump();
return false;
}

Expand Down Expand Up @@ -136,8 +137,7 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
if (succeeded(parser.parseOptionalQuestion())) {
sizes.push_back(-1);
continue;
}

}
int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
Expand Down Expand Up @@ -264,15 +264,41 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
}

ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
auto elementType = type.getElementType();
if (IntegerType itype = elementType.dyn_cast<IntegerType>()) {
if (!itype.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (itype.getWidth() == width){
elementType = IntegerType::get(type.getContext(), width, IntegerType::Signed);
}
}
}
}
return ValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
elementType);
}

TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
auto dtype = getDtype();
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (type.getWidth() == width){
if (!hasSizes())
return UnrankedTensorType::get(
IntegerType::get(type.getContext(), width));
return RankedTensorType::get(getSizes(),
IntegerType::get(type.getContext(), width));
}
}
}
}


if (!hasSizes())
return UnrankedTensorType::get(getDtype());
return RankedTensorType::get(getSizes(), getDtype());
Expand Down
50 changes: 42 additions & 8 deletions external/torch-mlir/lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
return visitAtenArgmaxOp(argmax, operands);
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
return visitAtenAnyDimOp(anyDim, operands);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
Expand Down Expand Up @@ -395,6 +397,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenAnyDimOp(AtenAnyDimOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
Expand Down Expand Up @@ -732,8 +736,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
llvm::SmallVector<int64_t> dimList;
bool keepdimBool;
if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
knowledge.hasSizes = true;
int64_t inputRank = input.sizes.size();
// TODO: This is not safe. Need to check the list users and use aliasing
Expand All @@ -744,20 +748,50 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
for (auto en : llvm::enumerate(input.sizes)) {
if (dimSet.contains(en.index())) {
if (keepdimBool)
if (keepDim)
knowledge.sizes.push_back(1);
} else {
knowledge.sizes.push_back(en.value());
}
}
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
auto sizes = listConstruct.elements();
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(),
knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(),
kUnknownSize);
}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
int64_t dim;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else {
// Assumes if dim is not an int, that it is None.
// keepDim is ignored in this case, and the result will be
// a rank-0 tensor.
knowledge.hasSizes = true;

}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
}


ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
Expand All @@ -766,21 +800,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
int64_t dim;
bool keepdimBool;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) {
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepdimBool)
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else {
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1,
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
kUnknownSize);
}
}
Expand Down
36 changes: 36 additions & 0 deletions frontends/pytorch/e2e_testing/torchscript/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir_torchscript.annotations import annotate_args, export

# ==============================================================================

class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a, dim=None, keepDim=False):
return torch.argmax(a, dim, keepDim)


@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4), dim=1, keepDim=False)

@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6), dim=0, keepDim=True)
1 change: 1 addition & 0 deletions frontends/pytorch/e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from . import elementwise
from . import list_programs
from . import reduction
from . import argmax

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend']
Expand Down
Loading

0 comments on commit a2c96d2

Please sign in to comment.