diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index 84b854026545..ad2ce7c48c05 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" @@ -740,25 +741,14 @@ class MaterializeContractionOp auto converter = static_cast( this->getTypeConverter()); - if (auto layoutAttr = converter->getLayoutAttr()) { - SmallVector convertedResTypes; - for (auto init : op.getDpsInits()) { - convertedResTypes.push_back(converter->convertType(init.getType())); - } - Operation *newOp = - layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } - - FailureOr convertedOp = - IREE::Codegen::lowerContractionOpWithEncoding( - rewriter, op, operands, converter->getTransposeNarrowN(), - converter->getLayoutAttr()); - if (failed(convertedOp)) { - return failure(); + IREE::Codegen::LayoutAttrInterface layoutAttr = converter->getLayoutAttr(); + SmallVector convertedResTypes; + for (auto init : op.getDpsInits()) { + convertedResTypes.push_back(converter->convertType(init.getType())); } - rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResult(0)); + Operation *newOp = + layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands); + rewriter.replaceOp(op, newOp->getResults()); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 32dbc46563b2..c515766e396f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -305,274 +305,4 @@ getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) { return encodingInfo; } -static RankedTensorType dropEncoding(RankedTensorType type) { - return RankedTensorType::get(type.getShape(), type.getElementType()); -} - -static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands) { - SmallVector operands; - operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); - operands.append(convertedOutputOperands.begin(), - convertedOutputOperands.end()); - return mlir::clone(builder, op, - {dropEncoding(cast( - convertedOutputOperands[0].getType()))}, - operands); -} - -static RankedTensorType -getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed, - SmallVectorImpl &ri) { - if (!isBatched) { - ri.assign({{0, 1}, {2, 3}}); - if (!isTransposed) { - return RankedTensorType::get( - {1, type.getDimSize(0), 1, type.getDimSize(1)}, - type.getElementType()); - } - return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, - type.getElementType()); - } - - ri.assign({{0}, {1, 2}, {3, 4}}); - if (!isTransposed) { - return RankedTensorType::get( - {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, - type.getElementType()); - } - return RankedTensorType::get( - {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, - type.getElementType()); -} - -/// Given an input Value and a desired output element type, create and return -/// an element-wise linalg::GenericOp that extends the input Value to the -/// output element type. -static Value createElementWiseExtUIOp(OpBuilder &builder, Value input, - Location loc, Type outElemType) { - auto inputType = cast(input.getType()); - SmallVector maps( - 2, builder.getMultiDimIdentityMap(inputType.getRank())); - SmallVector iteratorTypes(inputType.getRank(), - utils::IteratorType::parallel); - auto castedType = inputType.clone(outElemType); - SmallVector inputMixedSizes = - tensor::getMixedSizes(builder, loc, input); - Value init = - builder.create(loc, inputMixedSizes, outElemType); - return builder - .create( - loc, castedType, input, init, maps, iteratorTypes, - [&](OpBuilder &b, Location nestedLoc, ValueRange args) { - Value castRes = - b.create(nestedLoc, outElemType, args[0]) - ->getResult(0); - b.create(nestedLoc, castRes); - }) - .getResult(0); -} - -/// If needed, expand and the input Value, and return the resulting input with -/// the canonical mmt4d input shape. If the input element type is unsigned, -/// create a producer Linalg::GenericOp on the input that unsigned extends the -/// input to the output element type. This extension is required to keep the -/// unsignedness information on the input for ukernels. If `transpose` is true, -/// the `linalgOp`'s indexing maps are transposed. -static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, - bool transpose, OpBuilder &builder, - SmallVectorImpl &ri, - ArrayRef elemTypes, int operandIdx) { - assert(linalgOp.getNumDpsInputs() == 2); - assert(linalgOp.getNumDpsInits() == 1); - auto cDims = linalg::inferContractionDims(linalgOp); - Location loc = linalgOp->getLoc(); - Value expandedValue = value; - // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the - // operand is a vector and must be extended - if ((cDims->m.empty() && operandIdx != 1) || - (cDims->n.empty() && operandIdx != 0)) { - auto type = cast(value.getType()); - RankedTensorType newType = getExpandedType( - type, /*isBatched=*/!cDims->batch.empty(), - /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); - expandedValue = - builder.create(loc, newType, value, ri); - } - if (elemTypes[operandIdx].isUnsignedInteger()) { - return createElementWiseExtUIOp(builder, expandedValue, loc, - elemTypes.back()); - } - return expandedValue; -} - -TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, - IREE::Encoding::MatmulNarrowDim narrowDim, - ArrayRef hostDefinedUpperBound) { - assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) && - "expected hostDefinedUpperBound is empty or has upper bound for {M, " - "N, K}"); - // Handle narrow-N by transposing to reduce to narrow-M. Note: the - // enumeratedTiles currently only enumerate narrow-M cases. - if (narrowDim.isN()) { - SmallVector newHostDefinedUpperBound(hostDefinedUpperBound); - std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]); - narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M; - TileMxNxK tile = - chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound); - std::swap(tile.M, tile.N); - return tile; - } - // Handle kDynamic: currently this is only used with VMVX, where there is only - // one enumerated tile and it has all three M/N/K dimensions dynamic, so for - // now we only support that. Generalize that as needed when more dynamic tile - // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide - // how to incorporate the handling of kDynamic in the cost-model evaluation - // below to decide when to prefer a dynamic vs a static tile shape. - for (auto tile : enumeratedTiles) { - if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) || - ShapedType::isDynamic(tile.K)) { - assert(enumeratedTiles.size() == 1); - assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) && - ShapedType::isDynamic(tile.K)); - return tile; - } - } - // We're going to "rate" the enumerated tiles. - struct RatedTileMxNxK : TileMxNxK { - RatedTileMxNxK() {} - RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {} - // Penalize tiles that are wider in the M dimension than matmulNarrowM. - int64_t paddingPenalty = 0; - // Favor larger tiles, as long as they still minimize paddingPenalty. - int64_t productMxNxK = 0; - }; - SmallVector ratedTiles; - ratedTiles.reserve(enumeratedTiles.size()); - int64_t bestPaddingPenalty = INT64_MAX; - int64_t mUB = INT64_MAX; - int64_t nUB = INT64_MAX; - int64_t kUB = INT64_MAX; - if (!hostDefinedUpperBound.empty()) { - mUB = hostDefinedUpperBound[0]; - nUB = hostDefinedUpperBound[1]; - kUB = hostDefinedUpperBound[2]; - } - for (auto tile : enumeratedTiles) { - if (tile.M > mUB || tile.N > nUB || tile.K > kUB) { - LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile ("; - llvm::interleaveComma( - ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); - llvm::dbgs() - << ") is skipped because it is not valid for upper_bound ("; - llvm::interleaveComma(ArrayRef{mUB, nUB, kUB}, - llvm::dbgs()); - llvm::dbgs() << ")\n"); - continue; - } - RatedTileMxNxK ratedTile(tile); - ratedTile.paddingPenalty = 0; - // If we are choosing a tile for a narrow-M case, we want to minimize - // padding along the M dimension. - // The PowerOf2Ceil is so that we are OK with padding up to the next - // power of two, we just try to avoid padding beyond that. For example, - // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we - // are OK with the tile that has M==8 even though it requires some padding. - // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would - // end up selecting the vecmat tile (M==1) for that case! - if (narrowDim) { - ratedTile.paddingPenalty = - std::max(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0); - } - ratedTile.productMxNxK = tile.M * tile.N * tile.K; - ratedTiles.push_back(ratedTile); - LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma( - ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); - llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n"); - bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty); - } - RatedTileMxNxK bestRatedTile; - for (auto ratedTile : ratedTiles) { - // Choose only among tiles that minimize paddingPenalty. Among those, - // maximize productMxNxK. - if (ratedTile.paddingPenalty == bestPaddingPenalty && - bestRatedTile.productMxNxK < ratedTile.productMxNxK) { - bestRatedTile = ratedTile; - } - } - // Sanity check. This assert can only fail if there's a programming mistake - // locally here. - assert(bestRatedTile.paddingPenalty == bestPaddingPenalty); - return bestRatedTile; -} - -FailureOr -lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, - ValueRange operands, bool transposeNarrowN, - LayoutAttrInterface layoutAttr) { - if (!linalgOp.hasPureTensorSemantics()) { - return failure(); - } - - auto inputs = linalgOp.getDpsInputOperands(); - auto outputs = linalgOp.getDpsInits(); - - auto lhsType = cast(inputs[0]->get().getType()); - auto rhsType = cast(inputs[1]->get().getType()); - auto resultType = cast(outputs[0].getType()); - auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); - auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); - auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); - if (!lhsEncoding || !rhsEncoding || !resultEncoding) { - return failure(); - } - - if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || - rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || - resultEncoding.getOperandIndex().getValue() != - IREE::Encoding::MATMUL_RESULT) { - return failure(); - } - - MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo( - cast(linalgOp->getResultTypes()[0])); - - if (isIdentityLayout(encodingInfo)) { - return dropEncodingAndCloneOp(builder, linalgOp, - operands.take_front(inputs.size()), - operands.drop_front(inputs.size())); - } - - bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding); - SmallVector elemTypes = lhsEncoding.getElementTypesArray(); - SmallVector ri; - Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri, - elemTypes, /*operandIdx=*/0); - Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri, - elemTypes, /*operandIdx=*/1); - Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder, - ri, elemTypes, /*operandIdx=*/2); - if (transpose) { - std::swap(newLhs, newRhs); - } - Type newResultType = newResult.getType(); - auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); - Operation *result; - if (cDims->batch.empty()) { - result = builder.create(linalgOp.getLoc(), newResultType, - ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } else { - result = builder.create( - linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } - if (!ri.empty()) { - result = builder.create( - linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); - } - return result; -} - } // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h index 1bee3ec74032..8a8c309d4ba5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h @@ -75,29 +75,6 @@ struct TileMxNxK { MaterializeEncodingInfo getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK); -//===----------------------------------------------------------------------===// -// Operation Lowering Utilities. -//===----------------------------------------------------------------------===// - -// TODO(hanchung): The below methods are exposed to public because they are -// shared between MaterializeEncodingIntoPackUnPack.cpp.cpp and -// CPUEncodingExternalModels.cpp. They will be moved to other places after all -// the CPU backends implement their layout attributes. - -/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the -/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be -/// greater than the values. -/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such -/// information to host. For now, they are defined by host. -TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, - IREE::Encoding::MatmulNarrowDim narrowDim, - ArrayRef hostDefinedUpperBound = {}); - -FailureOr -lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, - ValueRange operands, bool transposeNarrowN, - LayoutAttrInterface layoutAttr); - } // namespace mlir::iree_compiler::IREE::Codegen #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp index a847abd5a1b2..89de2e6dcc16 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp @@ -24,13 +24,292 @@ using Codegen::TileMxNxK; namespace { +//===----------------------------------------------------------------------===// +// Utilities. +//===----------------------------------------------------------------------===// + +static RankedTensorType dropEncoding(RankedTensorType type) { + return RankedTensorType::get(type.getShape(), type.getElementType()); +} + +static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, + ValueRange convertedInputOperands, + ValueRange convertedOutputOperands) { + SmallVector operands; + operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); + operands.append(convertedOutputOperands.begin(), + convertedOutputOperands.end()); + return mlir::clone(builder, op, + {dropEncoding(cast( + convertedOutputOperands[0].getType()))}, + operands); +} + +static RankedTensorType +getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed, + SmallVectorImpl &ri) { + if (!isBatched) { + ri.assign({{0, 1}, {2, 3}}); + if (!isTransposed) { + return RankedTensorType::get( + {1, type.getDimSize(0), 1, type.getDimSize(1)}, + type.getElementType()); + } + return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, + type.getElementType()); + } + + ri.assign({{0}, {1, 2}, {3, 4}}); + if (!isTransposed) { + return RankedTensorType::get( + {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, + type.getElementType()); + } + return RankedTensorType::get( + {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, + type.getElementType()); +} + +/// Given an input Value and a desired output element type, create and return +/// an element-wise linalg::GenericOp that extends the input Value to the +/// output element type. +static Value createElementWiseExtUIOp(OpBuilder &builder, Value input, + Location loc, Type outElemType) { + auto inputType = cast(input.getType()); + SmallVector maps( + 2, builder.getMultiDimIdentityMap(inputType.getRank())); + SmallVector iteratorTypes(inputType.getRank(), + utils::IteratorType::parallel); + auto castedType = inputType.clone(outElemType); + SmallVector inputMixedSizes = + tensor::getMixedSizes(builder, loc, input); + Value init = + builder.create(loc, inputMixedSizes, outElemType); + return builder + .create( + loc, castedType, input, init, maps, iteratorTypes, + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { + Value castRes = + b.create(nestedLoc, outElemType, args[0]) + ->getResult(0); + b.create(nestedLoc, castRes); + }) + .getResult(0); +} + +/// If needed, expand and the input Value, and return the resulting input with +/// the canonical mmt4d input shape. If the input element type is unsigned, +/// create a producer Linalg::GenericOp on the input that unsigned extends the +/// input to the output element type. This extension is required to keep the +/// unsignedness information on the input for ukernels. If `transpose` is true, +/// the `linalgOp`'s indexing maps are transposed. +static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, + bool transpose, OpBuilder &builder, + SmallVectorImpl &ri, + ArrayRef elemTypes, int operandIdx) { + assert(linalgOp.getNumDpsInputs() == 2); + assert(linalgOp.getNumDpsInits() == 1); + auto cDims = linalg::inferContractionDims(linalgOp); + Location loc = linalgOp->getLoc(); + Value expandedValue = value; + // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the + // operand is a vector and must be extended + if ((cDims->m.empty() && operandIdx != 1) || + (cDims->n.empty() && operandIdx != 0)) { + auto type = cast(value.getType()); + RankedTensorType newType = getExpandedType( + type, /*isBatched=*/!cDims->batch.empty(), + /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); + expandedValue = + builder.create(loc, newType, value, ri); + } + if (elemTypes[operandIdx].isUnsignedInteger()) { + return createElementWiseExtUIOp(builder, expandedValue, loc, + elemTypes.back()); + } + return expandedValue; +} + +/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the +/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be +/// greater than the values. +/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such +/// information to host. For now, they are defined by host. +TileMxNxK chooseMatmulTile(ArrayRef enumeratedTiles, + IREE::Encoding::MatmulNarrowDim narrowDim, + ArrayRef hostDefinedUpperBound = {}) { + assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) && + "expected hostDefinedUpperBound is empty or has upper bound for {M, " + "N, K}"); + // Handle narrow-N by transposing to reduce to narrow-M. Note: the + // enumeratedTiles currently only enumerate narrow-M cases. + if (narrowDim.isN()) { + SmallVector newHostDefinedUpperBound(hostDefinedUpperBound); + std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]); + narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M; + TileMxNxK tile = + chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound); + std::swap(tile.M, tile.N); + return tile; + } + // Handle kDynamic: currently this is only used with VMVX, where there is only + // one enumerated tile and it has all three M/N/K dimensions dynamic, so for + // now we only support that. Generalize that as needed when more dynamic tile + // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide + // how to incorporate the handling of kDynamic in the cost-model evaluation + // below to decide when to prefer a dynamic vs a static tile shape. + for (auto tile : enumeratedTiles) { + if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) || + ShapedType::isDynamic(tile.K)) { + assert(enumeratedTiles.size() == 1); + assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) && + ShapedType::isDynamic(tile.K)); + return tile; + } + } + // We're going to "rate" the enumerated tiles. + struct RatedTileMxNxK : TileMxNxK { + RatedTileMxNxK() {} + RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {} + // Penalize tiles that are wider in the M dimension than matmulNarrowM. + int64_t paddingPenalty = 0; + // Favor larger tiles, as long as they still minimize paddingPenalty. + int64_t productMxNxK = 0; + }; + SmallVector ratedTiles; + ratedTiles.reserve(enumeratedTiles.size()); + int64_t bestPaddingPenalty = INT64_MAX; + int64_t mUB = INT64_MAX; + int64_t nUB = INT64_MAX; + int64_t kUB = INT64_MAX; + if (!hostDefinedUpperBound.empty()) { + mUB = hostDefinedUpperBound[0]; + nUB = hostDefinedUpperBound[1]; + kUB = hostDefinedUpperBound[2]; + } + for (auto tile : enumeratedTiles) { + if (tile.M > mUB || tile.N > nUB || tile.K > kUB) { + LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile ("; + llvm::interleaveComma( + ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); + llvm::dbgs() + << ") is skipped because it is not valid for upper_bound ("; + llvm::interleaveComma(ArrayRef{mUB, nUB, kUB}, + llvm::dbgs()); + llvm::dbgs() << ")\n"); + continue; + } + RatedTileMxNxK ratedTile(tile); + ratedTile.paddingPenalty = 0; + // If we are choosing a tile for a narrow-M case, we want to minimize + // padding along the M dimension. + // The PowerOf2Ceil is so that we are OK with padding up to the next + // power of two, we just try to avoid padding beyond that. For example, + // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we + // are OK with the tile that has M==8 even though it requires some padding. + // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would + // end up selecting the vecmat tile (M==1) for that case! + if (narrowDim) { + ratedTile.paddingPenalty = + std::max(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0); + } + ratedTile.productMxNxK = tile.M * tile.N * tile.K; + ratedTiles.push_back(ratedTile); + LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma( + ArrayRef{tile.M, tile.N, tile.K}, llvm::dbgs()); + llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n"); + bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty); + } + RatedTileMxNxK bestRatedTile; + for (auto ratedTile : ratedTiles) { + // Choose only among tiles that minimize paddingPenalty. Among those, + // maximize productMxNxK. + if (ratedTile.paddingPenalty == bestPaddingPenalty && + bestRatedTile.productMxNxK < ratedTile.productMxNxK) { + bestRatedTile = ratedTile; + } + } + // Sanity check. This assert can only fail if there's a programming mistake + // locally here. + assert(bestRatedTile.paddingPenalty == bestPaddingPenalty); + return bestRatedTile; +} + +FailureOr +lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, + ValueRange operands, + IREE::Codegen::LayoutAttrInterface layoutAttr) { + if (!linalgOp.hasPureTensorSemantics()) { + return failure(); + } + + auto inputs = linalgOp.getDpsInputOperands(); + auto outputs = linalgOp.getDpsInits(); + + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); + auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); + auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); + if (!lhsEncoding || !rhsEncoding || !resultEncoding) { + return failure(); + } + + if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || + rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || + resultEncoding.getOperandIndex().getValue() != + IREE::Encoding::MATMUL_RESULT) { + return failure(); + } + + MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo( + cast(linalgOp->getResultTypes()[0])); + + if (isIdentityLayout(encodingInfo)) { + return dropEncodingAndCloneOp(builder, linalgOp, + operands.take_front(inputs.size()), + operands.drop_front(inputs.size())); + } + + bool transpose = isNarrowNResult(resultEncoding); + SmallVector elemTypes = lhsEncoding.getElementTypesArray(); + SmallVector ri; + Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri, + elemTypes, /*operandIdx=*/0); + Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri, + elemTypes, /*operandIdx=*/1); + Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder, + ri, elemTypes, /*operandIdx=*/2); + if (transpose) { + std::swap(newLhs, newRhs); + } + Type newResultType = newResult.getType(); + auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); + Operation *result; + if (cDims->batch.empty()) { + result = builder.create(linalgOp.getLoc(), newResultType, + ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } else { + result = builder.create( + linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } + if (!ri.empty()) { + result = builder.create( + linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); + } + return result; +} + //===----------------------------------------------------------------------===// // Interface methods implementaion for iree_cpu.cpu_encoding_layout. //===----------------------------------------------------------------------===// // Enumerate tile sizes to choose from on riscv32. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileRiscv32(DictionaryAttr config) { if (hasUkernel(config)) { @@ -47,7 +326,7 @@ enumerateMatmulTileRiscv32(DictionaryAttr config) { // Enumerate tile sizes to choose from on arm64. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileArm64(TypeRange elementTypes, DictionaryAttr config) { // Data-tiling for SVE is not implemented yet. @@ -137,7 +416,7 @@ static SmallVector enumerateMatmulTileArm64(TypeRange elementTypes, // Enumerate tile sizes to choose from on x86-64. // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases -// are handled by transposition in IREE::CPU::chooseMatmulTile. +// are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTileX86_64(TypeRange elementTypes, DictionaryAttr config) { assert(elementTypes.size() == 3); @@ -309,8 +588,8 @@ struct CPUDeviceEncodingLayoutAttrInterface return nullptr; } - FailureOr newOp = Codegen::lowerContractionOpWithEncoding( - b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, + FailureOr newOp = lowerContractionOpWithEncoding( + b, linalgOp, convertedOperands, cast(layoutAttr)); return newOp.value_or(nullptr); } @@ -393,8 +672,8 @@ struct VMVXDeviceEncodingLayoutAttrInterface return nullptr; } - FailureOr newOp = Codegen::lowerContractionOpWithEncoding( - b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, + FailureOr newOp = lowerContractionOpWithEncoding( + b, linalgOp, convertedOperands, cast(layoutAttr)); return newOp.value_or(nullptr); }