From 3d083eb8c54f1be8467cffafbffe303b77dbf7af Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 15 Jan 2025 13:00:28 +0000 Subject: [PATCH] part 2 --- .../Dialect/Encoding/IR/EncodingAttrs.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 16f003ac16de..7185f2e85c00 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -21,6 +21,8 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +static constexpr int64_t kNumBitsInByte = 8; + namespace mlir::iree_compiler::IREE::Encoding { EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, @@ -162,7 +164,7 @@ Value PackedStorageAttr::calculateStorageSizeInBytes( bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(type); int64_t staticCount = 1; if (!isPackedStorage) { - staticCount *= elementBits * 8; + staticCount *= elementBits * kNumBitsInByte; } for (unsigned i = 0; i < type.getRank(); ++i) { @@ -175,18 +177,10 @@ Value PackedStorageAttr::calculateStorageSizeInBytes( for (auto dim : dynamicDims) { value = builder.createOrFold(loc, value, dim); } - // Sub-byte packing requires putting multiple elements in the same byte. + if (isPackedStorage) { - assert(8 % elementBits == 0); - unsigned byteElements = 8 / elementBits; - // TODO(antiagainst): We may want to emit runtime check to make sure this is - // divisible. - auto divisor = builder.create(loc, byteElements); - if (!isPackedStorage && dynamicDims.empty() && - (staticCount * elementBits) % 8 != 0) { - return nullptr; - } - value = builder.createOrFold(loc, value, divisor); + auto divisor = builder.create(loc, kNumBitsInByte); + value = builder.createOrFold(loc, value, divisor); } return value; } @@ -227,7 +221,6 @@ Value EncodingAttr::calculateStorageSizeInBytes(Location loc, pad(k, roundDimsTo[2]); } - constexpr int64_t kNumBitsInByte = 8; unsigned elementBits = getTypeBitWidth(type.getElementType()); // Deal with unpacked storage of i1. if (elementBits == 1 && !IREE::Encoding::hasPackedStorageAttr(type)) {