Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Jan 15, 2025
1 parent 0671d4e commit fe8befc
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 53 deletions.
21 changes: 7 additions & 14 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

static constexpr int64_t kNumBitsInByte = 8;

namespace mlir::iree_compiler::IREE::Encoding {

EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
Expand Down Expand Up @@ -147,7 +149,7 @@ static int32_t getRoundedElementByteWidth(Type type) {
unsigned bitsUnaligned = getTypeBitWidth(type);
assert(bitsUnaligned > 0 && "0-width types unsupported");
// Round up to 8-bit aligned bytes.
unsigned byteAligned = (bitsUnaligned + 8 - 1) / 8;
unsigned byteAligned = llvm::alignTo(bitsUnaligned, kNumBitsInByte);
// Round up to the next power of two (unless already a power of two).
return llvm::PowerOf2Ceil(byteAligned);
}
Expand All @@ -162,7 +164,7 @@ Value PackedStorageAttr::calculateStorageSizeInBytes(
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(type);
int64_t staticCount = 1;
if (!isPackedStorage) {
staticCount *= elementBits * 8;
staticCount *= elementBits * kNumBitsInByte;
}

for (unsigned i = 0; i < type.getRank(); ++i) {
Expand All @@ -175,18 +177,10 @@ Value PackedStorageAttr::calculateStorageSizeInBytes(
for (auto dim : dynamicDims) {
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.

if (isPackedStorage) {
assert(8 % elementBits == 0);
unsigned byteElements = 8 / elementBits;
// TODO(antiagainst): We may want to emit runtime check to make sure this is
// divisible.
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
if (!isPackedStorage && dynamicDims.empty() &&
(staticCount * elementBits) % 8 != 0) {
return nullptr;
}
value = builder.createOrFold<arith::CeilDivUIOp>(loc, value, divisor);
auto divisor = builder.create<arith::ConstantIndexOp>(loc, kNumBitsInByte);
value = builder.createOrFold<arith::CeilDivSIOp>(loc, value, divisor);
}
return value;
}
Expand Down Expand Up @@ -227,7 +221,6 @@ Value EncodingAttr::calculateStorageSizeInBytes(Location loc,
pad(k, roundDimsTo[2]);
}

constexpr int64_t kNumBitsInByte = 8;
unsigned elementBits = getTypeBitWidth(type.getElementType());
// Deal with unpacked storage of i1.
if (elementBits == 1 && !IREE::Encoding::hasPackedStorageAttr(type)) {
Expand Down
48 changes: 24 additions & 24 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,29 @@ include "mlir/IR/EnumAttr.td"
//===---------------------------------------------------------------------===//

class IREEEncoding_Attr<string name, list<Trait> traits = []>
: AttrDef<IREEEncoding_Dialect, name, traits>;
: AttrDef<IREEEncoding_Dialect, name, traits>;

class IREEEncoding_I32EnumAttr<string name, string summary,
list<I32EnumAttrCase> cases>
class IREEEncoding_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, summary, cases> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
let genSpecializedAttr = 0;
}

class IREEEncoding_EnumAttr<EnumAttrInfo enumInfo, string name = "">
: EnumAttr<IREEEncoding_Dialect, enumInfo, name>;
: EnumAttr<IREEEncoding_Dialect, enumInfo, name>;

// Enums for tagging operand operation in an EncodingAttr
def MATMUL : I32EnumAttrCase<"matmul", 0>;
def CONV : I32EnumAttrCase<"conv", 1>;
def CONV : I32EnumAttrCase<"conv", 1>;

def EncodingOpType
: IREEEncoding_I32EnumAttr<"EncodingOpType",
"Tracks the type of operation of the operand.", [
MATMUL,
CONV,
]>;
def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
"Tracks the type of operation of the operand.", [
MATMUL,
CONV,
]>;

def EncodingOpTypeAttr : IREEEncoding_EnumAttr<EncodingOpType, "optype">;
def EncodingOpTypeAttr:
IREEEncoding_EnumAttr<EncodingOpType, "optype">;

def PackedStorageAttr
: IREEEncoding_Attr<"PackedStorage",
Expand All @@ -50,20 +49,21 @@ def PackedStorageAttr
]>]> {
let mnemonic = "packed_storage";
let summary = [{Indicates packed storage data type.}];
let description =
[{This attribute indicates this is a back - to -
back packed storage in memory.This attribute takes no arguments.}];
let description = [{
This attribute indicates this is a back-to-back packed storage in memory.
This attribute takes no arguments.
}];
let genVerifyDecl = 0;
}

def EncodingAttr
: IREEEncoding_Attr<"Encoding",
[DeclareAttrInterfaceMethods<
IREEEncoding_EncodingLayoutAttrInterface, [
"calculateStorageSizeInBytes",
]>]> {
def EncodingAttr :
IREEEncoding_Attr<"Encoding", [
DeclareAttrInterfaceMethods<IREEEncoding_EncodingLayoutAttrInterface, [
"calculateStorageSizeInBytes",
]>
]> {
let mnemonic = "encoding";
let summary = [{information to decide how to data - tile a tensor}];
let summary = [{information to decide how to data-tile a tensor}];
let description = [{
This attribute describes the change in the layout for
a given tensor to execute subsequent operations on
Expand Down Expand Up @@ -116,8 +116,8 @@ def EncodingAttr
AffineMap getMapForOperandIndex() const;

/// Given the dim position of the encoding `user_indexing_maps`, returns the
/// matching index of the given encoding's tensor, using
/// getMapForOperandIndex bcast_map and user_indexing_map.
/// matching index of the given encoding's tensor, using getMapForOperandIndex
/// bcast_map and user_indexing_map.
std::optional<unsigned> mapDimToOperandIndex(int64_t dimPos) const;

/// Returns an integer array with values in `round_dims_to`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ struct TensorExportBufferViewOpPattern
}

auto loc = exportOp.getLoc();
// Drop the encoding of packed_storage here, as it is no longer needed
// afterwards.
auto tensorType =
dropEncoding(llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto dynamicDims = adaptor.getSourceEncodingDims();
Expand Down
11 changes: 6 additions & 5 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler {
using IREE::Encoding::getEncodingAttr;
}

namespace mlir::iree_compiler::IREE::Stream {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1517,7 +1513,12 @@ LogicalResult TensorCloneOp::verify() {
// information.
auto sourceEncoding = llvm::cast<RankedTensorType>(op.getSourceEncoding());
auto resultEncoding = llvm::cast<RankedTensorType>(op.getResultEncoding());
if (getEncodingAttr(sourceEncoding) != getEncodingAttr(resultEncoding)) {
if (IREE::Encoding::hasPackedStorageAttr(sourceEncoding) !=
IREE::Encoding::hasPackedStorageAttr(resultEncoding)) {
return op.emitOpError()
<< "clones attribute #iree_encoding.packed_storage mismatch";
}
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
return op.emitOpError() << "clones changing tensor encoding from "
<< sourceEncoding.getEncoding() << " to "
<< resultEncoding.getEncoding() << "; not allowed";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ struct ConvertToStreamPass final
if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
}
if (auto rankedType = llvm::dyn_cast_or_null<RankedTensorType>(type)) {
if (auto rankedType = llvm::dyn_cast<RankedTensorType>(type)) {
// Drop packed_storage attr if any, as we don't need them anymore.
if (IREE::Encoding::hasPackedStorageAttr(rankedType)) {
return RankedTensorType::get(rankedType.getShape(),
rankedType.getElementType());
Expand Down
19 changes: 10 additions & 9 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
namespace mlir::iree_compiler {

bool needToPackSubByteElements(Type type) {
unsigned bitWidth = isa<TensorType>(type)
? IREE::Util::getTypeBitWidth(
dyn_cast<TensorType>(type).getElementType())
: IREE::Util::getTypeBitWidth(type);

auto rankedTensorType = llvm::dyn_cast_or_null<RankedTensorType>(type);
bool isPackedStorage = rankedTensorType &&
IREE::Encoding::hasPackedStorageAttr(rankedTensorType);

auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(type);
if (!rankedTensorType) {
return false;
}
unsigned bitWidth =
IREE::Util::getTypeBitWidth(rankedTensorType.getElementType());

// i1 with packed memory layout does not need to be extended.
if (bitWidth == 1 && isPackedStorage) {
if (bitWidth == 1 && IREE::Encoding::hasPackedStorageAttr(rankedTensorType)) {
return true;
}

// Require the original bit width to be some power of two for now to avoid
// trickiness and weirdness of packing and cross-byte access.
// Also disallow boolean values for now--they may require separate interface
Expand Down

0 comments on commit fe8befc

Please sign in to comment.