Skip to content

Commit

Permalink
[DT][NFC] Localize CPU specific encoding materialization logic. (#19452)
Browse files Browse the repository at this point in the history
The revision moves the CPU materialization logic from
Dialect/Codegen/Utils/Utils.[h|cpp] to CPUEncodingExternalModels. They
were public methods during transition states. After all the CPU layout
attributes are implemented, we no longer need to expose them to the
public.

Additionally, it removes the outdated logic from
MaterializeContractionOp pattern. And it removes the `transposeNarrowN`
input argument from lowerContractionOpWithEncoding method because all
the CPU backends enable the transposeNarrowN feature.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Dec 13, 2024
1 parent c618134 commit ad938ae
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 318 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
Expand Down Expand Up @@ -740,25 +741,14 @@ class MaterializeContractionOp
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());

if (auto layoutAttr = converter->getLayoutAttr()) {
SmallVector<Type> convertedResTypes;
for (auto init : op.getDpsInits()) {
convertedResTypes.push_back(converter->convertType(init.getType()));
}
Operation *newOp =
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
rewriter.replaceOp(op, newOp->getResults());
return success();
}

FailureOr<Operation *> convertedOp =
IREE::Codegen::lowerContractionOpWithEncoding(
rewriter, op, operands, converter->getTransposeNarrowN(),
converter->getLayoutAttr());
if (failed(convertedOp)) {
return failure();
IREE::Codegen::LayoutAttrInterface layoutAttr = converter->getLayoutAttr();
SmallVector<Type> convertedResTypes;
for (auto init : op.getDpsInits()) {
convertedResTypes.push_back(converter->convertType(init.getType()));
}
rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResult(0));
Operation *newOp =
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
rewriter.replaceOp(op, newOp->getResults());
return success();
}

Expand Down
270 changes: 0 additions & 270 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,274 +305,4 @@ getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) {
return encodingInfo;
}

static RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands) {
SmallVector<Value> operands;
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
operands.append(convertedOutputOperands.begin(),
convertedOutputOperands.end());
return mlir::clone(builder, op,
{dropEncoding(cast<RankedTensorType>(
convertedOutputOperands[0].getType()))},
operands);
}

static RankedTensorType
getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed,
SmallVectorImpl<ReassociationIndices> &ri) {
if (!isBatched) {
ri.assign({{0, 1}, {2, 3}});
if (!isTransposed) {
return RankedTensorType::get(
{1, type.getDimSize(0), 1, type.getDimSize(1)},
type.getElementType());
}
return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1},
type.getElementType());
}

ri.assign({{0}, {1, 2}, {3, 4}});
if (!isTransposed) {
return RankedTensorType::get(
{type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)},
type.getElementType());
}
return RankedTensorType::get(
{type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1},
type.getElementType());
}

/// Given an input Value and a desired output element type, create and return
/// an element-wise linalg::GenericOp that extends the input Value to the
/// output element type.
static Value createElementWiseExtUIOp(OpBuilder &builder, Value input,
Location loc, Type outElemType) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<AffineMap> maps(
2, builder.getMultiDimIdentityMap(inputType.getRank()));
SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
utils::IteratorType::parallel);
auto castedType = inputType.clone(outElemType);
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(builder, loc, input);
Value init =
builder.create<tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
return builder
.create<linalg::GenericOp>(
loc, castedType, input, init, maps, iteratorTypes,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value castRes =
b.create<arith::ExtUIOp>(nestedLoc, outElemType, args[0])
->getResult(0);
b.create<linalg::YieldOp>(nestedLoc, castRes);
})
.getResult(0);
}

/// If needed, expand and the input Value, and return the resulting input with
/// the canonical mmt4d input shape. If the input element type is unsigned,
/// create a producer Linalg::GenericOp on the input that unsigned extends the
/// input to the output element type. This extension is required to keep the
/// unsignedness information on the input for ukernels. If `transpose` is true,
/// the `linalgOp`'s indexing maps are transposed.
static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
bool transpose, OpBuilder &builder,
SmallVectorImpl<ReassociationIndices> &ri,
ArrayRef<Type> elemTypes, int operandIdx) {
assert(linalgOp.getNumDpsInputs() == 2);
assert(linalgOp.getNumDpsInits() == 1);
auto cDims = linalg::inferContractionDims(linalgOp);
Location loc = linalgOp->getLoc();
Value expandedValue = value;
// If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
// operand is a vector and must be extended
if ((cDims->m.empty() && operandIdx != 1) ||
(cDims->n.empty() && operandIdx != 0)) {
auto type = cast<RankedTensorType>(value.getType());
RankedTensorType newType = getExpandedType(
type, /*isBatched=*/!cDims->batch.empty(),
/*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri);
expandedValue =
builder.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
}
if (elemTypes[operandIdx].isUnsignedInteger()) {
return createElementWiseExtUIOp(builder, expandedValue, loc,
elemTypes.back());
}
return expandedValue;
}

TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
IREE::Encoding::MatmulNarrowDim narrowDim,
ArrayRef<int64_t> hostDefinedUpperBound) {
assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) &&
"expected hostDefinedUpperBound is empty or has upper bound for {M, "
"N, K}");
// Handle narrow-N by transposing to reduce to narrow-M. Note: the
// enumeratedTiles currently only enumerate narrow-M cases.
if (narrowDim.isN()) {
SmallVector<int64_t> newHostDefinedUpperBound(hostDefinedUpperBound);
std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]);
narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
TileMxNxK tile =
chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound);
std::swap(tile.M, tile.N);
return tile;
}
// Handle kDynamic: currently this is only used with VMVX, where there is only
// one enumerated tile and it has all three M/N/K dimensions dynamic, so for
// now we only support that. Generalize that as needed when more dynamic tile
// sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide
// how to incorporate the handling of kDynamic in the cost-model evaluation
// below to decide when to prefer a dynamic vs a static tile shape.
for (auto tile : enumeratedTiles) {
if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) ||
ShapedType::isDynamic(tile.K)) {
assert(enumeratedTiles.size() == 1);
assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) &&
ShapedType::isDynamic(tile.K));
return tile;
}
}
// We're going to "rate" the enumerated tiles.
struct RatedTileMxNxK : TileMxNxK {
RatedTileMxNxK() {}
RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {}
// Penalize tiles that are wider in the M dimension than matmulNarrowM.
int64_t paddingPenalty = 0;
// Favor larger tiles, as long as they still minimize paddingPenalty.
int64_t productMxNxK = 0;
};
SmallVector<RatedTileMxNxK> ratedTiles;
ratedTiles.reserve(enumeratedTiles.size());
int64_t bestPaddingPenalty = INT64_MAX;
int64_t mUB = INT64_MAX;
int64_t nUB = INT64_MAX;
int64_t kUB = INT64_MAX;
if (!hostDefinedUpperBound.empty()) {
mUB = hostDefinedUpperBound[0];
nUB = hostDefinedUpperBound[1];
kUB = hostDefinedUpperBound[2];
}
for (auto tile : enumeratedTiles) {
if (tile.M > mUB || tile.N > nUB || tile.K > kUB) {
LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile (";
llvm::interleaveComma(
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
llvm::dbgs()
<< ") is skipped because it is not valid for upper_bound (";
llvm::interleaveComma(ArrayRef<int64_t>{mUB, nUB, kUB},
llvm::dbgs());
llvm::dbgs() << ")\n");
continue;
}
RatedTileMxNxK ratedTile(tile);
ratedTile.paddingPenalty = 0;
// If we are choosing a tile for a narrow-M case, we want to minimize
// padding along the M dimension.
// The PowerOf2Ceil is so that we are OK with padding up to the next
// power of two, we just try to avoid padding beyond that. For example,
// if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we
// are OK with the tile that has M==8 even though it requires some padding.
// Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
// end up selecting the vecmat tile (M==1) for that case!
if (narrowDim) {
ratedTile.paddingPenalty =
std::max<int64_t>(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0);
}
ratedTile.productMxNxK = tile.M * tile.N * tile.K;
ratedTiles.push_back(ratedTile);
LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma(
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n");
bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty);
}
RatedTileMxNxK bestRatedTile;
for (auto ratedTile : ratedTiles) {
// Choose only among tiles that minimize paddingPenalty. Among those,
// maximize productMxNxK.
if (ratedTile.paddingPenalty == bestPaddingPenalty &&
bestRatedTile.productMxNxK < ratedTile.productMxNxK) {
bestRatedTile = ratedTile;
}
}
// Sanity check. This assert can only fail if there's a programming mistake
// locally here.
assert(bestRatedTile.paddingPenalty == bestPaddingPenalty);
return bestRatedTile;
}

FailureOr<Operation *>
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
ValueRange operands, bool transposeNarrowN,
LayoutAttrInterface layoutAttr) {
if (!linalgOp.hasPureTensorSemantics()) {
return failure();
}

auto inputs = linalgOp.getDpsInputOperands();
auto outputs = linalgOp.getDpsInits();

auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
auto resultType = cast<RankedTensorType>(outputs[0].getType());
auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType);
auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}

if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
resultEncoding.getOperandIndex().getValue() !=
IREE::Encoding::MATMUL_RESULT) {
return failure();
}

MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo(
cast<RankedTensorType>(linalgOp->getResultTypes()[0]));

if (isIdentityLayout(encodingInfo)) {
return dropEncodingAndCloneOp(builder, linalgOp,
operands.take_front(inputs.size()),
operands.drop_front(inputs.size()));
}

bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
SmallVector<ReassociationIndices> ri;
Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri,
elemTypes, /*operandIdx=*/0);
Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri,
elemTypes, /*operandIdx=*/1);
Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
ri, elemTypes, /*operandIdx=*/2);
if (transpose) {
std::swap(newLhs, newRhs);
}
Type newResultType = newResult.getType();
auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
Operation *result;
if (cDims->batch.empty()) {
result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
ValueRange{newLhs, newRhs},
ValueRange{newResult});
} else {
result = builder.create<linalg::BatchMmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
ValueRange{newResult});
}
if (!ri.empty()) {
result = builder.create<tensor::CollapseShapeOp>(
linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
}
return result;
}

} // namespace mlir::iree_compiler::IREE::Codegen
23 changes: 0 additions & 23 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,6 @@ struct TileMxNxK {
MaterializeEncodingInfo
getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK);

//===----------------------------------------------------------------------===//
// Operation Lowering Utilities.
//===----------------------------------------------------------------------===//

// TODO(hanchung): The below methods are exposed to public because they are
// shared between MaterializeEncodingIntoPackUnPack.cpp.cpp and
// CPUEncodingExternalModels.cpp. They will be moved to other places after all
// the CPU backends implement their layout attributes.

/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the
/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be
/// greater than the values.
/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
/// information to host. For now, they are defined by host.
TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
IREE::Encoding::MatmulNarrowDim narrowDim,
ArrayRef<int64_t> hostDefinedUpperBound = {});

FailureOr<Operation *>
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
ValueRange operands, bool transposeNarrowN,
LayoutAttrInterface layoutAttr);

} // namespace mlir::iree_compiler::IREE::Codegen

#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_
Loading

0 comments on commit ad938ae

Please sign in to comment.