Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HardMax #950

Merged
merged 4 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_onnx_mlir_library(OMONNXToKrnl
Math/Clip.cpp
Math/Elementwise.cpp
Math/Gemm.cpp
Math/Hardmax.cpp
Math/LRN.cpp
Math/MatMul.cpp
Math/Reduction.cpp
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
populateLoweringONNXCumSumOpPattern(patterns, &getContext());
populateLoweringONNXElementwiseOpPattern(patterns, &getContext());
populateLoweringONNXGemmOpPattern(patterns, &getContext());
populateLoweringONNXHardmaxOpPattern(patterns, &getContext());
populateLoweringONNXReductionOpPattern(patterns, &getContext());
populateLoweringONNXSoftmaxOpPattern(patterns, &getContext());
populateLoweringONNXMatMulOpPattern(patterns, &getContext());
Expand Down
148 changes: 148 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Hardmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===----------------- Hardmax.cpp - Hardmax Op ---------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX softmax operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"

using namespace mlir;

/// Returns the indices of the maximum values along a given axis.
static Value emitArgmax(ConversionPatternRewriter &rewriter, Location loc,
Value input, int64_t axis) {
KrnlBuilder createKrnl(rewriter, loc);
MathBuilder createMath(createKrnl);
IndexExprScope scope(createKrnl);

MemRefType memRefType = input.getType().cast<MemRefType>();
Type indexType = rewriter.getIndexType();
int64_t rank = memRefType.getRank();
Value zero = createMath.constantIndex(0);

MemRefBoundsIndexCapture inputBounds(input);
SmallVector<IndexExpr, 4> inputUBS;
inputBounds.getDimList(inputUBS);

// Allocate and initialize the result.
// Th result has the same shape as the input except the axis dimension is 1.
SmallVector<IndexExpr, 4> outputUBS(inputUBS);
outputUBS[axis] = LiteralIndexExpr(1);
SmallVector<int64_t, 4> outputShape;
for (const IndexExpr &dim : outputUBS)
outputShape.push_back(dim.isLiteral() ? dim.getLiteral() : -1);
Value resMemRef = insertAllocAndDeallocSimple(rewriter, nullptr,
MemRefType::get(outputShape, indexType), loc, outputUBS,
/*insertDealloc=*/true);
createKrnl.memset(resMemRef, zero);

ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr> lbs(rank, LiteralIndexExpr(0));
createKrnl.iterateIE(loopDef, loopDef, lbs, inputUBS,
[&](KrnlBuilder &createKrnl, ValueRange inputLoopInd) {
MathBuilder createMath(createKrnl);
SCFBuilder createSCF(createKrnl);

// Load the index of the current max value.
SmallVector<Value> resLoopInd(inputLoopInd);
resLoopInd[axis] = zero;
Value maxInd = createKrnl.load(resMemRef, resLoopInd);

// Load the current max value.
SmallVector<Value> maxLoopInd(inputLoopInd);
maxLoopInd[axis] = maxInd;
Value maxValue = createKrnl.load(input, maxLoopInd);
// Load a new value.
Value next = createKrnl.load(input, inputLoopInd);

// Compare and update the index for the maximum value.
Value gt = createMath.sgt(next, maxValue);
createSCF.ifThenElse(gt, [&](SCFBuilder &createSCF) {
createKrnl.store(inputLoopInd[axis], resMemRef, resLoopInd);
});
});

return resMemRef;
}

struct ONNXHardmaxOpLowering : public ConversionPattern {
ONNXHardmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXHardmaxOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
KrnlBuilder createKrnl(rewriter, loc);
MathBuilder createMath(createKrnl);
IndexExprScope scope(createKrnl);

ONNXHardmaxOpAdaptor operandAdaptor(operands);
Value input = operandAdaptor.input();

MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
auto elementType = memRefType.getElementType();
Value zero = createMath.constantIndex(0);

int64_t rank = memRefType.getRank();
int64_t axis = llvm::dyn_cast<ONNXHardmaxOp>(op).axis();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);

MemRefBoundsIndexCapture inputBounds(input);
SmallVector<IndexExpr, 4> ubs;
inputBounds.getDimList(ubs);

// Insert an allocation and deallocation for the result of this operation.
bool insertDealloc = checkInsertDealloc(op);
Value resMemRef = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, ubs, insertDealloc);

// Compute argmax.
Value argmax = emitArgmax(rewriter, loc, input, axis);

// Produce the final result.
// Set value to 1 if index is argmax. Otherwise, 0.
ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr> lbs(rank, LiteralIndexExpr(0));
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
MathBuilder createMath(createKrnl);
SCFBuilder createSCF(createKrnl);

// Load the index of the current max value.
SmallVector<Value> maxLoopInd(loopInd);
maxLoopInd[axis] = zero;
Value maxInd = createKrnl.load(argmax, maxLoopInd);

// Set value to 1 if the index is argmax. Otherwise, 0.
Value eq = createMath.eq(maxInd, loopInd[axis]);
createSCF.ifThenElse(
eq, /*then*/
[&](SCFBuilder &createSCF) {
Value one = createMath.constant(elementType, 1);
createKrnl.store(one, resMemRef, loopInd);
},
/*else*/
[&](SCFBuilder &createSCF) {
Value zero = createMath.constant(elementType, 0);
createKrnl.store(zero, resMemRef, loopInd);
});
});

rewriter.replaceOp(op, resMemRef);
return success();
}
};

void populateLoweringONNXHardmaxOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXHardmaxOpLowering>(ctx);
}
3 changes: 3 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ void populateLoweringONNXElementwiseOpPattern(
void populateLoweringONNXGemmOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);

void populateLoweringONNXHardmaxOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);

void populateLoweringONNXLRNOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);

Expand Down
19 changes: 18 additions & 1 deletion src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3623,9 +3623,26 @@ LogicalResult ONNXGreaterOrEqualOp::inferShapes(
return success();
}

static LogicalResult verify(ONNXHardmaxOp op) {
ONNXHardmaxOpAdaptor hmOp = ONNXHardmaxOpAdaptor(op);
auto input = hmOp.input();
int64_t axis = op.axis();

// Verify that axis must be in range [-r, r - 1], where r is the rank of
// input.
if (hasShapeAndRank(input)) {
int64_t rank = input.getType().cast<ShapedType>().getRank();
if (axis < -rank || axis > rank - 1)
return op.emitError("axis value is out of range");
}

return success();
}

LogicalResult ONNXHardmaxOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
return emitError(NOT_IMPLEMENTED_MESSAGE);
getResult().setType(getOperand().getType());
return success();
}

LogicalResult ONNXIfOp::inferShapes(
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
return {20};
}
}];
let verifier = [{ return ::verify(*this); }];
}

def ONNXIdentityOp:ONNX_Op<"Identity",
Expand Down
7 changes: 7 additions & 0 deletions test/backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@
"test_gru_with_initial_bias_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{0:{0,1,2}}, CONSTANT_INPUT:{1,2}},

# Hard Max
"test_hardmax_axis_0_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_axis_2_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_one_hot_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_axis_1_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_default_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_hardmax_negative_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# Hard Sigmoid
"test_hardsigmoid_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
95 changes: 95 additions & 0 deletions test/mlir/onnx/onnx_lowering_with_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3121,3 +3121,98 @@ builtin.func @compress_no_axis_enough_cond(%arg0: tensor<3x2xf32>, %arg1: tensor
// CHECK: return [[RES_1_]] : memref<?xf32>
// CHECK: }
}

// -----

func @test_hardmax_axis_1(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32> {
%0 = "onnx.Hardmax"(%arg0) {axis = 1 : si64} : (tensor<3x4x5xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>

// mlir2FileCheck.py -a'["input"]'
// CHECK-LABEL: func @test_hardmax_axis_1
// CHECK-SAME: ([[INPUT_:%.+]]: memref<3x4x5xf32>) -> memref<3x4x5xf32> {
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x5xf32>
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3x1x5xindex>
// CHECK: krnl.memset [[RES_1_]], [[VAR_c0_]] : memref<3x1x5xindex>
// CHECK: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 3, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 4, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5) {
// CHECK: [[VAR_4_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_4_]]#0, [[VAR_c0_]], [[VAR_4_]]#2] : memref<3x1x5xindex>
// CHECK-DAG: [[LOAD_INPUT_MEM_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_4_]]#0, [[LOAD_RES_1_MEM_]], [[VAR_4_]]#2] : memref<3x4x5xf32>
// CHECK-DAG: [[LOAD_INPUT_MEM_1_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_4_]]#0, [[VAR_4_]]#1, [[VAR_4_]]#2] : memref<3x4x5xf32>
// CHECK: [[VAR_8_:%.+]] = arith.cmpf ogt, [[LOAD_INPUT_MEM_1_]], [[LOAD_INPUT_MEM_]] : f32
// CHECK: scf.if [[VAR_8_]] {
// CHECK: krnl.store [[VAR_4_]]#1, [[RES_1_]]{{.}}[[VAR_4_]]#0, [[VAR_c0_]], [[VAR_4_]]#2] : memref<3x1x5xindex>
// CHECK: }
// CHECK: }
// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 4, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 5) {
// CHECK: [[VAR_4_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_4_1_]]#0, [[VAR_c0_]], [[VAR_4_1_]]#2] : memref<3x1x5xindex>
// CHECK: [[LOAD_INPUT_MEM_2_:%.+]] = arith.cmpi eq, [[LOAD_RES_1_MEM_1_]], [[VAR_4_1_]]#1 : index
// CHECK: scf.if [[LOAD_INPUT_MEM_2_]] {
// CHECK: krnl.store [[VAR_cst_0_]], [[RES_]]{{.}}[[VAR_4_1_]]#0, [[VAR_4_1_]]#1, [[VAR_4_1_]]#2] : memref<3x4x5xf32>
// CHECK: } else {
// CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[VAR_4_1_]]#0, [[VAR_4_1_]]#1, [[VAR_4_1_]]#2] : memref<3x4x5xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES_]] : memref<3x4x5xf32>
// CHECK: }
}

// -----

func @test_hardmax_unknown_dims(%arg0: tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 = "onnx.Hardmax"(%arg0) {axis = 1 : si64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>

// mlir2FileCheck.py -a'["input"]'
// CHECK-DAG: #map0 = affine_map<(d0) -> (d0)>
// CHECK-DAG: #map1 = affine_map<(d0, d1) -> (d1)>
// CHECK-DAG: #map2 = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @test_hardmax_unknown_dims
// CHECK-SAME: ([[INPUT_:%.+]]: memref<?x?x?xf32>) -> memref<?x?x?xf32> {
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[INPUT_]], [[VAR_c0_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[INPUT_]], [[VAR_c1_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = memref.dim [[INPUT_]], [[VAR_c2_]] : memref<?x?x?xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {{.*}}: memref<?x?x?xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[INPUT_]], [[VAR_c0_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = memref.dim [[INPUT_]], [[VAR_c1_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = memref.dim [[INPUT_]], [[VAR_c2_]] : memref<?x?x?xf32>
// CHECK: [[RES_1_:%.+]] = memref.alloc([[VAR_4_]], [[VAR_6_]]) {{.*}}: memref<?x1x?xindex>
// CHECK: krnl.memset [[RES_1_]], [[VAR_c0_]] : memref<?x1x?xindex>
// CHECK: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to #map1([[VAR_4_]], [[VAR_5_]]), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to #map2([[VAR_4_]], [[VAR_5_]], [[VAR_6_]])) {
// CHECK: [[VAR_10_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_10_]]#0, [[VAR_c0_]], [[VAR_10_]]#2] : memref<?x1x?xindex>
// CHECK-DAG: [[LOAD_INPUT_MEM_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_10_]]#0, [[LOAD_RES_1_MEM_]], [[VAR_10_]]#2] : memref<?x?x?xf32>
// CHECK-DAG: [[LOAD_INPUT_MEM_1_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_10_]]#0, [[VAR_10_]]#1, [[VAR_10_]]#2] : memref<?x?x?xf32>
// CHECK: [[VAR_14_:%.+]] = arith.cmpf ogt, [[LOAD_INPUT_MEM_1_]], [[LOAD_INPUT_MEM_]] : f32
// CHECK: scf.if [[VAR_14_]] {
// CHECK: krnl.store [[VAR_10_]]#1, [[RES_1_]]{{.}}[[VAR_10_]]#0, [[VAR_c0_]], [[VAR_10_]]#2] : memref<?x1x?xindex>
// CHECK: }
// CHECK: }
// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map0([[VAR_0_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to #map1([[VAR_0_]], [[VAR_1_]]), [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to #map2([[VAR_0_]], [[VAR_1_]], [[VAR_2_]])) {
// CHECK: [[VAR_10_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_10_1_]]#0, [[VAR_c0_]], [[VAR_10_1_]]#2] : memref<?x1x?xindex>
// CHECK: [[LOAD_INPUT_MEM_2_:%.+]] = arith.cmpi eq, [[LOAD_RES_1_MEM_1_]], [[VAR_10_1_]]#1 : index
// CHECK: scf.if [[LOAD_INPUT_MEM_2_]] {
// CHECK: krnl.store [[VAR_cst_0_]], [[RES_]]{{.}}[[VAR_10_1_]]#0, [[VAR_10_1_]]#1, [[VAR_10_1_]]#2] : memref<?x?x?xf32>
// CHECK: } else {
// CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[VAR_10_1_]]#0, [[VAR_10_1_]]#1, [[VAR_10_1_]]#2] : memref<?x?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES_]] : memref<?x?x?xf32>
// CHECK: }
}
10 changes: 10 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2176,3 +2176,13 @@ func @compress_no_axis(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor<?x
// CHECK: return [[VAR_0_]] : tensor<?xf32>
// CHECK: }
}

// -----

func @hardmax(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32>{
%0 = "onnx.Hardmax"(%arg0) {axis = 1 : si64} : (tensor<3x4x5xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: hardmax
// CHECK: [[RES:%.+]] = "onnx.Hardmax"(%arg0) {axis = 1 : si64} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: return [[RES]] : tensor<3x4x5xf32>
}
1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@
'Conv',
'DepthToSpace',
'Expand',
'Hardmax',
'InstanceNormalization',
'Mod',
'NonMaxSuppression',
Expand Down