diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 1ef0090049..19bbcd9d86 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1289,7 +1289,9 @@ template <> GenOpMix getGenOpMix(Type t, Operation *op) { return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2}, {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + {GenericOps::FloorGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, + 4 /* Little parallelism in code. */}}; } template <> diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 5d31eb9b8b..f67be95176 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType, return 1; } // Gather operation statics - int64_t vectorizedOpNum, scalarOpNum; - double avgVL = VectorMachineSupport::getAvgArchVectorLength( - genOps, elementType, vectorizedOpNum, scalarOpNum); + int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure; + double avgVL = + VectorMachineSupport::getAvgArchVectorLength(genOps, elementType, + vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure); if (avgVL < 1.5) { LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with " << avgVL << " avg VL\n"); return 1; } - LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n"); + LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL + << ", vec op num " << vectorizedOpNum + << ", max reg pressure " + << estimatedMaxVectorRegisterPressure << "\n"); // Define a target max unroll as a function of register pressure. int64_t unrollVL; int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum(); - if (vectorizedOpNum >= vrNum / 2) + if (estimatedMaxVectorRegisterPressure >= vrNum) + unrollVL = 1; + else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum) unrollVL = 2; - else if (vectorizedOpNum >= vrNum / 4) + else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum) unrollVL = 4; else unrollVL = 8; @@ -743,6 +749,22 @@ int64_t capVLForMaxUnroll( return archVL * unrollVL; } +int64_t boostVLForMinUnroll( + MemRefType memRefType, MemRefType convertedMemRefType, int64_t totVL) { + if (totVL == 1) + return 1; // Simd already disabled, nothing to cap. + Type convertedElementType = convertedMemRefType.getElementType(); + int64_t convertedArchVL = + VectorMachineSupport::getArchVectorLength(convertedElementType); + if (convertedArchVL > totVL) { + LLVM_DEBUG(llvm::dbgs() + << " simd enable: boost totVL to " << convertedArchVL + << " because of type conversions.\n"); + return convertedArchVL; + } + return totVL; +} + int64_t capVLForSimdOnly( MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) { if (totVL == 1) diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index db72437837..87f7fdaf8c 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -663,6 +663,12 @@ int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType, // Cap totVL so that it is at most maxUnrollVL * archVL. int64_t capVLForMaxUnroll( mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL); +// In some type conversion loops we may have a given totVL based on a given +// memRef type and gen op mix. But the final result may be converted to a +// different type, which may requires a minimum unroll to proceed as a single +// SIMD operation. This call adjust the totVL for that case. +int64_t boostVLForMinUnroll(mlir::MemRefType memRefType, + mlir::MemRefType convertedMemRefType, int64_t totVL); // Enabling a simdOnly code generation scheme by capping totVL so that it // divides simdLoopStaticTripCount. When not possible (either because // there is no totVL that divides simdLoopStaticTripCount or trip count is diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 775ee0cc35..3c3143c4ad 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax, Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD, bool enableParallel) { - MultiDialectBuilder create( - rewriter, loc); + MultiDialectBuilder + create(rewriter, loc); // Types Type quantizedElementType = quantizedType.getElementType(); @@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5}, {GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2}, {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + {GenericOps::FloorGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, + 8 /* Little parallelism in code. */}}; totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); @@ -68,8 +70,16 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, inputAF.emplace_back(zero); DimsExpr outputAF; outputAF.emplace_back(zero); + + // faster than original loop on z16, takes 124us for 64k vals + // Allocate output buffers. + MemRefType flatBufferType = llvm::cast(flatInput.getType()); + Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims); + DimsExpr bufferAF; + bufferAF.emplace_back(zero); + create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, - {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, + {flatInput}, {inputAF}, {flatBuffer}, {bufferAF}, {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { MultiDialectBuilder create(kb); Value x = inputVals[0]; @@ -83,11 +93,31 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, adjustX = create.math.add(roundX, zeroPoint); else adjustX = roundX; - // Saturate + // Saturate: use max into a min. Value saturateX = create.math.clip(adjustX, qMin, qMax); - Value res = create.math.cast(quantizedElementType, saturateX); - return res; + // Old approach. + // return create.math.cast(quantizedElementType, saturateX); + return saturateX; }}); + + // A second loop that performs scalar float to int performs better than the + // compiler's attempt to generate SIMD conversion code. This might not hold + // with all data types, but is definitely noticeable with uint8. + // + // Investigate further: we might save the vector to a buffer on the fly + // (avoiding a second loop as below), and then reload each value as scalar and + // then saved them as scalar (thus avoiding the insert/extract SIMD operations + // that also do not perform well). We can have a SIMD buffer in memory for the + // non-quantized and quantized simd values, but then we also need to privatize + // it, which is also not easy in this scheme. So ignore this for now. + create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel, + [&](KrnlBuilder &kb, ValueRange loopInd) { + MultiDialectBuilder create(kb); + Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]}); + Value res = create.math.cast(quantizedElementType, buffVal); + create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]}); + }); + if (totVL > 1) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "quantizationLinear whole tensor"); diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 75fa59cb96..89ae3c5818 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -2073,6 +2073,29 @@ void VectorBuilder::multiReduction(ArrayRef inputVecArray, } } +Value VectorBuilder::extractElement(Value vector, int64_t index) const { + MultiDialectBuilder create(*this); + VectorType type = llvm::cast(vector.getType()); + int64_t VL = type.getShape()[0]; + assert(type.getRank() == 1 && "expected 1D vector only"); + assert(index >= 0 && index < VL && "out of range vector index"); + Value position = create.math.constantIndex(index); + return b().create(loc(), vector, position); +} + +Value VectorBuilder::insertElement( + Value vector, Value element, int64_t index) const { + MultiDialectBuilder create(*this); + VectorType type = llvm::cast(vector.getType()); + int64_t VL = type.getShape()[0]; + assert(type.getRank() == 1 && "expected 1D vector only"); + assert(index >= 0 && index < VL && "out of range vector index"); + Value position = create.math.constantIndex(index); + // Unlike LLVM insert element which takes , vector + // take + return b().create(loc(), element, vector, position); +} + //===----------------------------------------------------------------------===// // LLVM Builder //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index f1c65bb32c..f247f6c9fe 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -574,6 +574,11 @@ struct VectorBuilder final : DialectBuilder { void multiReduction(mlir::ArrayRef inputVecArray, F2 reductionFct, llvm::SmallVectorImpl &outputVecArray); + // Insert and extract. + mlir::Value extractElement(mlir::Value vector, int64_t position) const; + mlir::Value insertElement( + mlir::Value vector, mlir::Value element, int64_t position) const; + private: bool isPowerOf2(uint64_t num) const; uint64_t getLengthOf1DVector(mlir::Value vec) const; diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index f5e6cf897e..2e5ab3ef4a 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -78,21 +78,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { } /*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps, - Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) { + Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure) { int64_t size = genOps.size(); + vectorizedOpNum = maxVectorRegisterPressure = 0; if (!hasSimd()) { - vectorizedOpNum = 0; scalarOpNum = size; return 1; } int64_t totProcessedValues = 0.0; - vectorizedOpNum = 0; scalarOpNum = 0; + bool hasRegisterPressure = false; + // Determine which operations support SIMD and accumulate their vector // lengths. for (auto pair : genOps) { GenericOps genOp = pair.first; int64_t num = pair.second; + // Handle other metrics first. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num); + hasRegisterPressure = true; + continue; + } + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t vl = getArchVectorLength(genOp, elementType); // If past last value, assume 1; otherwise use actual value. // Accumulate weighted scalar/vectorized num and vl length. @@ -106,7 +115,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { } // Compute final values int64_t totNum = vectorizedOpNum + scalarOpNum; - scalarOpNum = size - vectorizedOpNum; + if (!hasRegisterPressure) { + // Estimate default register pressure as one per 2 vector operation. + maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1); + } return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0; } @@ -115,13 +127,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { // ============================================================================= int64_t Z16VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); - // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -137,10 +149,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: /* Supported via compare and select */ case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: /* Use load integer & rounding modes*/ @@ -161,7 +173,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -190,13 +202,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // ============================================================================= int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -212,10 +225,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -237,7 +250,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -276,13 +289,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // ============================================================================= int64_t NeonVectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -297,10 +311,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( if (isFloat) { // Supports only 32 and 64 bit Floats; if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -322,7 +336,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -370,10 +384,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) { for (auto pair : mix1) { GenericOps genOp = pair.first; int64_t num = pair.second; - if (u.find(genOp) != u.end()) - u[genOp] += num; // Has this op already, add to it. - else + if (u.find(genOp) != u.end()) { + // Merge the 2 operation counts/metrics. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + // For register pressure, pick the max of both. + u[genOp] = std::max(u[genOp], num); + } else { + // For operation count, use the sum of both + u[genOp] += num; + } + } else { + // First time we have this. u[genOp] = num; + } } return u; } diff --git a/src/Dialect/Mlir/VectorMachineSupport.hpp b/src/Dialect/Mlir/VectorMachineSupport.hpp index bcd2ad1a88..0d1104bbad 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.hpp +++ b/src/Dialect/Mlir/VectorMachineSupport.hpp @@ -32,6 +32,10 @@ namespace onnx_mlir { // (e.g. all the compares). enum class GenericOps { + ///////////////////////////////////// + // Generic ops. + ///////////////////////////////////// + AbsGop, ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */ CeilDivGop, @@ -62,6 +66,17 @@ enum class GenericOps { TrigArcGop, /* Arc trigonometry ops: asin, acos, atan. */ TrigGop, /* Trigonometry ops: sin, cos, tan. */ TrigHyperbolicGop, /* Hyperbolic trig. */ + + LastGop, /* Marker of the last op. Used to delineate from other metrics. */ + + ///////////////////////////////////// + // Metrics others than operations. + ///////////////////////////////////// + + // Metric that provides an estimate of the maximum number of vector registers + // used in a kernel. If none is provided, we estimate the pressure based on + // the number of operations. + EstimatedVectorRegisterPressure, }; // Describe the mix of Generic operations in a given kernel. Each generic @@ -132,8 +147,12 @@ class VectorMachineSupport { // number of times that generic operation was found. Note that scalar // operation have a vector length of one in the weighted average as they still // contribute one result. + // Max vector register pressure is also reported, either from an explicit + // mention in the genOps, or estimated as one vector register per vector + // operation. static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType, - int64_t &vectorizedOpNum, int64_t &scalarOpNum); + int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure); protected: // Virtual functions that do the actual work. Called by the "get" functions. diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir index 55dbdb1942..934fba2240 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir @@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref +// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref +// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -87,33 +87,40 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32 -// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32 -// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32 +// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32 +// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32 -// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32 -// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i8 -// CHECK: [[VAR_52_:%.+]] = builtin.unrealized_conversion_cast [[VAR_51_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_52_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32 +// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32 +// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref +// CHECK: } +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8 +// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index 61800a518a..637cd5fdaf 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -13,11 +13,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -42,16 +42,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -97,37 +97,44 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<4096xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> -// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<16xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4096){ +// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref // CHECK: } @@ -143,11 +150,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -172,29 +179,29 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4304){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_40_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_40_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_41_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -240,65 +247,70 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<4335xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4328){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<8xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ +// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<16xf32> +// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_]], [[LOAD_RES_6_MEM_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_43_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = math.floor [[VAR_43_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.mulf [[VAR_44_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_46_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_45_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_46_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_44_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = math.floor [[VAR_44_]] : vector<16xf32> +// CHECK: [[VAR_46_:%.+]] = arith.mulf [[VAR_45_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_47_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_46_]] : vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_47_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_47_]], [[VAR_48_]], [[LOAD_RES_6_MEM_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_48_]], [[VAR_49_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.select [[VAR_50_]], [[VAR_49_]], [[VAR_42_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_]], [[VAR_52_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.select [[VAR_51_]], [[VAR_50_]], [[VAR_43_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_53_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_]], [[VAR_53_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: vector.store [[VAR_56_]], [[RES_10_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4328 to 4335){ -// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ +// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> // CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 // CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_3_1_]] : f32 // CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_3_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_43_1_:%.+]] = math.floor [[VAR_42_1_]] : f32 -// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[VAR_43_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_44_1_]] : f32 -// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_44_1_:%.+]] = math.floor [[VAR_43_1_]] : f32 +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[VAR_44_1_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_45_1_]] : f32 +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.select [[VAR_46_1_]], [[VAR_47_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_50_1_:%.+]] = arith.select [[VAR_49_1_]], [[VAR_48_1_]], [[VAR_41_1_]] : f32 -// CHECK: [[VAR_51_1_:%.+]] = arith.addf [[VAR_50_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_52_1_:%.+]] = arith.maxnumf [[VAR_51_1_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_1_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i8 -// CHECK: [[VAR_55_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_55_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.select [[VAR_47_1_]], [[VAR_48_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_51_1_:%.+]] = arith.select [[VAR_50_1_]], [[VAR_49_1_]], [[VAR_42_1_]] : f32 +// CHECK: [[VAR_52_1_:%.+]] = arith.addf [[VAR_51_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_53_1_:%.+]] = arith.maxnumf [[VAR_52_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_1_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_54_1_]], [[RES_10_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: } +// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_4_:%.+]] = 0 to 4335){ +// CHECK: [[VAR_35_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_1_]] : f32 to i8 +// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_3_1_]] : i8 to ui8 +// CHECK: krnl.store [[LOAD_RES_4_MEM_1_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref // CHECK: } @@ -343,16 +355,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -398,37 +410,44 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 8){ +// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir index 2bcf1dba86..dd048e3ab7 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -18,11 +18,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 * -512 + 4096, 512)> // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> @@ -49,46 +49,46 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_:%.+]] = affine.apply [[MAP_0_]]([[VAR_33_]]) -// CHECK-DAG: [[VAR_35_:%.+]] = affine.min [[MAP_1_]]([[VAR_33_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_2_]]([[VAR_33_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_37_:%.+]] = affine.min [[MAP_3_]]([[VAR_33_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_34_]] to [[VAR_37_]] step [[CST_32_]] { +// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_50_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_4_]]([[VAR_33_]]) -// CHECK: [[VAR_39_:%.+]] = arith.remsi [[VAR_38_]], [[CST_32_]] : index -// CHECK: [[VAR_40_:%.+]] = arith.subi [[VAR_38_]], [[VAR_39_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.addi [[VAR_34_]], [[VAR_40_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_41_]] to [[VAR_35_]] step [[CST_1_]] { +// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) +// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index +// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_50_1_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_51_1_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_44_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_44_]], [[RES_5_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_45_]], [[RES_7_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -96,16 +96,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_34_1_]] : f32 -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK: krnl.store [[VAR_38_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_39_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 +// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -151,38 +151,46 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<4096xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32>, vector<8xf32> -// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[VAR_36_1_:%.+]] = arith.divf [[VAR_34_1_]], [[VAR_35_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = math.floor [[VAR_36_1_]] : vector<8xf32> -// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_1_]], [[VAR_37_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_2_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_2_]], [[VAR_40_1_]], [[VAR_37_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_44_1_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_37_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_41_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_2_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_2_]], [[VAR_51_2_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_25_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: vector.store [[VAR_55_]], [[RES_12_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK: } +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_3_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to 4096){ +// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[RES_12_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4096xf32> +// CHECK: [[VAR_36_3_:%.+]] = arith.fptoui [[VAR_35_1_1_]] : f32 to i8 +// CHECK: [[VAR_37_2_:%.+]] = builtin.unrealized_conversion_cast [[VAR_36_3_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_37_2_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4096xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<256x16xui8>, memref, memref // CHECK: } @@ -203,11 +211,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 * -542 + 4335, 542)> // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> @@ -234,46 +242,46 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { +// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_0_]]([[VAR_35_]]) +// CHECK-DAG: [[VAR_37_:%.+]] = affine.min [[MAP_1_]]([[VAR_35_]]) +// CHECK-DAG: [[VAR_38_:%.+]] = affine.apply [[MAP_2_]]([[VAR_35_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_3_]]([[VAR_35_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_36_]] to [[VAR_39_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_53_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_52_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_53_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) -// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { +// CHECK: [[VAR_40_:%.+]] = affine.min [[MAP_4_]]([[VAR_35_]]) +// CHECK: [[VAR_41_:%.+]] = arith.remsi [[VAR_40_]], [[CST_32_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.subi [[VAR_40_]], [[VAR_41_]] : index +// CHECK: [[VAR_43_:%.+]] = arith.addi [[VAR_36_]], [[VAR_42_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_43_]] to [[VAR_37_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_53_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_52_1_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_53_1_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_47_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_46_]], [[RES_5_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_47_]], [[RES_7_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -281,16 +289,16 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_37_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_36_1_]] : f32 +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_37_1_]] : f32 +// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -336,66 +344,72 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<4335xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4328){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<8xf32> -// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<8xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ +// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_37_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_37_2_]] : vector<16xf32> +// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : vector<16xf32> +// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_39_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_1_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_39_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_46_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_47_1_:%.+]] = arith.subf [[VAR_39_1_]], [[VAR_46_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_47_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_52_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_43_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_53_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_2_]], [[VAR_53_2_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: vector.store [[VAR_56_]], [[RES_12_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4328 to 4335){ -// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> -// CHECK: [[VAR_36_3_:%.+]] = arith.divf [[VAR_35_1_1_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_37_2_:%.+]] = math.floor [[VAR_36_3_]] : f32 -// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_3_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.mulf [[VAR_37_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_42_2_]] : f32 +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ +// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_37_3_:%.+]] = arith.divf [[VAR_36_1_1_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_38_2_:%.+]] = math.floor [[VAR_37_3_]] : f32 +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_3_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[VAR_43_2_:%.+]] = arith.mulf [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_43_2_]] : f32 // CHECK: [[LOAD_RES_6_MEM_2_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_2_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_45_2_:%.+]] = arith.subf [[VAR_37_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 -// CHECK-DAG: [[VAR_46_2_:%.+]] = arith.cmpf oeq, [[VAR_45_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_46_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_41_2_]] : f32 -// CHECK: [[VAR_51_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_52_3_:%.+]] = arith.maxnumf [[VAR_51_3_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_3_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i8 -// CHECK: [[VAR_55_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_55_1_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> +// CHECK: [[VAR_46_2_:%.+]] = arith.subf [[VAR_38_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 +// CHECK-DAG: [[VAR_47_2_:%.+]] = arith.cmpf oeq, [[VAR_46_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_47_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_42_2_]] : f32 +// CHECK: [[VAR_52_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_53_3_:%.+]] = arith.maxnumf [[VAR_52_3_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_54_1_]], [[RES_12_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: } +// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_4_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_6_:%.+]] = 0 to 4335){ +// CHECK: [[VAR_35_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[RES_12_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_37_4_:%.+]] = arith.fptoui [[VAR_36_1_1_]] : f32 to i8 +// CHECK: [[VAR_38_3_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_4_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_38_3_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<255x17xui8>, memref, memref // CHECK: } @@ -440,16 +454,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -495,37 +509,46 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__1_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_2_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 8){ +// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir index e456311773..ff079672bd 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -60,22 +60,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_12_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_]]#0, [[VAR_12_]]#1] : memref +// CHECK: [[VAR_13_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_13_]]#0, [[VAR_13_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_15_]], [[RES_3_]][] : memref +// CHECK: [[VAR_16_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_16_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_12_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_1_]]#0, [[VAR_12_1_]]#1] : memref +// CHECK: [[VAR_13_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_13_1_]]#0, [[VAR_13_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_15_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_15_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_16_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_16_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -95,32 +95,39 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: affine.store [[VAR_10_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_9_]]) {{.*}}: memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_12_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: [[VAR_13_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_13_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_15_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_15_2_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_16_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_16_2_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_16_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_2_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_2_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_16_2_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.mulf [[VAR_16_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_22_:%.+]] = math.floor [[VAR_21_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.mulf [[VAR_22_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.subf [[VAR_16_2_]], [[VAR_23_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.cmpf oeq, [[VAR_24_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_26_:%.+]] = arith.addf [[VAR_16_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_2_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_28_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.minnumf [[VAR_29_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_31_:%.+]] = arith.fptoui [[VAR_30_]] : f32 to i8 -// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_32_]], [[VAR_reshape_14_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_2_]] : f32 +// CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.maxnumf [[VAR_29_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_31_:%.+]] = arith.minnumf [[VAR_30_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_31_]], [[RES_7_]]{{.}}[[VAR_13_2_]]{{.}} : memref +// CHECK: } +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_13_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_13_3_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8 +// CHECK: [[VAR_16_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_16_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_13_3_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } @@ -143,32 +150,39 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_2_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_4_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_5_:%.+]] = math.floor [[VAR_4_]] : f32 -// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_5_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_6_:%.+]] = math.floor [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_5_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpf ogt, [[VAR_7_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_6_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[VAR_8_]], [[VAR_5_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.mulf [[VAR_5_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_11_:%.+]] = math.floor [[VAR_10_]] : f32 -// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_5_]], [[VAR_12_]] : f32 -// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_8_]], [[VAR_9_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.mulf [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = math.floor [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.subf [[VAR_6_]], [[VAR_13_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_14_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addf [[VAR_6_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_5_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[VAR_16_]], [[VAR_9_]] : f32 -// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.minnumf [[VAR_19_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.fptoui [[VAR_20_]] : f32 to i8 -// CHECK: [[VAR_22_:%.+]] = builtin.unrealized_conversion_cast [[VAR_21_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_22_]], [[RES_]]{{.}}[[VAR_2_]]{{.}} : memref<6xui8> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.select [[VAR_15_]], [[VAR_16_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_7_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.select [[VAR_18_]], [[VAR_17_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.maxnumf [[VAR_19_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.minnumf [[VAR_20_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_21_]], [[RES_1_]]{{.}}[[VAR_3_]]{{.}} : memref<6xf32> +// CHECK: } +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ +// CHECK: [[VAR_3_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_3_1_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_5_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 +// CHECK: [[VAR_6_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_5_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_6_1_]], [[RES_]]{{.}}[[VAR_3_1_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir index 657df8c178..6ec8f9d8cb 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir @@ -23,33 +23,40 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 // CHECK-DAG: [[VAR_3_:%.+]] = arith.uitofp [[VAR_2_]] : i8 to f32 +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpf ogt, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.select [[VAR_11_]], [[VAR_12_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.mulf [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.mulf [[VAR_15_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_9_]], [[VAR_16_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.fptoui [[VAR_24_]] : f32 to i8 -// CHECK: [[VAR_26_:%.+]] = builtin.unrealized_conversion_cast [[VAR_25_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_26_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> +// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[VAR_20_]], [[VAR_13_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.addf [[VAR_22_]], [[VAR_3_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.maxnumf [[VAR_23_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.minnumf [[VAR_24_]], [[CST_2_dot_550000_]] : f32 +// CHECK: krnl.store [[VAR_25_]], [[RES_1_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> +// CHECK: } +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ +// CHECK: [[VAR_6_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_8_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 +// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_8_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_9_1_]], [[RES_]]{{.}}[[VAR_6_1_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } @@ -57,10 +64,12 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // ----- + func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xi8> { %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xi8> return %0 : tensor<6xi8> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_quantize_linear_i8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xi8> { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -73,33 +82,41 @@ func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %ar // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_2_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32 +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_4_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_6_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_7_:%.+]] = math.floor [[VAR_6_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpf ogt, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_9_]], [[VAR_10_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_12_:%.+]] = arith.mulf [[VAR_7_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_13_:%.+]] = math.floor [[VAR_12_]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.mulf [[VAR_13_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.subf [[VAR_7_]], [[VAR_14_]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf oeq, [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_16_]], [[VAR_17_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.cmpf oeq, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[VAR_18_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.addf [[VAR_20_]], [[VAR_2_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[VAR_21_]], [[CST_minus_1_dot_280000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.minnumf [[VAR_22_]], [[CST_1_dot_270000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.fptosi [[VAR_23_]] : f32 to i8 -// CHECK: krnl.store [[VAR_24_]], [[RES_]]{{.}}[[VAR_4_]]{{.}} : memref<6xi8> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_2_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_1_dot_270000_]] : f32 +// CHECK: krnl.store [[VAR_24_]], [[RES_1_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> +// CHECK: } +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ +// CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_7_1_:%.+]] = arith.fptosi [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 +// CHECK: krnl.store [[VAR_7_1_]], [[RES_]]{{.}}[[VAR_5_1_]]{{.}} : memref<6xi8> // CHECK: } // CHECK: return [[RES_]] : memref<6xi8> // CHECK: } } +