diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel index 2d8f6f618cc7..4550c1c2e8b7 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel @@ -17,6 +17,7 @@ iree_td_library( name = "td_files", srcs = enforce_glob( [ + "EncodingAttrs.td", "EncodingBase.td", "EncodingOps.td", ], @@ -39,6 +40,7 @@ iree_td_library( iree_compiler_cc_library( name = "IR", srcs = [ + "EncodingAttrs.cpp", "EncodingAttrs.cpp.inc", "EncodingDialect.cpp", "EncodingDialect.cpp.inc", @@ -54,6 +56,7 @@ iree_compiler_cc_library( "EncodingEnums.h.inc", "EncodingOps.h", "EncodingOps.h.inc", + "EncodingTypes.h", "EncodingTypes.h.inc", ], deps = [ @@ -101,7 +104,7 @@ iree_gentbl_cc_library( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "EncodingBase.td", + td_file = "EncodingAttrs.td", deps = [":td_files"], ) @@ -169,7 +172,7 @@ iree_gentbl_cc_library( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "EncodingBase.td", + td_file = "EncodingAttrs.td", deps = [":td_files"], ) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt index 0321f89561e8..7544ec551a42 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt @@ -20,8 +20,10 @@ iree_cc_library( "EncodingEnums.h.inc" "EncodingOps.h" "EncodingOps.h.inc" + "EncodingTypes.h" "EncodingTypes.h.inc" SRCS + "EncodingAttrs.cpp" "EncodingAttrs.cpp.inc" "EncodingDialect.cpp" "EncodingDialect.cpp.inc" @@ -63,7 +65,7 @@ iree_tablegen_library( NAME EncodingEnumsGen TD_FILE - "EncodingBase.td" + "EncodingAttrs.td" OUTS --gen-enum-decls EncodingEnums.h.inc --gen-enum-defs EncodingEnums.cpp.inc @@ -85,7 +87,7 @@ iree_tablegen_library( NAME EncodingTypesGen TD_FILE - "EncodingBase.td" + "EncodingAttrs.td" OUTS --gen-attrdef-decls --attrdefs-dialect=iree_encoding EncodingAttrs.h.inc --gen-attrdef-defs --attrdefs-dialect=iree_encoding EncodingAttrs.cpp.inc diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp new file mode 100644 index 000000000000..ec163eaf14c7 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -0,0 +1,160 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" + +#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::iree_compiler::IREE::Encoding { + +EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, + EncodingOpType opType, ArrayRef elemTypes, + ArrayRef maps, + std::optional bcastMap, + ArrayRef roundDimsTo) { + Builder b(ctx); + auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType); + auto roundDimsToAttr = roundDimsTo.empty() + ? DenseI64ArrayAttr() + : b.getDenseI64ArrayAttr(roundDimsTo); + auto bcastMapAttr = bcastMap.has_value() + ? AffineMapAttr::get(bcastMap.value()) + : AffineMapAttr(); + return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr, + b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps), + bcastMapAttr, roundDimsToAttr); +} + +AffineMap EncodingAttr::getMapForOperandIndex() { + auto index = getOperandIndex().getValue().getZExtValue(); + switch (index) { + case MATMUL_LHS: + case MATMUL_RHS: + case MATMUL_RESULT: { + auto indexingMap = + llvm::cast(getUserIndexingMaps()[index]).getAffineMap(); + if (auto bcastMap = getBcastMap()) { + indexingMap = bcastMap.getAffineMap().compose(indexingMap); + } + return indexingMap; + } + default: + return AffineMap(); + } +} + +std::optional EncodingAttr::mapDimToOperandIndex(int64_t dimPos) { + return getMapForOperandIndex().getResultPosition( + getAffineDimExpr(dimPos, getContext())); +} + +MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, + int narrowThreshold) { + linalg::ContractionDimensions cDims = + linalg::inferContractionDims(linalgOp).value(); + auto map = linalgOp.getIndexingMapsArray().back(); + auto outType = llvm::cast(linalgOp.getDpsInits()[0].getType()); + auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t { + return outType.getDimSize( + map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext())) + .value()); + }; + // M or N can be empty instead of having an explicit dim size of 1 for matvec + // and vecmat, so set to 1 if empty. + int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]); + int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]); + + MatmulNarrowDim narrowM, narrowN; + if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) { + narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize}; + } + if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) { + narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize}; + } + + return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN; +} + +ArrayRef EncodingAttr::getRoundDimsToArray() { + auto roundDimsTo = getRoundDimsTo(); + if (!roundDimsTo) { + return {}; + } + return llvm::cast(roundDimsTo).asArrayRef(); +} + +SmallVector EncodingAttr::getElementTypesArray() { + return llvm::map_to_vector(getElementTypes().getValue(), [](Attribute a) { + return llvm::cast(a).getValue(); + }); +} + +EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { + return get(bcastMap.getContext(), getOperandIndex(), getOpType(), + getElementTypes(), getUserIndexingMaps(), + AffineMapAttr::get(bcastMap), getRoundDimsTo()); +} + +MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { + if (encoding.getOpType().getValue() != EncodingOpType::matmul) { + return {}; + } + ArrayRef roundDimsTo = encoding.getRoundDimsToArray(); + if (roundDimsTo.empty()) { + return {}; + } + int m = roundDimsTo[0]; + int n = roundDimsTo[1]; + if (m < n) { + return {MatmulNarrowDim::Dim::M, m}; + } + if (n < m) { + return {MatmulNarrowDim::Dim::N, n}; + } + return {}; +} + +EncodingAttr getEncodingAttr(RankedTensorType type) { + return dyn_cast_or_null(type.getEncoding()); +} + +FailureOr +getEncodingContractionDims(EncodingAttr encoding) { + auto indexingMapsAttr = encoding.getUserIndexingMaps(); + SmallVector indexingMaps = llvm::map_to_vector( + indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap { + return cast(m).getAffineMap(); + }); + return linalg::inferContractionDims(indexingMaps); +} + +std::string stringifyOperandIndex(IntegerAttr valueAttr) { + auto value = valueAttr.getValue().getZExtValue(); + switch (value) { + case MATMUL_LHS: + return "LHS"; + case MATMUL_RHS: + return "RHS"; + case MATMUL_RESULT: + return "RESULT"; + default: + assert(false && "invalid index"); + return ""; + } +} + +} // namespace mlir::iree_compiler::IREE::Encoding diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td new file mode 100644 index 000000000000..3ec4bd0d0408 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td @@ -0,0 +1,104 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_ENCODING_ATTRS +#define IREE_DIALECT_ENCODING_ATTRS + +include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" + +//===---------------------------------------------------------------------===// +// Data layout encoding attributes +//===---------------------------------------------------------------------===// + +class IREEEncoding_Attr traits = []> + : AttrDef; + +class IREEEncoding_I32EnumAttr cases> + : I32EnumAttr { + let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; + let genSpecializedAttr = 0; +} + +class IREEEncoding_EnumAttr + : EnumAttr; + +// Enums for tagging operand operation in an EncodingAttr +def MATMUL : I32EnumAttrCase<"matmul", 0>; +def CONV : I32EnumAttrCase<"conv", 1>; + +def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType", + "Tracks the type of operation of the operand.", [ + MATMUL, + CONV, + ]>; + +def EncodingOpTypeAttr: + IREEEncoding_EnumAttr; + +def EncodingAttr : + IREEEncoding_Attr<"Encoding"> { + let mnemonic = "encoding"; + let summary = [{information to decide how to data-tile a tensor}]; + let description = [{ + This attribute describes the change in the layout for + a given tensor to execute subsequent operations on + the tiled layout. The encoding serves as a way to + represent the change in the way the data is laid out in + memory without changing the logical rank/extent of + the tensor itself. When required, the encoding + can be used to explicitly manifest the layout change + through operations like pack/unpack. + }]; + + let assemblyFormat = "`<` struct(params) `>`"; + + let parameters = (ins + AttrParameter<"IntegerAttr", "this tensor operand's index in the parameter list">:$operand_index, + AttrParameter<"EncodingOpTypeAttr", "operand type">:$op_type, + AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types, + OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps, + OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map, + // TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now. + OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to + ); + + let builders = [ + AttrBuilder<(ins "int64_t":$operandIndex, + "EncodingOpType":$opType, + "ArrayRef":$elemTypes, + CArg<"ArrayRef", "{}">:$maps, + CArg<"std::optional", "{}">:$bcastMap, + CArg<"ArrayRef", "{}">:$roundDimsTo)> + ]; + + let extraClassDeclaration = [{ + /// Returns the bcast_map composed with the user_indexing_map for the + /// operand_index. The dimensions of the returned map are those of the + /// data-tiled op's iteration space, and the results of the map are in + /// the domain of the encoded tensor type. + AffineMap getMapForOperandIndex(); + + /// Given the dim position of the encoding `user_indexing_maps`, returns the + /// matching index of the given encoding's tensor, using getMapForOperandIndex + /// bcast_map and user_indexing_map. + std::optional mapDimToOperandIndex(int64_t dimPos); + + /// Returns an integer array with values in `round_dims_to`. + ArrayRef getRoundDimsToArray(); + + /// Returns a vector with values in `element_types`. + SmallVector getElementTypesArray(); + + /// Clones an encoding with a new bcast_map + EncodingAttr clone(AffineMap bcastMap); + }]; + + let genVerifyDecl = 0; +} + +#endif // IREE_DIALECT_ENCODING_ATTRS diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td index 8788dc33bd66..8d52e3b4e32e 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td @@ -8,8 +8,6 @@ #define IREE_DIALECT_ENCODING_BASE include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/EnumAttr.td" //===----------------------------------------------------------------------===// // Dialect definition @@ -28,94 +26,4 @@ def IREEEncoding_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; } -//===---------------------------------------------------------------------===// -// Data layout encoding attributes -//===---------------------------------------------------------------------===// - -class IREEEncoding_Attr traits = []> - : AttrDef; - -class IREEEncoding_I32EnumAttr cases> - : I32EnumAttr { - let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; - let genSpecializedAttr = 0; -} - -class IREEEncoding_EnumAttr - : EnumAttr; - -// Enums for tagging operand operation in an EncodingAttr -def MATMUL : I32EnumAttrCase<"matmul", 0>; -def CONV : I32EnumAttrCase<"conv", 1>; - -def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType", - "Tracks the type of operation of the operand.", [ - MATMUL, - CONV, - ]>; - -def EncodingOpTypeAttr: - IREEEncoding_EnumAttr; - -def EncodingAttr : - IREEEncoding_Attr<"Encoding"> { - let mnemonic = "encoding"; - let summary = [{information to decide how to data-tile a tensor}]; - let description = [{ - This attribute describes the change in the layout for - a given tensor to execute subsequent operations on - the tiled layout. The encoding serves as a way to - represent the change in the way the data is laid out in - memory without changing the logical rank/extent of - the tensor itself. When required, the encoding - can be used to explicitly manifest the layout change - through operations like pack/unpack. - }]; - - let assemblyFormat = "`<` struct(params) `>`"; - - let parameters = (ins - AttrParameter<"IntegerAttr", "this tensor operand's index in the parameter list">:$operand_index, - AttrParameter<"EncodingOpTypeAttr", "operand type">:$op_type, - AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types, - OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps, - OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map, - // TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now. - OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to - ); - - let builders = [ - AttrBuilder<(ins "int64_t":$operandIndex, - "EncodingOpType":$opType, - "ArrayRef":$elemTypes, - CArg<"ArrayRef", "{}">:$maps, - CArg<"std::optional", "{}">:$bcastMap, - CArg<"ArrayRef", "{}">:$roundDimsTo)> - ]; - - let extraClassDeclaration = [{ - /// Returns the bcast_map composed with the user_indexing_map for the - /// operand_index. The dimensions of the returned map are those of the - /// data-tiled op's iteration space, and the results of the map are in - /// the domain of the encoded tensor type. - AffineMap getMapForOperandIndex(); - - /// Given the dim position of the encoding `user_indexing_maps`, returns the - /// matching index of the given encoding's tensor, using getMapForOperandIndex - /// bcast_map and user_indexing_map. - std::optional mapDimToOperandIndex(int64_t dimPos); - - /// Returns an integer array with values in `round_dims_to`. - ArrayRef getRoundDimsToArray(); - - /// Returns a vector with values in `element_types`. - SmallVector getElementTypesArray(); - - /// Clones an encoding with a new bcast_map - EncodingAttr clone(AffineMap bcastMap); - }]; - - let genVerifyDecl = 0; -} - #endif // IREE_DIALECT_ENCODING_BASE diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp index c480bf3ed127..46925b570ec9 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/SourceMgr.h" @@ -16,13 +17,12 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/InliningUtils.h" -using namespace mlir; -using namespace mlir::iree_compiler::IREE::Encoding; - -#include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.cpp.inc" // IWYU pragma: keep - #define GET_ATTRDEF_CLASSES -#include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp.inc" +#include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.cpp.inc" +#undef GET_ATTRDEF_CLASSES + +namespace mlir::iree_compiler::IREE::Encoding { // Used to control inlining behavior. struct IREEEncodingInlinerInterface : public DialectInlinerInterface { @@ -60,4 +60,6 @@ void IREEEncodingDialect::initialize() { >(); } +} // namespace mlir::iree_compiler::IREE::Encoding + #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.h index 2912aa5e6405..7ff318f9cc83 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.h +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.h @@ -9,6 +9,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Support/TypeID.h" // clang-format off: must be included after all LLVM/MLIR headers #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h.inc" // IWYU pragma: keep diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp index c1311517d600..bc367765cfc8 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp @@ -6,19 +6,13 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -92,153 +86,11 @@ LogicalResult UnsetEncodingOp::reifyResultShapes( return success(); } +} // namespace mlir::iree_compiler::IREE::Encoding + //===----------------------------------------------------------------------===// -// encoding.encoding +// TableGen definitions (intentionally last) //===----------------------------------------------------------------------===// -EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, - EncodingOpType opType, ArrayRef elemTypes, - ArrayRef maps, - std::optional bcastMap, - ArrayRef roundDimsTo) { - Builder b(ctx); - auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType); - auto roundDimsToAttr = roundDimsTo.empty() - ? DenseI64ArrayAttr() - : b.getDenseI64ArrayAttr(roundDimsTo); - auto bcastMapAttr = bcastMap.has_value() - ? AffineMapAttr::get(bcastMap.value()) - : AffineMapAttr(); - return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr, - b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps), - bcastMapAttr, roundDimsToAttr); -} - -AffineMap EncodingAttr::getMapForOperandIndex() { - auto index = getOperandIndex().getValue().getZExtValue(); - switch (index) { - case MATMUL_LHS: - case MATMUL_RHS: - case MATMUL_RESULT: { - auto indexingMap = - llvm::cast(getUserIndexingMaps()[index]).getAffineMap(); - if (auto bcastMap = getBcastMap()) { - indexingMap = bcastMap.getAffineMap().compose(indexingMap); - } - return indexingMap; - } - default: - return AffineMap(); - } -} - -std::optional EncodingAttr::mapDimToOperandIndex(int64_t dimPos) { - return getMapForOperandIndex().getResultPosition( - getAffineDimExpr(dimPos, getContext())); -} - -MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, - int narrowThreshold) { - linalg::ContractionDimensions cDims = - linalg::inferContractionDims(linalgOp).value(); - auto map = linalgOp.getIndexingMapsArray().back(); - auto outType = llvm::cast(linalgOp.getDpsInits()[0].getType()); - auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t { - return outType.getDimSize( - map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext())) - .value()); - }; - // M or N can be empty instead of having an explicit dim size of 1 for matvec - // and vecmat, so set to 1 if empty. - int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]); - int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]); - - MatmulNarrowDim narrowM, narrowN; - if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) { - narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize}; - } - if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) { - narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize}; - } - - return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN; -} - -ArrayRef EncodingAttr::getRoundDimsToArray() { - auto roundDimsTo = getRoundDimsTo(); - if (!roundDimsTo) { - return {}; - } - return llvm::cast(roundDimsTo).asArrayRef(); -} - -SmallVector EncodingAttr::getElementTypesArray() { - return llvm::map_to_vector(getElementTypes().getValue(), [](Attribute a) { - return llvm::cast(a).getValue(); - }); -} - -EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { - return get(bcastMap.getContext(), getOperandIndex(), getOpType(), - getElementTypes(), getUserIndexingMaps(), - AffineMapAttr::get(bcastMap), getRoundDimsTo()); -} - -MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { - if (encoding.getOpType().getValue() != EncodingOpType::matmul) { - return {}; - } - ArrayRef roundDimsTo = encoding.getRoundDimsToArray(); - if (roundDimsTo.empty()) { - return {}; - } - int m = roundDimsTo[0]; - int n = roundDimsTo[1]; - if (m < n) { - return {MatmulNarrowDim::Dim::M, m}; - } - if (n < m) { - return {MatmulNarrowDim::Dim::N, n}; - } - return {}; -} - -//===---------------------------------------------------------------------===// -// Encoding Dialect Helpers -//===---------------------------------------------------------------------===// - -EncodingAttr getEncodingAttr(RankedTensorType type) { - return dyn_cast_or_null(type.getEncoding()); -} - -FailureOr -getEncodingContractionDims(EncodingAttr encoding) { - auto indexingMapsAttr = encoding.getUserIndexingMaps(); - SmallVector indexingMaps = llvm::map_to_vector( - indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap { - return cast(m).getAffineMap(); - }); - return linalg::inferContractionDims(indexingMaps); -} - -std::string stringifyOperandIndex(IntegerAttr valueAttr) { - auto value = valueAttr.getValue().getZExtValue(); - switch (value) { - case MATMUL_LHS: - return "LHS"; - case MATMUL_RHS: - return "RHS"; - case MATMUL_RESULT: - return "RESULT"; - default: - assert(false && "invalid index"); - return ""; - } -} - -} // namespace mlir::iree_compiler::IREE::Encoding - -// clang-format off #define GET_OP_CLASSES #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp.inc" // IWYU pragma: keep -// clang-format: on diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h index 9a0810ed78fe..fd89887cf1da 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGOPS_H_ #define IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGOPS_H_ +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -18,75 +19,9 @@ #include "mlir/Interfaces/TilingInterface.h" // clang-format off - -#include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.h.inc" // IWYU pragma: export - -#define GET_ATTRDEF_CLASSES -#include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.h.inc" // IWYU pragma: export - #define GET_OP_CLASSES #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h.inc" // IWYU pragma: export - +#undef GET_OP_CLASSES // clang-format on -//===---------------------------------------------------------------------===// -// Encoding Dialect Helpers -//===---------------------------------------------------------------------===// - -namespace mlir::iree_compiler::IREE::Encoding { - -/// Returns the encoding attribute from the type if there is an encoding. -/// Otherwise, returns null. -EncodingAttr getEncodingAttr(RankedTensorType type); - -/// Returns the ContractionDimensions for the encoding user_indexing_maps. -FailureOr -getEncodingContractionDims(EncodingAttr encoding); - -/// Assign a name to operand indices for clarity -const int64_t MATMUL_LHS = 0; -const int64_t MATMUL_RHS = 1; -const int64_t MATMUL_RESULT = 2; - -/// Convert operand index to strings for printing -std::string stringifyOperandIndex(IntegerAttr); - -/// Designates a dimension in a matmul (either the M or the N dimension) as -/// being "narrow", i.e. small enough that we bother lowering the amount of -/// padding along that dimension compared to how padding we apply to -/// sufficiently large dimensions. -struct MatmulNarrowDim { - // Enumerates dimensions of a matmul that may be labelled as narrow. - enum class Dim { - None, - M, - N, - }; - Dim dim = Dim::None; // Which dimension is designated by *this. - int64_t size = 0; // Size of the designated dimension, or kDynamic. - - explicit operator bool() const { return dim != Dim::None; } - bool isM() const { return dim == Dim::M; } - bool isN() const { return dim == Dim::N; } -}; - -/// Returns the narrow dim in a given `linalgOp`, with respect to the given -/// `narrowThreshold` below which a dimension is eligible to be considered -/// narrow. If both M and N are narrow, M is returned. If neither M nor N are -/// narrow, this returns a default-constructed falsish value. -MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, - int narrowThreshold); - -/// Returns the narrow dim in a given `encoding`. This works by inspecting -/// the `round_dims_to` array attribute in the `encoding`. If the -/// `round_dims_to` of one dimension (M or N) is smaller than the other, then -/// that's the narrow dimension, because the only way it would have been set -/// to be smaller in the first place, is if we previously flagged that dimension -/// as narrow. If the `round_dims_to` of the M and N dimensions agree, then -/// neither is a narrow dimension and this returns a default-constructed falsish -/// value. -MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding); - -} // namespace mlir::iree_compiler::IREE::Encoding - #endif // IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGOPS_H_ diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h new file mode 100644 index 000000000000..7e3ed08ffa01 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h @@ -0,0 +1,90 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGTYPES_H_ +#define IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGTYPES_H_ + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" + +// clang-format off +#include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.h.inc" // IWYU pragma: export +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.h.inc" // IWYU pragma: export +#undef GET_ATTRDEF_CLASSES +#define GET_TYPEDEF_CLASSES +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h.inc" // IWYU pragma: export +#undef GET_TYPEDEF_CLASSES +// clang-format on + +//===---------------------------------------------------------------------===// +// Encoding Dialect Helpers +//===---------------------------------------------------------------------===// + +namespace mlir::iree_compiler::IREE::Encoding { + +/// Returns the encoding attribute from the type if there is an encoding. +/// Otherwise, returns null. +EncodingAttr getEncodingAttr(RankedTensorType type); + +/// Returns the ContractionDimensions for the encoding user_indexing_maps. +FailureOr +getEncodingContractionDims(EncodingAttr encoding); + +/// Assign a name to operand indices for clarity +const int64_t MATMUL_LHS = 0; +const int64_t MATMUL_RHS = 1; +const int64_t MATMUL_RESULT = 2; + +/// Convert operand index to strings for printing +std::string stringifyOperandIndex(IntegerAttr); + +/// Designates a dimension in a matmul (either the M or the N dimension) as +/// being "narrow", i.e. small enough that we bother lowering the amount of +/// padding along that dimension compared to how padding we apply to +/// sufficiently large dimensions. +struct MatmulNarrowDim { + // Enumerates dimensions of a matmul that may be labelled as narrow. + enum class Dim { + None, + M, + N, + }; + Dim dim = Dim::None; // Which dimension is designated by *this. + int64_t size = 0; // Size of the designated dimension, or kDynamic. + + explicit operator bool() const { return dim != Dim::None; } + bool isM() const { return dim == Dim::M; } + bool isN() const { return dim == Dim::N; } +}; + +/// Returns the narrow dim in a given `linalgOp`, with respect to the given +/// `narrowThreshold` below which a dimension is eligible to be considered +/// narrow. If both M and N are narrow, M is returned. If neither M nor N are +/// narrow, this returns a default-constructed falsish value. +MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, + int narrowThreshold); + +/// Returns the narrow dim in a given `encoding`. This works by inspecting +/// the `round_dims_to` array attribute in the `encoding`. If the +/// `round_dims_to` of one dimension (M or N) is smaller than the other, then +/// that's the narrow dimension, because the only way it would have been set +/// to be smaller in the first place, is if we previously flagged that dimension +/// as narrow. If the `round_dims_to` of the M and N dimensions agree, then +/// neither is a narrow dimension and this returns a default-constructed falsish +/// value. +MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding); + +} // namespace mlir::iree_compiler::IREE::Encoding + +#endif // IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGTYPES_H_