diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 578cd5921c59..71a8272d341b 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,7 +15,7 @@ // GFX942: target = #iree_gpu.target, , , , , , , , ], +// GFX942-SAME: mma = [, , , , , , , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, // GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], @@ -26,10 +26,10 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , , , , ], +// GFX940-SAME: mma = [, , , , , , , , , , , ], // GFX1100: target = #iree_gpu.target, , ] +// GFX1100-SAME: mma = [, , , , ] // GFX1100-SAME: subgroup_size_choices = [32, 64] stream.executable public @reduce_dispatch { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index a584fabdaa0c..4f9b1ffab3b6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -216,9 +216,13 @@ static std::tuple getABCElementTypes(MLIRContext *context, Type f16 = Float16Type::get(context); Type bf16 = BFloat16Type::get(context); Type f32 = Float32Type::get(context); + Type f64 = Float64Type::get(context); Type i8 = IntegerType::get(context, 8); Type i32 = IntegerType::get(context, 32); switch (intrinsic) { + case MMAIntrinsic::MFMA_F64_16x16x4_F64: { + return {f64, f64, f64}; + } case MMAIntrinsic::MFMA_F32_16x16x4_F32: { return {f32, f32, f32}; } @@ -228,6 +232,12 @@ static std::tuple getABCElementTypes(MLIRContext *context, case MMAIntrinsic::MFMA_F32_32x32x8_F16: { return {f16, f16, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x8_BF16: { + return {bf16, bf16, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x4_BF16: { + return {bf16, bf16, f32}; + } case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { return {bf16, bf16, f32}; } @@ -240,6 +250,12 @@ static std::tuple getABCElementTypes(MLIRContext *context, case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { return {f8E5M2FNUZ, f8E5M2FNUZ, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: { + return {f8E4M3FNUZ, f8E5M2FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: { + return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; + } case MMAIntrinsic::MFMA_I32_16x16x32_I8: { return {i8, i8, i32}; } @@ -258,6 +274,12 @@ static std::tuple getABCElementTypes(MLIRContext *context, case MMAIntrinsic::WMMA_F16_16x16x16_F16: { return {f16, f16, f16}; } + case MMAIntrinsic::WMMA_F32_16x16x16_BF16: { + return {bf16, bf16, f32}; + } + case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: { + return {bf16, bf16, bf16}; + } case MMAIntrinsic::WMMA_I32_16x16x16_I8: { return {i8, i8, i32}; } @@ -505,6 +527,43 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } + case MMAIntrinsic::MFMA_F64_16x16x4_F64: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 1}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{1, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{1, 1}}; + } + case MMAIntrinsic::MFMA_F32_16x16x8_BF16: { + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 2}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{2, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + } + } + case MMAIntrinsic::MFMA_F32_32x32x4_BF16: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, + /*element=*/{1, 2}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{2, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: case MMAIntrinsic::MFMA_F32_16x16x16_BF16: @@ -535,6 +594,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: switch (fragment) { case MMAFragment::Lhs: @@ -560,6 +621,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, /*element=*/{4, 1}}; } case MMAIntrinsic::WMMA_F32_16x16x16_F16: + case MMAIntrinsic::WMMA_F32_16x16x16_BF16: case MMAIntrinsic::WMMA_I32_16x16x16_I8: switch (fragment) { case MMAFragment::Lhs: @@ -573,6 +635,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, /*element=*/{1, 1}}; } case MMAIntrinsic::WMMA_F16_16x16x16_F16: + case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: switch (fragment) { case MMAFragment::Lhs: return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0}, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index fd6d02a5ffc4..ac03a9fa5fa2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -146,17 +146,26 @@ def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x1021>; def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x10C0>; def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x10C1>; +// Introduced in CDNA2 +def MFMA_F32_16x16x8_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x8_BF16", 0x1120>; +def MFMA_F32_32x32x4_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x4_BF16", 0x1121>; +def MFMA_F64_16x16x4_F64 : I32EnumAttrCase<"MFMA_F64_16x16x4_F64", 0x1100>; + // Introduced in CDNA3 def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x1220>; def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x1221>; def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x1230>; +def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>; def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>; +def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>; def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>; def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>; // Introduced in RDNA3 def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x1820>; def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x1821>; +def WMMA_F32_16x16x16_BF16 : I32EnumAttrCase<"WMMA_F32_16x16x16_BF16", 0x1822>; +def WMMA_BF16_16x16x16_BF16 : I32EnumAttrCase<"WMMA_BF16_16x16x16_BF16", 0x1823>; def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x18C0>; // NV intrinsics @@ -172,17 +181,26 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_I32_16x16x16_I8, MFMA_I32_32x32x8_I8, + // Introduced in CDNA2 + MFMA_F32_16x16x8_BF16, + MFMA_F32_32x32x4_BF16, + MFMA_F64_16x16x4_F64, + // Introduced in CDNA3 MFMA_F32_16x16x16_BF16, MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E5M2FNUZ, + MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E4M3FNUZ, + MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, // RDNA3 intrinsics WMMA_F32_16x16x16_F16, WMMA_F16_16x16x16_F16, + WMMA_F32_16x16x16_BF16, + WMMA_BF16_16x16x16_BF16, WMMA_I32_16x16x16_I8, // NV intrinsics diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 82fd46d9be2c..e198e216ece7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -133,13 +133,19 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, const WgpDetails *getCDNA3WgpDetails() { static const MMAIntrinsic cdna3MMAOps[] = { + // Introduced in CDNA1, still present in CDNA3 MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16, MMAIntrinsic::MFMA_F32_32x32x8_F16, + // Introduced in CDNA2, still present in CDNA3 + MMAIntrinsic::MFMA_F64_16x16x4_F64, + // Introduced in CDNA3 MMAIntrinsic::MFMA_F32_16x16x16_BF16, MMAIntrinsic::MFMA_F32_32x32x8_BF16, - MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ, + MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8, }; @@ -162,10 +168,16 @@ const WgpDetails *getCDNA3WgpDetails() { const WgpDetails *getCDNA2WgpDetails() { static const MMAIntrinsic cdna2MMAOps[] = { + // Introduced in CDNA1 + MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16, MMAIntrinsic::MFMA_F32_32x32x8_F16, MMAIntrinsic::MFMA_I32_16x16x16_I8, MMAIntrinsic::MFMA_I32_32x32x8_I8, + // Introduced in CDNA2 + MMAIntrinsic::MFMA_F32_16x16x8_BF16, + MMAIntrinsic::MFMA_F32_32x32x4_BF16, + MMAIntrinsic::MFMA_F64_16x16x4_F64, }; static const WgpDetails cdna2Wgp = {allComputeBits, allStorageBits, @@ -183,8 +195,9 @@ const WgpDetails *getCDNA2WgpDetails() { const WgpDetails *getCDNA1WgpDetails() { static const MMAIntrinsic cdna1MMAOps[] = { - MMAIntrinsic::MFMA_F32_16x16x16_F16, - MMAIntrinsic::MFMA_F32_32x32x8_F16, + MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16, + MMAIntrinsic::MFMA_F32_32x32x8_F16, MMAIntrinsic::MFMA_I32_16x16x16_I8, + MMAIntrinsic::MFMA_I32_32x32x8_I8, }; static const WgpDetails cdna1Wgp = {allComputeBits, allStorageBits, @@ -202,9 +215,10 @@ const WgpDetails *getCDNA1WgpDetails() { const WgpDetails *getRDNA3WgpDetails() { static const MMAIntrinsic rdna3MMAOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::WMMA_F32_16x16x16_F16, MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::WMMA_I32_16x16x16_I8, MMAIntrinsic::WMMA_I32_16x16x16_I8, MMAIntrinsic::WMMA_I32_16x16x16_I8, + }; static const WgpDetails rdna3Wgp = {allComputeBits, allStorageBits, diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index dd4036e22354..9e8642ed7abb 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1741,6 +1741,36 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_cdna3_dt_f64 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f64" + "--acc_type=f64" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + "--iree-input-demote-f64-to-f32=false" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + endif() elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")