diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 2d411da967..e778cf3f41 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -90,10 +90,12 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addPass(onnx_mlir::createONNXToZHighPass()); pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + // There are more opportunities for const propagation once all zhigh ops were // generated. pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); pm.addPass(mlir::createCanonicalizerPass()); + // Layout propagation at ZHighIR. pm.addNestedPass( onnx_mlir::zhigh::createZHighLayoutPropagationPass()); @@ -110,13 +112,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); } - // After all optimizations, if there are still light-weight ops (e.g. add, - // sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to - // use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle - // these ops, e.g vectorize the computation. - if (nnpaEnableZHighToOnnx) - pm.addNestedPass(onnx_mlir::createZHighToONNXPass()); - // One more call to ONNX shape inference/canonicalization/... to update shape // if possible. if (enableONNXHybridPass) { @@ -134,13 +129,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { // ZHighConstPropagation currently assumes that DenseElementsAttr is used. pm.addPass(createScrubDisposablePass()); - // Constant propagation at ZHighIR: constant stickify. - // Only support BE machines. - bool isBE = llvm::endianness::native == llvm::endianness::big; - if (isBE) - pm.addNestedPass( - onnx_mlir::zhigh::createZHighConstPropagationPass()); - // Experimental feature: Decompose stick/unstick into two phases: layout // transform and data conversion. Do some optimizations after decomposing. // Then, recompose again layout and data conversion if they are not optimized. @@ -152,6 +140,20 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { onnx_mlir::zhigh::createZHighRecomposeToStickUnstickPass()); } + // After all optimizations, if there are still light-weight ops (e.g. add, + // sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to + // use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle + // these ops, e.g vectorize the computation. + if (nnpaEnableZHighToOnnx) + pm.addNestedPass(onnx_mlir::createZHighToONNXPass()); + + // Constant propagation at ZHighIR: constant stickify. + // Only support BE machines. + bool isBE = llvm::endianness::native == llvm::endianness::big; + if (isBE) + pm.addNestedPass( + onnx_mlir::zhigh::createZHighConstPropagationPass()); + // Remove common sub-expressions. pm.addPass(mlir::createCSEPass()); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td index 2175d0ecc2..8fd410c98f 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td @@ -37,52 +37,30 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create($_loc, $0.getT // ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighAddPattern1 : Pat< - (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXAddOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)] ->; - -def replaceZHighAddPattern2 : Pat< - (ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXAddOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)] +def replaceZHighAddPattern : Pat< + (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXAddOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMulPattern1 : Pat< - (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXMulOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMulPattern2 : Pat< - (ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXMulOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [], - (addBenefit 0) +def replaceZHighMulPattern : Pat< + (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXMulOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighSubPattern1 : Pat< - (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXSubOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighSubPattern2 : Pat< - (ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXSubOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighSubPattern : Pat< + (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXSubOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// @@ -90,54 +68,30 @@ def replaceZHighSubPattern2 : Pat< // %X),(ZHighStickOp %Y)) // Note: turn off this pattern since NNPA is faster at this moment. //===----------------------------------------------------------------------===// -//def replaceZHighDivPattern1 : Pat< -// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)), -// (ONNXDivOp $x, (ZHighUnstickOp $y)), -// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], -// (addBenefit 1) -//>; -// -//def replaceZHighDivPattern2 : Pat< -// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))), -// (ONNXDivOp (ZHighUnstickOp $x), $y), -// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], -// (addBenefit 0) -//>; +// def replaceZHighDivPattern : Pat< +// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), +// (ONNXDivOp $x, $y), +// [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] +// >; //===----------------------------------------------------------------------===// // ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMinPattern1 : Pat< - (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMinPattern2 : Pat< - (ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (CreateONNXMinOp $u, (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighMinPattern : Pat< + (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (CreateONNXMinOp $u, $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMaxPattern1 : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMaxPattern2 : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighMaxPattern : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (CreateONNXMaxOp $u, $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===//