Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Jan 14, 2025
1 parent 181837e commit d4463b3
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 53 deletions.
6 changes: 0 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,4 @@ RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type) {
if (!IREE::Encoding::hasPackedStorageAttr(type))
return type;
return RankedTensorType::get(type.getShape(), type.getElementType());
}

} // namespace mlir::iree_compiler
3 changes: 0 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Returns the RankedTensorType without encodings.
RankedTensorType dropEncoding(RankedTensorType type);

/// Returns the RankedTensorType without packed storage encoding (if any).
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type);

/// Utility method to convert from `set_encoding` op to `pack` operation.
/// NOTE: `source` could be returned when packing is not needed.
FailureOr<Value> lowerSetEncodingOpToPackOp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ struct TensorExportBufferViewOpPattern
}

auto loc = exportOp.getLoc();
auto tensorType = dropPackedStorageEncodingIfAny(
llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto tensorType =
dropEncoding(llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto dynamicDims = adaptor.getSourceEncodingDims();

// NOTE: we should have verified supported encodings/types at entry into the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ static Value canonicalizeFillPattern(Value pattern, OpBuilder &builder) {
// %i8_val = (%i8_val << 2) | %i2_val
// %i8_val = (%i8_val << 2) | %i2_val
// %i8_val = (%i8_val << 2) | %i2_val
bool patternIsPacked =
IREE::Encoding::hasPackedStorageAttr(pattern.getType());
if (!patternIsPacked && needToPackSubByteElementBitWidth(elementBitWidth)) {
if (needToPackSubByteElements(pattern.getType())) {
Type i8Type = builder.getI8Type();
Value bitwidth = builder.createOrFold<arith::ConstantOp>(
loc, i8Type, builder.getIntegerAttr(i8Type, elementBitWidth));
Expand Down
55 changes: 22 additions & 33 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,40 @@

namespace mlir::iree_compiler {

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
bool needToPackSubByteElements(Type type) {
unsigned bitWidth = isa<TensorType>(type)
? IREE::Util::getTypeBitWidth(
dyn_cast<TensorType>(type).getElementType())
: type.getIntOrFloatBitWidth();

auto rankedTensorType = llvm::dyn_cast_or_null<RankedTensorType>(type);
bool isPackedStorage = rankedTensorType &&
IREE::Encoding::hasPackedStorageAttr(rankedTensorType);
// i1 with packed memory layout does not need to be extended.
if (bitWidth == 1 && isPackedStorage) {
return true;
}
// Require the original bit width to be some power of two for now to avoid
// trickiness and weirdness of packing and cross-byte access.
// Also disallow boolean values for now--they may require separate interface
// choices.
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
}

bool needToPackSubByteElements(RankedTensorType shapedType) {
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
// i1 with packed memory layout does not need to be extended.
if (bitWidth == 1 && IREE::Encoding::hasPackedStorageAttr(shapedType)) {
return true;
}
return needToPackSubByteElementBitWidth(bitWidth);
}
Type legalizeTensorStorageElementType(Type type) {
auto tensorType = llvm::cast<TensorType>(type);
auto elementType = tensorType.getElementType();

static Type legalizeStorageElementTypeImpl(Type elementType,
bool isPackedStorage) {
// Only handle integers; floats in MLIR all have aligned widths (today).
auto intType = dyn_cast<IntegerType>(elementType);
if (!intType)
return elementType;

unsigned bitWidth = intType.getWidth();
if (bitWidth == 1 && isPackedStorage) {
return elementType;
}

// For sub-byte elements, default to pack them into bytes.
if (needToPackSubByteElementBitWidth(bitWidth))
if (needToPackSubByteElements(type))
return elementType;

unsigned bitWidth = intType.getWidth();
// Otherwise, extend them to the next power-of-two bit width.
unsigned alignedBitWidth =
IREE::Util::getRoundedElementByteWidth(intType) * 8;
Expand All @@ -59,12 +60,6 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
intType.getSignedness());
}

Type legalizeTensorStorageElementType(Type type) {
auto tensorType = llvm::cast<TensorType>(type);
return legalizeStorageElementTypeImpl(
tensorType.getElementType(), IREE::Encoding::hasPackedStorageAttr(type));
}

Value calculateStorageElementCountInBytes(Location loc,
RankedTensorType shapedType,
ValueRange dynamicDims,
Expand All @@ -80,13 +75,9 @@ Value calculateStorageElementCountInBytes(Location loc,
Type alignedElementType = legalizeTensorStorageElementType(shapedType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Calculate all static dims first, if any.
int64_t staticCount = 1;
if (!isI1WithPackedStorage &&
!needToPackSubByteElementBitWidth(elementBits)) {
if (needToPackSubByteElements(shapedType)) {
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
}

Expand All @@ -101,12 +92,13 @@ Value calculateStorageElementCountInBytes(Location loc,
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
if (needToPackSubByteElements(shapedType)) {
assert(8 % elementBits == 0);
unsigned byteElements = 8 / elementBits;
// TODO(antiagainst): We may want to emit runtime check to make sure this is
// divisible.
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
if (!isPackedStorage && dynamicDims.empty() &&
(staticCount * elementBits) % 8 != 0) {
return nullptr;
Expand All @@ -124,11 +116,8 @@ Value calculateStorageElementOffsetInBytes(Location loc,
Type alignedElementType = legalizeTensorStorageElementType(originalType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(originalType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Sub-byte packing requires putting multiple elements in the same byte.
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
if (needToPackSubByteElements(originalType)) {
Value byteElements =
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand Down
8 changes: 2 additions & 6 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@

namespace mlir::iree_compiler {

/// Returns true if the given |bitWidth|, if appearing at runtime-kernel
/// interface, is less than a byte that should be tightly packed together.
bool needToPackSubByteElementBitWidth(unsigned bitWidth);

/// Returns true if the given |shapedType|, if appearing at runtime-kernel
/// Returns true if the given |type|, if appearing at runtime-kernel
/// interface, has sub-byte element types that should be tightly packed
/// together.
bool needToPackSubByteElements(RankedTensorType shapedType);
bool needToPackSubByteElements(Type type);

/// Legalizes the given |elementType| for storage.
///
Expand Down

0 comments on commit d4463b3

Please sign in to comment.