Skip to content

Commit

Permalink
[NFC][GPU] Simplify definitions of MMA attributes (#19905)
Browse files Browse the repository at this point in the history
The tablegen had some strange auto-generated polymorphism with implicit
parsing of certain fields. None of it provided any benefit and is
simplified down to just the MMA enum. Also replaces the enum with an
enum parameter removing the extra `.getValue()` indirection when
accessing the enum.
  • Loading branch information
qedawkins authored Feb 5, 2025
1 parent 56bb652 commit 0a2862c
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ static TileSwizzle
getSwizzleBeforeMovingCrossThreadOutermost(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment) {
auto swizzle = getIntrinsicSwizzleBeforeMovingCrossThreadOutermost(
mma.getIntrinsic().getValue(), fragment);
mma.getIntrinsic(), fragment);
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs:
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
Expand Down
121 changes: 48 additions & 73 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,30 +306,43 @@ struct OpaqueMmaLayout {
Type cType;
};

static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
static std::tuple<int64_t, int64_t, int64_t>
getMNKShapeFromIntrinsic(MMAIntrinsic intrinsic) {
if (is_AMD(intrinsic)) {
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
} else {
std::tie(o.mSize, o.nSize, o.kSize) = getUnsupportedMNKShape(intrinsic);
return {lhs.outer[0] * lhs.thread[0] * lhs.element[0],
rhs.outer[1] * rhs.thread[1] * rhs.element[1],
lhs.outer[1] * lhs.thread[1] * lhs.element[1]};
}
return getUnsupportedMNKShape(intrinsic);
}

int64_t getMSize(MMAIntrinsic intrinsic) {
return std::get<0>(getMNKShapeFromIntrinsic(intrinsic));
}
int64_t getNSize(MMAIntrinsic intrinsic) {
return std::get<1>(getMNKShapeFromIntrinsic(intrinsic));
}
int64_t getKSize(MMAIntrinsic intrinsic) {
return std::get<2>(getMNKShapeFromIntrinsic(intrinsic));
}

static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
std::tie(o.mSize, o.nSize, o.kSize) = getMNKShapeFromIntrinsic(intrinsic);
return o;
}

MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return getSingleSubgroupLayout(mmaAttr.getIntrinsic().getValue(), fragment);
return getSingleSubgroupLayout(mmaAttr.getIntrinsic(), fragment);
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic().getValue(),
fragment);
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic(), fragment);
}
assert(false && "unhandled MMA Interface type.");
return {};
Expand All @@ -339,43 +352,12 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
// MMA Attributes
//===----------------------------------------------------------------------===//

Attribute MMAAttr::parse(AsmParser &p, Type type) {
if (failed(p.parseLess()))
return {};

FailureOr<MMAIntrinsicAttr> mmaIntrinsic =
FieldParser<MMAIntrinsicAttr>::parse(p);
if (failed(mmaIntrinsic)) {
p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier");
return {};
}

if (failed(p.parseGreater()))
return {};

return get(p.getContext(), mmaIntrinsic->getValue());
}

void MMAAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<";
os << stringifyMMAIntrinsic(getIntrinsic().getValue());
os << ">";
}

MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
auto layout = getOpaqueMMALayout(context, type);
return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
layout.nSize, layout.kSize, layout.aType, layout.bType,
layout.cType);
}

std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
return {getAType(), getBType(), getCType()};
return IREE::GPU::getABCElementTypes(getContext(), getIntrinsic());
}

std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
return {getMSize(), getNSize(), getKSize()};
return getMNKShapeFromIntrinsic(getIntrinsic());
}

template <typename MMAIntrinsicType>
Expand All @@ -394,24 +376,24 @@ static VectorType getVectorType(MLIRContext *context,
std::tuple<VectorType, VectorType, VectorType>
MMAAttr::getABCVectorTypes() const {
MLIRContext *context = getContext();
MMAIntrinsic intrinsic = getIntrinsic().getValue();
MMAIntrinsic intrinsic = getIntrinsic();
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
return {aVecType, bVecType, cVecType};
}

int64_t MMAAttr::getBlockSize() const {
return IREE::GPU::getBlockSize(getIntrinsic().getValue());
return IREE::GPU::getBlockSize(getIntrinsic());
}

int64_t MMAAttr::getSubgroupSize() const {
return getIntrinsicSubgroupSize(getIntrinsic().getValue());
return getIntrinsicSubgroupSize(getIntrinsic());
}

FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
// Explicit distribution currently unsupported for NV intrinsics.
MMAIntrinsic intrinsic = getIntrinsic().getValue();
MMAIntrinsic intrinsic = getIntrinsic();
if (intrinsic == MMAIntrinsic::NV_WMMA_F16_16x16x16_F16 ||
intrinsic == MMAIntrinsic::NV_WMMA_F32_16x16x16_F16) {
return failure();
Expand All @@ -421,7 +403,7 @@ FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {

// Get virtual intrinsics that is composed/based on queried op.
SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
switch (getIntrinsic().getValue()) {
switch (getIntrinsic()) {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16};
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
Expand Down Expand Up @@ -475,8 +457,8 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
if (cType != resultType) {
return failure();
}
if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
resultType, lhs, rhs, acc)) {
if (Value value = createMmaOp(builder, loc, getIntrinsic(), resultType, lhs,
rhs, acc)) {
return value;
}
return failure();
Expand Down Expand Up @@ -562,7 +544,7 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
SmallVector<OpFoldResult> &strides) const {

MMASingleSubgroupLayout subgroupLayout =
getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
getSingleSubgroupLayout(getIntrinsic(), fragment);
SmallVector<OpFoldResult> canonicalOffsets;
SmallVector<OpFoldResult> canonicalSizes;
if (failed(populateCanonicalOffsetsSizesAndStrides(
Expand Down Expand Up @@ -597,13 +579,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle,

std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic());
return {opaqueLayout.mSize * getIntrinsicsM() * getSubgroupsM(),
opaqueLayout.nSize * getIntrinsicsN() * getSubgroupsN(),
opaqueLayout.kSize * getIntrinsicsK()};
Expand All @@ -624,7 +606,7 @@ DataTiledMMAAttr::getABCVectorTypes() const {
}

int64_t DataTiledMMAAttr::getSubgroupSize() const {
return getIntrinsicSubgroupSize(getIntrinsic().getValue());
return getIntrinsicSubgroupSize(getIntrinsic());
}

FailureOr<IREE::GPU::MMAScope> DataTiledMMAAttr::getMmaScope() const {
Expand Down Expand Up @@ -672,8 +654,7 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
// In layoutThreadSizes, intrinsic level dimensions are mixed with expansion
// to multiple subgroups, so in order to tell if there are additional
// distribution-only thread dimensions, we need to get back to the intrinsic.
TileSwizzle intrinsicSwizzle =
getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
TileSwizzle intrinsicSwizzle = getIntrinsicSwizzle(getIntrinsic(), fragment);

SmallVector<int64_t> intrinsicLayoutThreadSizes =
sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
Expand Down Expand Up @@ -826,7 +807,7 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
SmallVector<Value> intrinsicsAcc =
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);

MMAIntrinsic intrinsic = getIntrinsic().getValue();
MMAIntrinsic intrinsic = getIntrinsic();
VectorType intrinCType =
getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);

Expand Down Expand Up @@ -878,12 +859,6 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
// VirtualMMA Attributes
//===----------------------------------------------------------------------===//

VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
VirtualMMAIntrinsic type) {
auto intrinsicAttr = VirtualMMAIntrinsicAttr::get(context, type);
return VirtualMMAAttr::get(context, intrinsicAttr);
}

static std::tuple<int64_t, int64_t, int64_t>
getMNKShape(VirtualMMAIntrinsic type) {
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
Expand Down Expand Up @@ -936,14 +911,14 @@ static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,

std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<VectorType, VectorType, VectorType>
VirtualMMAAttr::getABCVectorTypes() const {
MLIRContext *context = getContext();
VirtualMMAIntrinsic intrinsic = getIntrinsic().getValue();
VirtualMMAIntrinsic intrinsic = getIntrinsic();
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
Expand All @@ -952,12 +927,12 @@ VirtualMMAAttr::getABCVectorTypes() const {

std::tuple<int64_t, int64_t, int64_t> VirtualMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic());
return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize};
}

int64_t VirtualMMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
switch (getIntrinsic()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
Expand All @@ -980,7 +955,7 @@ LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides(
SmallVector<OpFoldResult> &strides) const {

MMASingleSubgroupLayout subgroupLayout =
getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
getSingleSubgroupLayout(getIntrinsic(), fragment);
SmallVector<OpFoldResult> canonicalOffsets;
SmallVector<OpFoldResult> canonicalSizes;
if (failed(populateCanonicalOffsetsSizesAndStrides(
Expand All @@ -995,7 +970,7 @@ LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides(
}

int64_t VirtualMMAAttr::getIntrinsicsK() const {
switch (getIntrinsic().getValue()) {
switch (getIntrinsic()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
return 2;
Expand Down Expand Up @@ -1025,7 +1000,7 @@ FailureOr<Value> VirtualMMAAttr::buildMmaOperation(OpBuilder &builder,
if (cType != resultType) {
return failure();
}
switch (getIntrinsic().getValue()) {
switch (getIntrinsic()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
Expand Down Expand Up @@ -1064,7 +1039,7 @@ FailureOr<Value> VirtualMMAAttr::buildMmaOperation(OpBuilder &builder,
}

int64_t VirtualMMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
switch (getIntrinsic()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ struct MMASingleSubgroupLayout {
SmallVector<int64_t, 2> element;
};

int64_t getMSize(MMAIntrinsic intrinsic);
int64_t getNSize(MMAIntrinsic intrinsic);
int64_t getKSize(MMAIntrinsic intrinsic);

MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
MMAFragment fragment);

Expand Down
Loading

0 comments on commit 0a2862c

Please sign in to comment.