From ad938ae6e36d458904cc24b93648234ebe62ace0 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 12 Dec 2024 20:04:00 -0800 Subject: [PATCH] [DT][NFC] Localize CPU specific encoding materialization logic. (#19452) The revision moves the CPU materialization logic from Dialect/Codegen/Utils/Utils.[h|cpp] to CPUEncodingExternalModels. They were public methods during transition states. After all the CPU layout attributes are implemented, we no longer need to expose them to the public. Additionally, it removes the outdated logic from MaterializeContractionOp pattern. And it removes the `transposeNarrowN` input argument from lowerContractionOpWithEncoding method because all the CPU backends enable the transposeNarrowN feature. Signed-off-by: hanhanW --- .../MaterializeEncodingIntoPackUnPack.cpp | 26 +- .../Codegen/Dialect/Codegen/Utils/Utils.cpp | 270 ---------------- .../Codegen/Dialect/Codegen/Utils/Utils.h | 23 -- .../CPUEncodingExternalModels.cpp | 293 +++++++++++++++++- 4 files changed, 294 insertions(+), 318 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index 84b854026545..ad2ce7c48c05 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" @@ -740,25 +741,14 @@ class MaterializeContractionOp auto converter = static_cast( this->getTypeConverter()); - if (auto layoutAttr = converter->getLayoutAttr()) { - SmallVector convertedResTypes; - for (auto init : op.getDpsInits()) { - convertedResTypes.push_back(converter->convertType(init.getType())); - } - Operation *newOp = - layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } - - FailureOr convertedOp = - IREE::Codegen::lowerContractionOpWithEncoding( - rewriter, op, operands, converter->getTransposeNarrowN(), - converter->getLayoutAttr()); - if (failed(convertedOp)) { - return failure(); + IREE::Codegen::LayoutAttrInterface layoutAttr = converter->getLayoutAttr(); + SmallVector convertedResTypes; + for (auto init : op.getDpsInits()) { + convertedResTypes.push_back(converter->convertType(init.getType())); } - rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResult(0)); + Operation *newOp = + layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands); + rewriter.replaceOp(op, newOp->getResults()); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 32dbc46563b2..c515766e396f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -305,274 +305,4 @@ getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) { return encodingInfo; } -static RankedTensorType dropEncoding(RankedTensorType type) { - return RankedTensorType::get(type.getShape(), type.getElementType()); -} - -static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands) { - SmallVector operands; - operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); - operands.append(convertedOutputOperands.begin(), - convertedOutputOperands.end()); - return mlir::clone(builder, op, - {dropEncoding(cast( - convertedOutputOperands[0].getType()))}, - operands); -} - -static RankedTensorType -getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed, - SmallVectorImpl &ri) { - if (!isBatched) { - ri.assign({{0, 1}, {2, 3}}); - if (!isTransposed) { - return RankedTensorType::get( - {1, type.getDimSize(0), 1, type.getDimSize(1)}, - type.getElementType()); - } - return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, - type.getElementType()); - } - - ri.assign({{0}, {1, 2}, {3, 4}}); - if (!isTransposed) { - return RankedTensorType::get( - {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, - type.getElementType()); - } - return RankedTensorType::get( - {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, - type.getElementType()); -} - -/// Given an input Value and a desired output element type, create and return -/// an element-wise linalg::GenericOp that extends the input Value to the -/// output element type. -static Value createElementWiseExtUIOp(OpBuilder &builder, Value input, - Location loc, Type outElemType) { - auto inputType = cast(input.getType()); - SmallVector maps( - 2, builder.getMultiDimIdentityMap(inputType.getRank())); - SmallVector iteratorTypes(inputType.getRank(), - utils::IteratorType::parallel); - auto castedType = inputType.clone(outElemType); - SmallVector inputMixedSizes = - tensor::getMixedSizes(builder, loc, input); - Value init = - builder.create(loc, inputMixedSizes, outElemType); - return builder - .create( - loc, castedType, input, init, maps, iteratorTypes, - [&](OpBuilder &b, Location nestedLoc, ValueRange args) { - Value castRes = - b.create(nestedLoc, outElemType, args[0]) - ->getResult(0); - b.create(nestedLoc, castRes); - }) - .getResult(0); -} - -/// If needed, expand and the input Value, and return the resulting input with -/// the canonical mmt4d input shape. If the input element type is unsigned, -/// create a producer Linalg::GenericOp on the input that unsigned extends the -/// input to the output element type. This extension is required to keep the -/// unsignedness information on the input for ukernels. If `transpose` is true, -/// the `linalgOp`'s indexing maps are transposed. -static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, - bool transpose, OpBuilder &builder, - SmallVectorImpl &ri, - ArrayRef elemTypes, int operandIdx) { - assert(linalgOp.getNumDpsInputs() == 2); - assert(linalgOp.getNumDpsInits() == 1); - auto cDims = linalg::inferContractionDims(linalgOp); - Location loc = linalgOp->getLoc(); - Value expandedValue = value; - // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the - // operand is a vector and must be extended - if ((cDims->m.empty() && operandIdx != 1) || - (cDims->n.empty() && operandIdx != 0)) { - auto type = cast(value.getType()); - RankedTensorType newType = getExpandedType( - type, /*isBatched=*/!cDims->batch.empty(), - /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); - expandedValue = - builder.create(loc, newType, value, ri); - } - if (elemTypes[operandIdx].isUnsignedInteger()) { - return createElementWiseExtUIOp(builder, expandedValue, loc, - elemTypes.back()); - } - return expandedValue; -} - -TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, - IREE::Encoding::MatmulNarrowDim narrowDim, - ArrayRef hostDefinedUpperBound) { - assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) && - "expected hostDefinedUpperBound is empty or has upper bound for {M, " - "N, K}"); - // Handle narrow-N by transposing to reduce to narrow-M. Note: the - // enumeratedTiles currently only enumerate narrow-M cases. - if (narrowDim.isN()) { - SmallVector newHostDefinedUpperBound(hostDefinedUpperBound); - std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]); - narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M; - TileMxNxK tile = - chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound); - std::swap(tile.M, tile.N); - return tile; - } - // Handle kDynamic: currently this is only used with VMVX, where there is only - // one enumerated tile and it has all three M/N/K dimensions dynamic, so for - // now we only support that. Generalize that as needed when more dynamic tile - // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide - // how to incorporate the handling of kDynamic in the cost-model evaluation - // below to decide when to prefer a dynamic vs a static tile shape. - for (auto tile : enumeratedTiles) { - if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) || - ShapedType::isDynamic(tile.K)) { - assert(enumeratedTiles.size() == 1); - assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) && - ShapedType::isDynamic(tile.K)); - return tile; - } - } - // We're going to "rate" the enumerated tiles. - struct RatedTileMxNxK : TileMxNxK { - RatedTileMxNxK() {} - RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {} - // Penalize tiles that are wider in the M dimension than matmulNarrowM. - int64_t paddingPenalty = 0; - // Favor larger tiles, as long as they still minimize paddingPenalty. - int64_t productMxNxK = 0; - }; - SmallVector ratedTiles; - ratedTiles.reserve(enumeratedTiles.size()); - int64_t bestPaddingPenalty = INT64_MAX; - int64_t mUB = INT64_MAX; - int64_t nUB = INT64_MAX; - int64_t kUB = INT64_MAX; - if (!hostDefinedUpperBound.empty()) { - mUB = hostDefinedUpperBound[0]; - nUB = hostDefinedUpperBound[1]; - kUB = hostDefinedUpperBound[2]; - } - for (auto tile : enumeratedTiles) { - if (tile.M > mUB || tile.N > nUB || tile.K > kUB) { - LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile ("; - llvm::interleaveComma( - ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); - llvm::dbgs() - << ") is skipped because it is not valid for upper_bound ("; - llvm::interleaveComma(ArrayRef{mUB, nUB, kUB}, - llvm::dbgs()); - llvm::dbgs() << ")\n"); - continue; - } - RatedTileMxNxK ratedTile(tile); - ratedTile.paddingPenalty = 0; - // If we are choosing a tile for a narrow-M case, we want to minimize - // padding along the M dimension. - // The PowerOf2Ceil is so that we are OK with padding up to the next - // power of two, we just try to avoid padding beyond that. For example, - // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we - // are OK with the tile that has M==8 even though it requires some padding. - // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would - // end up selecting the vecmat tile (M==1) for that case! - if (narrowDim) { - ratedTile.paddingPenalty = - std::max(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0); - } - ratedTile.productMxNxK = tile.M * tile.N * tile.K; - ratedTiles.push_back(ratedTile); - LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma( - ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); - llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n"); - bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty); - } - RatedTileMxNxK bestRatedTile; - for (auto ratedTile : ratedTiles) { - // Choose only among tiles that minimize paddingPenalty. Among those, - // maximize productMxNxK. - if (ratedTile.paddingPenalty == bestPaddingPenalty && - bestRatedTile.productMxNxK < ratedTile.productMxNxK) { - bestRatedTile = ratedTile; - } - } - // Sanity check. This assert can only fail if there's a programming mistake - // locally here. - assert(bestRatedTile.paddingPenalty == bestPaddingPenalty); - return bestRatedTile; -} - -FailureOr -lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, - ValueRange operands, bool transposeNarrowN, - LayoutAttrInterface layoutAttr) { - if (!linalgOp.hasPureTensorSemantics()) { - return failure(); - } - - auto inputs = linalgOp.getDpsInputOperands(); - auto outputs = linalgOp.getDpsInits(); - - auto lhsType = cast(inputs[0]->get().getType()); - auto rhsType = cast(inputs[1]->get().getType()); - auto resultType = cast(outputs[0].getType()); - auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); - auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); - auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); - if (!lhsEncoding || !rhsEncoding || !resultEncoding) { - return failure(); - } - - if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || - rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || - resultEncoding.getOperandIndex().getValue() != - IREE::Encoding::MATMUL_RESULT) { - return failure(); - } - - MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo( - cast(linalgOp->getResultTypes()[0])); - - if (isIdentityLayout(encodingInfo)) { - return dropEncodingAndCloneOp(builder, linalgOp, - operands.take_front(inputs.size()), - operands.drop_front(inputs.size())); - } - - bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding); - SmallVector elemTypes = lhsEncoding.getElementTypesArray(); - SmallVector ri; - Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri, - elemTypes, /*operandIdx=*/0); - Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri, - elemTypes, /*operandIdx=*/1); - Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder, - ri, elemTypes, /*operandIdx=*/2); - if (transpose) { - std::swap(newLhs, newRhs); - } - Type newResultType = newResult.getType(); - auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); - Operation *result; - if (cDims->batch.empty()) { - result = builder.create(linalgOp.getLoc(), newResultType, - ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } else { - result = builder.create( - linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } - if (!ri.empty()) { - result = builder.create( - linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); - } - return result; -} - } // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h index 1bee3ec74032..8a8c309d4ba5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h @@ -75,29 +75,6 @@ struct TileMxNxK { MaterializeEncodingInfo getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK); -//===----------------------------------------------------------------------===// -// Operation Lowering Utilities. -//===----------------------------------------------------------------------===// - -// TODO(hanchung): The below methods are exposed to public because they are -// shared between MaterializeEncodingIntoPackUnPack.cpp.cpp and -// CPUEncodingExternalModels.cpp. They will be moved to other places after all -// the CPU backends implement their layout attributes. - -/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the -/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be -/// greater than the values. -/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such -/// information to host. For now, they are defined by host. -TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, - IREE::Encoding::MatmulNarrowDim narrowDim, - ArrayRef hostDefinedUpperBound = {}); - -FailureOr -lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, - ValueRange operands, bool transposeNarrowN, - LayoutAttrInterface layoutAttr); - } // namespace mlir::iree_compiler::IREE::Codegen #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp index a847abd5a1b2..89de2e6dcc16 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp @@ -24,13 +24,292 @@ using Codegen::TileMxNxK; namespace { +//===----------------------------------------------------------------------===// +// Utilities. +//===----------------------------------------------------------------------===// + +static RankedTensorType dropEncoding(RankedTensorType type) { + return RankedTensorType::get(type.getShape(), type.getElementType()); +} + +static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, + ValueRange convertedInputOperands, + ValueRange convertedOutputOperands) { + SmallVector operands; + operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); + operands.append(convertedOutputOperands.begin(), + convertedOutputOperands.end()); + return mlir::clone(builder, op, + {dropEncoding(cast( + convertedOutputOperands[0].getType()))}, + operands); +} + +static RankedTensorType +getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed, + SmallVectorImpl &ri) { + if (!isBatched) { + ri.assign({{0, 1}, {2, 3}}); + if (!isTransposed) { + return RankedTensorType::get( + {1, type.getDimSize(0), 1, type.getDimSize(1)}, + type.getElementType()); + } + return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, + type.getElementType()); + } + + ri.assign({{0}, {1, 2}, {3, 4}}); + if (!isTransposed) { + return RankedTensorType::get( + {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, + type.getElementType()); + } + return RankedTensorType::get( + {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, + type.getElementType()); +} + +/// Given an input Value and a desired output element type, create and return +/// an element-wise linalg::GenericOp that extends the input Value to the +/// output element type. +static Value createElementWiseExtUIOp(OpBuilder &builder, Value input, + Location loc, Type outElemType) { + auto inputType = cast(input.getType()); + SmallVector maps( + 2, builder.getMultiDimIdentityMap(inputType.getRank())); + SmallVector iteratorTypes(inputType.getRank(), + utils::IteratorType::parallel); + auto castedType = inputType.clone(outElemType); + SmallVector inputMixedSizes = + tensor::getMixedSizes(builder, loc, input); + Value init = + builder.create(loc, inputMixedSizes, outElemType); + return builder + .create( + loc, castedType, input, init, maps, iteratorTypes, + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { + Value castRes = + b.create(nestedLoc, outElemType, args[0]) + ->getResult(0); + b.create(nestedLoc, castRes); + }) + .getResult(0); +} + +/// If needed, expand and the input Value, and return the resulting input with +/// the canonical mmt4d input shape. If the input element type is unsigned, +/// create a producer Linalg::GenericOp on the input that unsigned extends the +/// input to the output element type. This extension is required to keep the +/// unsignedness information on the input for ukernels. If `transpose` is true, +/// the `linalgOp`'s indexing maps are transposed. +static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, + bool transpose, OpBuilder &builder, + SmallVectorImpl &ri, + ArrayRef elemTypes, int operandIdx) { + assert(linalgOp.getNumDpsInputs() == 2); + assert(linalgOp.getNumDpsInits() == 1); + auto cDims = linalg::inferContractionDims(linalgOp); + Location loc = linalgOp->getLoc(); + Value expandedValue = value; + // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the + // operand is a vector and must be extended + if ((cDims->m.empty() && operandIdx != 1) || + (cDims->n.empty() && operandIdx != 0)) { + auto type = cast(value.getType()); + RankedTensorType newType = getExpandedType( + type, /*isBatched=*/!cDims->batch.empty(), + /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); + expandedValue = + builder.create(loc, newType, value, ri); + } + if (elemTypes[operandIdx].isUnsignedInteger()) { + return createElementWiseExtUIOp(builder, expandedValue, loc, + elemTypes.back()); + } + return expandedValue; +} + +/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the +/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be +/// greater than the values. +/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such +/// information to host. For now, they are defined by host. +TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, + IREE::Encoding::MatmulNarrowDim narrowDim, + ArrayRef hostDefinedUpperBound = {}) { + assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) && + "expected hostDefinedUpperBound is empty or has upper bound for {M, " + "N, K}"); + // Handle narrow-N by transposing to reduce to narrow-M. Note: the + // enumeratedTiles currently only enumerate narrow-M cases. + if (narrowDim.isN()) { + SmallVector newHostDefinedUpperBound(hostDefinedUpperBound); + std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]); + narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M; + TileMxNxK tile = + chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound); + std::swap(tile.M, tile.N); + return tile; + } + // Handle kDynamic: currently this is only used with VMVX, where there is only + // one enumerated tile and it has all three M/N/K dimensions dynamic, so for + // now we only support that. Generalize that as needed when more dynamic tile + // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide + // how to incorporate the handling of kDynamic in the cost-model evaluation + // below to decide when to prefer a dynamic vs a static tile shape. + for (auto tile : enumeratedTiles) { + if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) || + ShapedType::isDynamic(tile.K)) { + assert(enumeratedTiles.size() == 1); + assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) && + ShapedType::isDynamic(tile.K)); + return tile; + } + } + // We're going to "rate" the enumerated tiles. + struct RatedTileMxNxK : TileMxNxK { + RatedTileMxNxK() {} + RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {} + // Penalize tiles that are wider in the M dimension than matmulNarrowM. + int64_t paddingPenalty = 0; + // Favor larger tiles, as long as they still minimize paddingPenalty. + int64_t productMxNxK = 0; + }; + SmallVector ratedTiles; + ratedTiles.reserve(enumeratedTiles.size()); + int64_t bestPaddingPenalty = INT64_MAX; + int64_t mUB = INT64_MAX; + int64_t nUB = INT64_MAX; + int64_t kUB = INT64_MAX; + if (!hostDefinedUpperBound.empty()) { + mUB = hostDefinedUpperBound[0]; + nUB = hostDefinedUpperBound[1]; + kUB = hostDefinedUpperBound[2]; + } + for (auto tile : enumeratedTiles) { + if (tile.M > mUB || tile.N > nUB || tile.K > kUB) { + LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile ("; + llvm::interleaveComma( + ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); + llvm::dbgs() + << ") is skipped because it is not valid for upper_bound ("; + llvm::interleaveComma(ArrayRef{mUB, nUB, kUB}, + llvm::dbgs()); + llvm::dbgs() << ")\n"); + continue; + } + RatedTileMxNxK ratedTile(tile); + ratedTile.paddingPenalty = 0; + // If we are choosing a tile for a narrow-M case, we want to minimize + // padding along the M dimension. + // The PowerOf2Ceil is so that we are OK with padding up to the next + // power of two, we just try to avoid padding beyond that. For example, + // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we + // are OK with the tile that has M==8 even though it requires some padding. + // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would + // end up selecting the vecmat tile (M==1) for that case! + if (narrowDim) { + ratedTile.paddingPenalty = + std::max(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0); + } + ratedTile.productMxNxK = tile.M * tile.N * tile.K; + ratedTiles.push_back(ratedTile); + LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma( + ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); + llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n"); + bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty); + } + RatedTileMxNxK bestRatedTile; + for (auto ratedTile : ratedTiles) { + // Choose only among tiles that minimize paddingPenalty. Among those, + // maximize productMxNxK. + if (ratedTile.paddingPenalty == bestPaddingPenalty && + bestRatedTile.productMxNxK < ratedTile.productMxNxK) { + bestRatedTile = ratedTile; + } + } + // Sanity check. This assert can only fail if there's a programming mistake + // locally here. + assert(bestRatedTile.paddingPenalty == bestPaddingPenalty); + return bestRatedTile; +} + +FailureOr +lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, + ValueRange operands, + IREE::Codegen::LayoutAttrInterface layoutAttr) { + if (!linalgOp.hasPureTensorSemantics()) { + return failure(); + } + + auto inputs = linalgOp.getDpsInputOperands(); + auto outputs = linalgOp.getDpsInits(); + + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); + auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); + auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); + if (!lhsEncoding || !rhsEncoding || !resultEncoding) { + return failure(); + } + + if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || + rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || + resultEncoding.getOperandIndex().getValue() != + IREE::Encoding::MATMUL_RESULT) { + return failure(); + } + + MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo( + cast(linalgOp->getResultTypes()[0])); + + if (isIdentityLayout(encodingInfo)) { + return dropEncodingAndCloneOp(builder, linalgOp, + operands.take_front(inputs.size()), + operands.drop_front(inputs.size())); + } + + bool transpose = isNarrowNResult(resultEncoding); + SmallVector elemTypes = lhsEncoding.getElementTypesArray(); + SmallVector ri; + Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri, + elemTypes, /*operandIdx=*/0); + Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri, + elemTypes, /*operandIdx=*/1); + Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder, + ri, elemTypes, /*operandIdx=*/2); + if (transpose) { + std::swap(newLhs, newRhs); + } + Type newResultType = newResult.getType(); + auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); + Operation *result; + if (cDims->batch.empty()) { + result = builder.create(linalgOp.getLoc(), newResultType, + ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } else { + result = builder.create( + linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } + if (!ri.empty()) { + result = builder.create( + linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); + } + return result; +} + //===----------------------------------------------------------------------===// // Interface methods implementaion for iree_cpu.cpu_encoding_layout. //===----------------------------------------------------------------------===// // Enumerate tile sizes to choose from on riscv32. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileRiscv32(DictionaryAttr config) { if (hasUkernel(config)) { @@ -47,7 +326,7 @@ enumerateMatmulTileRiscv32(DictionaryAttr config) { // Enumerate tile sizes to choose from on arm64. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileArm64(TypeRange elementTypes, DictionaryAttr config) { // Data-tiling for SVE is not implemented yet. @@ -137,7 +416,7 @@ static SmallVector enumerateMatmulTileArm64(TypeRange elementTypes, // Enumerate tile sizes to choose from on x86-64. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileX86_64(TypeRange elementTypes, DictionaryAttr config) { assert(elementTypes.size() == 3); @@ -309,8 +588,8 @@ struct CPUDeviceEncodingLayoutAttrInterface return nullptr; } - FailureOr newOp = Codegen::lowerContractionOpWithEncoding( - b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, + FailureOr newOp = lowerContractionOpWithEncoding( + b, linalgOp, convertedOperands, cast(layoutAttr)); return newOp.value_or(nullptr); } @@ -393,8 +672,8 @@ struct VMVXDeviceEncodingLayoutAttrInterface return nullptr; } - FailureOr newOp = Codegen::lowerContractionOpWithEncoding( - b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, + FailureOr newOp = lowerContractionOpWithEncoding( + b, linalgOp, convertedOperands, cast(layoutAttr)); return newOp.value_or(nullptr); }