Skip to content

Commit

Permalink
[i1] Remove command line option to enable packed storage
Browse files Browse the repository at this point in the history
* only use `#iree_encoding.packed_storage` to designate if an `i1` tensor is of packed memory layout.
* remove `iree-experimental-packed-i1-storage` command line option.
* teach type converters to allow casting into packed tensor types

Signed-off-by: Alan Li <me@alanli.org>
  • Loading branch information
lialan committed Feb 5, 2025
1 parent dc4e900 commit 3d5d0bb
Show file tree
Hide file tree
Showing 25 changed files with 196 additions and 214 deletions.
13 changes: 12 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
addConversion([](FloatType floatType) { return floatType; });
addConversion([](MemRefType memrefType) { return memrefType; });
addConversion([=](RankedTensorType type) -> RankedTensorType {
MaterializeEncodingInfo encodingInfo = getEncodingInfo(type);
if (IREE::Encoding::hasPackedStorageAttr(type)) {
return type;
}
// For a given tensor type with an encoding, return the materialized
// type to use for it. If no encoding is set, then return the tensor type
// itself.
MaterializeEncodingInfo encodingInfo = getEncodingInfo(type);
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return dropEncoding(type);
}
Expand Down Expand Up @@ -92,6 +95,14 @@ MaterializeEncodingTypeConverter::getEncodingInfo(RankedTensorType type) const {
}

RankedTensorType dropEncoding(RankedTensorType type) {
assert(!IREE::Encoding::hasPackedStorageAttr(type) &&
"not expected `packed_storage` attribute.");
return RankedTensorType::get(type.getShape(), type.getElementType());
}

RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type) {
if (!IREE::Encoding::hasPackedStorageAttr(type))
return type;
return RankedTensorType::get(type.getShape(), type.getElementType());
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -77,6 +78,9 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Returns the RankedTensorType without encodings.
RankedTensorType dropEncoding(RankedTensorType type);

/// Returns the RankedTensorType without packed storage encoding (if any).
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type);

/// Returns the deserialized MaterializeEncodingInfo if the `layouts` field is
/// present in encodings and it only has a single layout. Otherwise, returns
/// std::nullopt.
Expand Down
26 changes: 16 additions & 10 deletions compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand Down Expand Up @@ -65,9 +66,8 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType,
/// std::nullopt.
static std::optional<Type> getLegalizedType(Type t) {
if (auto shapedType = llvm::dyn_cast<RankedTensorType>(t)) {
Type elementType = shapedType.getElementType();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(elementType);
legalizeTensorStorageElementType(shapedType);
if (!legalizedElementType)
return std::nullopt;
return RankedTensorType::get(shapedType.getShape(),
Expand Down Expand Up @@ -121,7 +121,7 @@ struct ConstantOpTypeConversion
constantOp, "expected attribute type to be shaped type");
}
std::optional<Type> legalizedElementType =
legalizeStorageElementType(attrType.getElementType());
legalizeTensorStorageElementType(attrType);
if (!legalizedElementType) {
return rewriter.notifyMatchFailure(constantOp,
"cannot legalize elementType");
Expand Down Expand Up @@ -227,8 +227,10 @@ struct GenericOpTypePropagation
signatureConverter.addInputs(index, argType);
continue;
}
auto inputOperandType =
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
std::optional<Type> legalizedArgType =
legalizeStorageElementType(argType);
legalizeTensorStorageElementType(inputOperandType);
if (!legalizedArgType) {
return genericOp.emitOpError("failed to get legalized type for arg ")
<< index;
Expand Down Expand Up @@ -258,8 +260,8 @@ struct GenericOpTypePropagation
modifyYield = true;
OpOperand *yieldOperand =
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
std::optional<Type> legalizedType =
legalizeStorageElementType(yieldOperand->get().getType());
std::optional<Type> legalizedType = legalizeTensorStorageElementType(
modifiedOpOperand->get().getType());
if (!legalizedType) {
return genericOp.emitOpError(
"failed to get legalized type for yield value");
Expand Down Expand Up @@ -289,7 +291,7 @@ struct LinalgFillTypePropagation
ConversionPatternRewriter &rewriter) const final {
Value value = adaptor.getInputs().front();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(value.getType());
legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType());
if (!legalizedElementType) {
return fillOp.emitOpError("failed to get legalized type for value");
}
Expand Down Expand Up @@ -355,8 +357,8 @@ struct IREELinalgExtScatterTypePropagation
// type.
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
Type argType = modifiedOpRegion.getArguments()[0].getType();
std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(inputType);
if (!legalizedArgType) {
return scatterOp.emitOpError("failed to get legalized type for argument");
}
Expand Down Expand Up @@ -418,8 +420,12 @@ struct IREELinalgExtSortTypePropagation
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
for (auto [index, arg] : llvm::enumerate(modifiedOpRegion.getArguments())) {
// Refer to input types of the original operation to determine the
// corresponding legal arg type.
auto convertType = index % 2 == 0 ? sortOp->getOperandTypes()[index / 2]
: sortOp->getResultTypes()[index / 2];
std::optional<Type> legalizedArgType =
legalizeStorageElementType(arg.getType());
legalizeTensorStorageElementType(convertType);
if (!legalizedArgType) {
return sortOp.emitOpError("failed to get legalized type for argument");
}
Expand Down
11 changes: 9 additions & 2 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,15 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
bool hasPackedStorageAttr(Type type) {
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
auto encoding = tensorType.getEncoding();
if (!encoding) {
return false;
}
return dyn_cast_or_null<PackedStorageAttr>(encoding) != nullptr;
}
return false;
}

FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace mlir::iree_compiler::IREE::Encoding {
EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns true if the type contains packed_storage attribute.
bool hasPackedStorageAttr(RankedTensorType type);
bool hasPackedStorageAttr(Type type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_compiler_cc_library(
],
deps = [
":Utils",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_cc_library(
MLIRSCFDialect
MLIRTransformUtils
MLIRTransforms
iree::compiler::Codegen::Common
iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::Conversion
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.h"

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Dialect/HAL/Analysis/Captures.h"
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
Expand Down Expand Up @@ -478,7 +479,8 @@ struct TensorExportBufferViewOpPattern
}

auto loc = exportOp.getLoc();
auto tensorType = llvm::cast<RankedTensorType>(adaptor.getSourceEncoding());
auto tensorType = dropPackedStorageEncodingIfAny(
llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto dynamicDims = adaptor.getSourceEncodingDims();

// NOTE: we should have verified supported encodings/types at entry into the
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iree_compiler_cc_library(
":StreamInterfacesGen",
":StreamOpsGen",
":StreamTypesGen",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ iree_cc_library(
MLIRTensorDialect
MLIRTransformUtils
MLIRViewLikeInterface
iree::compiler::Codegen::Common
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand All @@ -27,6 +28,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler {
using IREE::Encoding::getEncodingAttr;
}

namespace mlir::iree_compiler::IREE::Stream {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1903,7 +1908,7 @@ LogicalResult TensorCloneOp::verify() {
// information.
auto sourceEncoding = llvm::cast<RankedTensorType>(op.getSourceEncoding());
auto resultEncoding = llvm::cast<RankedTensorType>(op.getResultEncoding());
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
if (getEncodingAttr(sourceEncoding) != getEncodingAttr(resultEncoding)) {
return op.emitOpError() << "clones changing tensor encoding from "
<< sourceEncoding.getEncoding() << " to "
<< resultEncoding.getEncoding() << "; not allowed";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
Expand All @@ -22,6 +23,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -247,6 +249,12 @@ struct ConvertToStreamPass final
if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
}
if (auto rankedType = llvm::dyn_cast_or_null<RankedTensorType>(type)) {
if (IREE::Encoding::hasPackedStorageAttr(rankedType)) {
return RankedTensorType::get(rankedType.getShape(),
rankedType.getElementType());
}
}
return !llvm::isa<TensorType>(type) ? type : Type{};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
// Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width.
static RankedTensorType alignTensorType(RankedTensorType originalType) {
Type elementType = originalType.getElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType = legalizeTensorStorageElementType(originalType);
if (alignedType == elementType)
return originalType;
return RankedTensorType::get(originalType.getShape(), alignedType,
Expand Down Expand Up @@ -168,7 +168,9 @@ static Value canonicalizeFillPattern(Value pattern, OpBuilder &builder) {
// %i8_val = (%i8_val << 2) | %i2_val
// %i8_val = (%i8_val << 2) | %i2_val
// %i8_val = (%i8_val << 2) | %i2_val
if (needToPackSubByteElementBitWidth(elementBitWidth)) {
bool patternIsPacked =
IREE::Encoding::hasPackedStorageAttr(pattern.getType());
if (!patternIsPacked && needToPackSubByteElementBitWidth(elementBitWidth)) {
Type i8Type = builder.getI8Type();
Value bitwidth = builder.createOrFold<arith::ConstantOp>(
loc, i8Type, builder.getIntegerAttr(i8Type, elementBitWidth));
Expand Down Expand Up @@ -655,7 +657,8 @@ struct EncodeHostTensorsPass
static IREE::Flow::DispatchTensorType
alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) {
Type elementType = originalType.getBoundElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType =
legalizeTensorStorageElementType(originalType.asRankedTensorType());
if (alignedType == elementType)
return originalType;
return IREE::Flow::DispatchTensorType::get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_lit_test_suite(
"encode_host_tensors.mlir",
"encode_host_tensors_encoding.mlir",
"encode_host_tensors_packing.mlir",
"encode_host_tensors_packing_i1_experimental_clopt.mlir",
"fold_globals.mlir",
"fold_uniform_operands.mlir",
"fuse_dispatch_bindings.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ iree_lit_test_suite(
"encode_host_tensors.mlir"
"encode_host_tensors_encoding.mlir"
"encode_host_tensors_packing.mlir"
"encode_host_tensors_packing_i1_experimental_clopt.mlir"
"fold_globals.mlir"
"fold_uniform_operands.mlir"
"fuse_dispatch_bindings.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: iree-opt --split-input-file --iree-stream-encode-host-tensors %s | FileCheck %s

#packed = #iree_encoding.packed_storage
func.func @unaligned_i1_size() -> index {
%0 = stream.tensor.sizeof tensor<12xi1, #packed> : index
return %0 : index
}
// CHECK: func @unaligned_i1_size() -> index {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: return %[[C2]] : index

// -----

#packed = #iree_encoding.packed_storage
func.func @aligned_i1_size() -> index {
%0 = stream.tensor.sizeof tensor<24xi1, #packed> : index
return %0 : index
}

// CHECK: func @aligned_i1_size() -> index {
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: return %[[C3]] : index

// -----

#packed = #iree_encoding.packed_storage
func.func @packed_i1_input_output(%input : tensor<16xi1, #packed>) -> tensor<16xi1, #packed> {
return %input : tensor<16xi1, #packed>
}
Loading

0 comments on commit 3d5d0bb

Please sign in to comment.