Skip to content

Commit

Permalink
[FUSION]: Fuse ONNXConvOp fed by ONNXMulOp (llvm#1419)
Browse files Browse the repository at this point in the history
Fuse an ONNXMulOp when one of the input is a the result of a ONNXConvOp and the other input is a produced by a ONNXConstantOp if the ONNXConvOp weights is produced by a ONXConstantOp`

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>
  • Loading branch information
Ettore Tiotto authored May 17, 2022
1 parent c3ee6a5 commit ddb548e
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 54 deletions.
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,7 @@ def ONNXModOp:ONNX_Op<"Mod",

def ONNXMulOp:ONNX_Op<"Mul",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Mul operation";
let description = [{
"Performs element-wise binary multiplication (with Numpy-style broadcasting support)."
Expand Down
34 changes: 17 additions & 17 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,13 @@

using namespace mlir;

namespace {

// If 'lhs' is not NoneType, return 'lhs - rhs'.
// If 'lhs' is not NoneType, return 'lhs - rhs'.
// Otherwise, return '-rhs'.
Value subtractOrNeg(
PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) {
if (lhs.getType().isa<NoneType>()) {
Value result = rewriter.create<ONNXNegOp>(loc, rhs);
return result;
} else {
Value result = rewriter.create<ONNXSubOp>(loc, lhs, rhs);
return result;
}
namespace onnx_mlir {

// If 'A' is NoneType, return -B. Otherwise return A-B.
Value subtractOrNeg(PatternRewriter &rewriter, Location loc, Value A, Value B) {
if (A.getType().isa<NoneType>())
return rewriter.create<ONNXNegOp>(loc, B);
return rewriter.create<ONNXSubOp>(loc, A, B);
}

// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
Expand Down Expand Up @@ -100,13 +93,13 @@ bool areProducedByTransposeOp(ValueRange values) {
});
}

} // namespace onnx_mlir

/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Dialect/ONNX/ONNXRewrite.inc"

} // end anonymous namespace

/// Register optimization patterns as "canonicalization" patterns
/// on the ONNXMatMultOp.
/// on the ONNXAddOp.
void ONNXAddOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<NormalizeAddPattern>(context);
Expand All @@ -116,6 +109,13 @@ void ONNXAddOp::getCanonicalizationPatterns(
results.insert<FuseAddConvNullBiasPattern>(context);
}

/// on the ONNXMulOp.
void ONNXMulOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<NormalizeMulPattern>(context);
results.insert<FuseMulConvNullBiasPattern>(context);
}

/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
Expand Down
132 changes: 99 additions & 33 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,38 @@ def createDenseElementsAttrFromSize : NativeCodeCall<
// If '$1' is not NoneType, do subtraction '$1 - $2'.
// Otherwise, take the negative of '$2'.
def subtractOrNeg: NativeCodeCall<
"subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">;
"onnx_mlir::subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">;

// Get the rank of the given value.
def getRankOf :
NativeCodeCall<"$0.getType().cast<ShapedType>().getRank()">;

// Create an ArrayAttr of IntergerAttr(s) of [$0].
def createArrayAttrOf : NativeCodeCall<
"onnx_mlir::createArrayAttrOfNToM($_builder, $0, $0)">;

// Create an ArrayAttr of IntergerAttr(s) of values in [1, N-1].
def createArrayAttrOfOneToRankOf : NativeCodeCall<
"createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 1)">;
"onnx_mlir::createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 1)">;

// Create an ArrayAttr of IntergerAttr(s) of values in [1, N-2].
def createArrayAttrOfOneToRankOfExclusive : NativeCodeCall<
"createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 2)">;
"onnx_mlir::createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 2)">;

// Create an ArrayAttr of IntergerAttr(s) of values in [2, rank - 1].
def createArrayAttrOfTwoToRankOf : NativeCodeCall<
"createArrayAttrOfNToM($_builder, 2, $0.getType().cast<ShapedType>().getRank() - 1)">;
"onnx_mlir::createArrayAttrOfNToM($_builder, 2, $0.getType().cast<ShapedType>().getRank() - 1)">;

def ONNXConstantOpNormalize: NativeCodeCall<
"onnx_mlir::normalizeConstantOp($_builder, $0, $1)">;

def AttributeIsNotNull :
Constraint<CPred<" ($_self) ">, "Attribute is null">;
Constraint<CPred<" ($_self) ">, "Attribute is null">;

def IsDenseElementsAttr :
Constraint<And<[CPred<" ($_self) ">,
CPred<" ($_self).isa<DenseElementsAttr>()">
]>, "Attribute is not a DenseElementsAttr">;

// Intended to check whether there is at least one not-Null the attributes
// However, the current table gen can only support max 4 parameters
Expand All @@ -72,33 +85,21 @@ def AttributesNotAllNull :
Constraint<Neg<And<[CPred<"($0)">, CPred<" ($1) ">, CPred<" ($2) ">]>>,
"Attributes are not null">;

def GetNullAttr :
NativeCodeCall<"Attribute()">;
def GetNullAttr : NativeCodeCall<"Attribute()">;

def GetNullFloatAttr :
NativeCodeCall<"FloatAttr()">;
def GetNullFloatAttr : NativeCodeCall<"FloatAttr()">;

def GetNullIntegerAttr :
NativeCodeCall<"IntegerAttr()">;
def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">;

def GetNullStringAttr :
NativeCodeCall<"StringAttr()">;
def GetNullStringAttr : NativeCodeCall<"StringAttr()">;

def GetNullArrayAttr :
NativeCodeCall<"ArrayAttr()">;
def GetNullArrayAttr : NativeCodeCall<"ArrayAttr()">;

// Create a StringAttr from a string.
class StringAttrOfValue<string val>:
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;

// Check whether an ArrayAttr contains non-zero values or not.
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
"has non-zero elements">;

// Check that a StrAttr does not contain a specific value.
class IsNotStringAttrOfValue<string val>:
Constraint<CPred<"$0.cast<StringAttr>().getValue() != \"" # val # "\"">>;

// Check the rank of a value is greater than a given integer.
class HasRankGT<int rank> :
Constraint<CPred<"$0.getType().isa<ShapedType>() && "
Expand All @@ -117,6 +118,13 @@ def HaveSameLastDim: Constraint<
"[$1.getType().cast<RankedTensorType>().getRank() - 1])">,
"Two tensors have the same last dimension">;

class HaveSameDim<int dim>: Constraint<
CPred<"onnx_mlir::hasShapeAndRank($0) && onnx_mlir::hasShapeAndRank($1) && "
"$0.getType().cast<RankedTensorType>().getShape()[" # dim # "] != -1 && "
"($0.getType().cast<RankedTensorType>().getShape()[" # dim # "] =="
"$1.getType().cast<RankedTensorType>().getShape()[" # dim # "])">,
"Two tensors have the same specified dimension">;

def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">;

def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
Expand All @@ -125,18 +133,23 @@ def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;

def NotNoneType : Constraint<CPred<"!($0.getType().isa<NoneType>())">>;

def hasShapeAndRank : Constraint<CPred<"onnx_mlir::hasShapeAndRank($0)">>;
def HasShapeAndRank : Constraint<CPred<"onnx_mlir::hasShapeAndRank($0)">>;

def HasSameElementType : Constraint<
CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
"$1.cast<::mlir::TypeAttr>().getValue())">,
"has same element type">;

def HaveSameElementType : Constraint<
CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
"$1.getType().dyn_cast<ShapedType>().getElementType())">,
"has same element type">;

def IsStaticShapeTensor:
Constraint<
CPred<
"$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">,
"hasStaticShape">;
Constraint<
CPred<
"$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">,
"hasStaticShape">;

def AreTheSameAxisArray: Constraint<
CPred<"onnx_mlir::AreTheSameAxisArray("
Expand All @@ -156,6 +169,11 @@ def IsNotFromONNXConstantOp: Constraint<
CPred<"!(llvm::dyn_cast_or_null<ONNXConstantOp>($0.getDefiningOp()))">,
"Is a value not from ONNXConstantOp">;

def IsFromONNXConstantOpWithDenseElementsAttr: Constraint<
And<[CPred<" isa<ONNXConstantOp>($_self.getDefiningOp()) ">,
CPred<" onnx_mlir::getONNXConstantOp($_self).valueAttr().isa<DenseElementsAttr>() ">
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">;

def AreTheSameConstantOpDenseAttr: Constraint<
CPred<"onnx_mlir::AreTheSameConstantOpDenseAttr($_builder,"
"(onnx_mlir::hasShapeAndRank($0) ? $0.getType().cast<ShapedType>().getRank() : 0),"
Expand All @@ -172,11 +190,10 @@ class AllDimsFromAxisToEndAre<int axis, int val>: Constraint<
class RankXMinusRankYIs<int diff>: Constraint<
CPred<"($0.getType().cast<ShapedType>().getRank() "
" - $1.getType().cast<ShapedType>().getRank() == " # diff # ")">,
"X' rank is greater than Y's rank dif units">;
"X' rank is greater than Y's rank diff units">;

def TransposeVariadicInput: NativeCodeCall<
"transposeVariadicInput($_builder, $_loc, $0, $1)"
>;
"onnx_mlir::transposeVariadicInput($_builder, $_loc, $0, $1)">;

//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
Expand Down Expand Up @@ -263,6 +280,55 @@ def FuseAddConvPattern: Pat<
(RankXMinusRankYIs<1> $res, $y)]
>;

//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'Mul o Conv' into 'Conv' if the other input
// of Mul is a constant, by multipling constant to 'w' of 'Conv':
//
// We have:
// (Conv) z = i * w + b
// (Mul) y = z x c (where c is a constant)
//
// which corresponds to the following computation:
// y = i * new_w + b
// where
// new_w = w x c
//
// The shape of 'c' must be compatible with that of 'w'
//===----------------------------------------------------------------------===//

def NormalizeMulPattern: Pat<
(ONNXMulOp $x, $y),
(ONNXMulOp $y, $x),
[(IsFromONNXConstantOp $x), (IsNotFromONNXConstantOp $y)]
>;

def FuseMulConvNullBiasPattern: Pat<
(ONNXMulOp:$res
(ONNXConvOp
$x, $w, $b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
(ONNXConstantOp:$y $_, $denseAttr, $_, $_, $_, $_, $_, $_)),
(ONNXConvOp
$x,
// new_w
(ONNXMulOp $w, (ONNXUnsqueezeV11Op $y, (createArrayAttrOf(getRankOf $y)))),
// unchanged operands and attributes.
$b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
[(HasNoneType $b),
(IsDenseElementsAttr:$denseAttr),
(IsFromONNXConstantOpWithDenseElementsAttr:$w),
(HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type.
(HasRankGT<1> $w), // rank of $w must be at least 2.
(RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1.
(HaveSameDim<0> $w, $y), // the first dimension of $w and $y must be equal.
(AllDimsFromAxisToEndAre<1, 1>:$y)] // all dimensions of $y must be 1 except for the first one.
>;

// TODO add pattern for non-null bias with contraints:
// - bias must be have rank equal to 1 and
// - bias element data type must be the same as mul constant
// - bias dimension (0) must be equal to mul constant dim(0)
// codegen is different too (look it up in onnx-runtime)

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXIdentityOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -314,11 +380,11 @@ def RemoveIdentityTransposePattern: Pat<
[(IsIdentityPermuteAttribute:$p)]>;

def GetIndexOfAxisInPerm: NativeCodeCall<
"getIndexOfAxisInPerm($_builder, $0, $1)"
"onnx_mlir::getIndexOfAxisInPerm($_builder, $0, $1)"
>;

def ProducedByTransposeOp: Constraint<
CPred<"areProducedByTransposeOp($_self)">,
CPred<"onnx_mlir::areProducedByTransposeOp($_self)">,
"all values are produced by ONNXTransposeOp"
>;

Expand Down Expand Up @@ -351,7 +417,7 @@ def RemoveIdentityReshapePattern: Pat<
[(HasSpecifiedConstantShape $val, $shape)]>;

def GetReturnTypeForMatMulOpND2D: NativeCodeCall<
"getReturnTypeForMatMulOpND2D($0, $1)"
"onnx_mlir::getReturnTypeForMatMulOpND2D($0, $1)"
>;

def SwapReshapeMatMulPattern: Pattern<
Expand Down
1 change: 1 addition & 0 deletions src/Transform/ONNX/ONNXOpTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ void ONNXOpTransformPass::runOnOperation() {
onnx_mlir::createDecomposeONNXToONNXPass());
dynamicPM.addPass(onnx_mlir::createShapeInferencePass());
dynamicPM.addPass(mlir::createCanonicalizerPass());
dynamicPM.addPass(onnx_mlir::createShapeInferencePass());
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
if (failed(runPipeline(dynamicPM, module)))
Expand Down
29 changes: 25 additions & 4 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func @test_conv_batchnormtestmode_fusion_nobias(%arg0 : tensor<1x3x224x224xf32>)
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.UnsqueezeV11"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[UNSQUEEZE]], [[WEIGHT]]) : (tensor<*xf32>, tensor<64x3x7x7xf32>) -> tensor<*xf32>

// CHECK: [[NEG_MEAN:%.+]] = "onnx.Neg"([[MEAN]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[NEG_MEAN]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
Expand Down Expand Up @@ -144,7 +144,7 @@ func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.UnsqueezeV11"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[UNSQUEEZE]], [[WEIGHT]]) : (tensor<*xf32>, tensor<64x3x7x7xf32>) -> tensor<*xf32>

// CHECK: [[SUB:%.+]] = "onnx.Sub"(%arg1, [[MEAN]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[SUB]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
Expand Down Expand Up @@ -560,7 +560,7 @@ func @test_rewrite_batchnormtestmode_Nd(%arg0 : tensor<1x64x112x112xf32>) -> ten

// CHECK: [[X_A:%.*]] = "onnx.Mul"(%arg0, [[A_UNSQUEEZE]]) : (tensor<1x64x112x112xf32>, tensor<*xf32>) -> tensor<*xf32>

// CHECK: [[SUB:%.*]] = "onnx.Mul"([[MEAN]], [[A]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[SUB:%.*]] = "onnx.Mul"([[A]], [[MEAN]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[B:%.*]] = "onnx.Sub"([[BIAS]], [[SUB]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[B_UNSQUEEZE:%.*]] = "onnx.UnsqueezeV11"([[B]]) {axes = [1, 2]} : (tensor<*xf32>) -> tensor<*xf32>

Expand Down Expand Up @@ -591,7 +591,7 @@ func @test_rewrite_batchnormtestmode_1d(%arg0 : tensor<64xf32>) -> tensor<64xf32

// CHECK: [[X_A:%.*]] = "onnx.Mul"(%arg0, [[A]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>

// CHECK: [[SUB:%.*]] = "onnx.Mul"([[MEAN]], [[A]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[SUB:%.*]] = "onnx.Mul"([[A]], [[MEAN]]) : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[B:%.*]] = "onnx.Sub"([[BIAS]], [[SUB]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>

// CHECK: [[RES:%.*]] = "onnx.Add"([[X_A]], [[B]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<64xf32>
Expand Down Expand Up @@ -628,3 +628,24 @@ func @test_fuse_add_conv(%arg0 : tensor<1x1x28x28xf32>, %arg1 : tensor<8x1x5x5xf
// CHECK: return [[RES]] : tensor<1x8x28x28xf32>
// CHECK: }
}

// -----

func @test_fuse_mul_conv(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[[[[0.0234164055, 0.0228030644], [2.442580e-02, 0.0237577036]]], [[[-0.0410864502, 0.0488203131], [0.164448678, -0.0200194642]]], [[[-4.34581793E-9, 0.025325032], [0.0373019315, 0.165243402]]], [[[-0.0198689923, 0.131284416], [0.0572107285, 2.33985098E-8]]], [[[0.0187684372, -0.148515195], [0.0154875498, 0.019133633]]], [[[0.0176953916, -0.0154658081], [0.0233727545, -0.274110436]]], [[[-0.021181887, 0.0936150252], [0.135688141, -0.0202601217]]], [[[-0.0201558527, 0.0192655921], [0.227748245, -0.196346223]]]]> : tensor<8x1x2x2xf32>} : () -> tensor<8x1x2x2xf32>
%1 = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Conv"(%arg0, %0, %1) {kernel_shape = [2, 2], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x2x2xf32>, none) -> tensor<*xf32>
%3 = "onnx.Constant"() {value = dense<[[[-0.161539719]], [[-0.433835655]], [[0.091641359]], [[-0.0168522168]], [[-0.0650264397]], [[-0.131737873]], [[0.0204175506]], [[-0.121110231]]]> : tensor<8x1x1xf32>} : () -> tensor<8x1x1xf32>
%4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<8x1x1xf32>) -> tensor<*xf32>
return %4 : tensor<*xf32>

// CHECK-LABEL: test_fuse_mul_conv
// CHECK-SAME: ([[X:%.+]]: tensor<1x1x28x28xf32>) -> tensor<*xf32> {
// CHECK: [[W:%.+]] = "onnx.Constant"() {value = dense<{{.*}}[0.0234164055, 0.0228030644], [2.442580e-02, 0.0237577036]{{.*}}, {{.*}}[-0.0410864502, 0.0488203131], [0.164448678, -0.0200194642]{{.*}}, {{.*}}[-4.34581793E-9, 0.025325032], [0.0373019315, 0.165243402]{{.*}}, {{.*}}[-0.0198689923, 0.131284416], [0.0572107285, 2.33985098E-8]{{.*}}, {{.*}}[0.0187684372, -0.148515195], [0.0154875498, 0.019133633]{{.*}}, {{.*}}[0.0176953916, -0.0154658081], [0.0233727545, -0.274110436]{{.*}}, {{.*}}[-0.021181887, 0.0936150252], [0.135688141, -0.0202601217]{{.*}}, {{.*}}[-0.0201558527, 0.0192655921], [0.227748245, -0.196346223]{{.*}}]> : tensor<8x1x2x2xf32>} : () -> tensor<8x1x2x2xf32>
// CHECK: [[NOBIAS:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y:%.+]] = "onnx.Constant"() {value = dense<{{.*}}[-0.161539719{{.*}}, {{.*}}-0.433835655{{.*}}, {{.*}}0.091641359{{.*}}, {{.*}}-0.0168522168{{.*}}, {{.*}}-0.0650264397{{.*}}, {{.*}}-0.131737873{{.*}}, {{.*}}0.0204175506{{.*}}, {{.*}}-0.121110231{{.*}}> : tensor<8x1x1xf32>} : () -> tensor<8x1x1xf32>
// CHECK: [[Y1:%.+]] = "onnx.UnsqueezeV11"([[Y]]) {axes = [3]} : (tensor<8x1x1xf32>) -> tensor<*xf32>
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[Y1]], [[W]]) : (tensor<*xf32>, tensor<8x1x2x2xf32>) -> tensor<*xf32>
// CHECK: [[RES:%.+]] = "onnx.Conv"(%arg0, [[MUL]], [[NOBIAS]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [2, 2], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<*xf32>, none) -> tensor<*xf32>
// CHECK: return [[RES]] : tensor<*xf32>
}
Loading

0 comments on commit ddb548e

Please sign in to comment.