From 29229dfb84c36d90a7d60b0e03f61a7b6e0a8d58 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Wed, 4 Dec 2024 00:02:44 +0000 Subject: [PATCH] [GPU] Add gather fusion tests for vector distribution (#19209) VectorDistribution now supports gather fusion on producers. This pr adds pipeline tests for that. There are still numerical issues being tracked seperatly,related to distribution of gather. --- .../pipeline_vector_distribute_gfx942.mlir | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir index 184d49799faf..4396888ad90b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir @@ -884,6 +884,77 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding]> +#translation = #iree_codegen.translation_info +#config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}> + +hal.executable public @matmul_gather_rhs { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @matmul_gather_rhs ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_gather_rhs() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x64xf16> + %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x64xi64> + %6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x64xf16> + %7 = tensor.empty() : tensor<4096x4096xf16> + %8 = tensor.empty() : tensor<4096x4096xf32> + %9 = tensor.empty() : tensor<4096x64xf16> + %10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) { + ^bb0(%in: i64, %out: f16): + %14 = linalg.index 0 : index + %15 = arith.index_cast %in : i64 to index + %extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16> + linalg.yield %extracted : f16 + } -> tensor<4096x64xf16> + %11 = linalg.fill ins(%cst : f32) outs(%8 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + %12 = linalg.generic {indexing_maps = [#map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%6, %10 : tensor<4096x64xf16>, tensor<4096x64xf16>) + outs(%11 : tensor<4096x4096xf32>) + attrs = {lowering_config = #config} { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %14 = arith.extf %in : f16 to f32 + %15 = arith.extf %in_0 : f16 to f32 + %16 = arith.mulf %14, %15 : f32 + %17 = arith.addf %out, %16 : f32 + linalg.yield %17 : f32 + } -> tensor<4096x4096xf32> + %13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%12 : tensor<4096x4096xf32>) outs(%7 : tensor<4096x4096xf16>) { + ^bb0(%in: f32, %out: f16): + %14 = arith.truncf %in : f32 to f16 + linalg.yield %14 : f16 + } -> tensor<4096x4096xf16> + flow.dispatch.tensor.store %13, %3, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : tensor<4096x4096xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func.func @matmul_gather_rhs +// CHECK: vector.gather +// CHECK-COUNT-32: amdgpu.mfma + +// ----- + #config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}> #translation = #iree_codegen.translation_info @@ -1169,3 +1240,74 @@ hal.executable private @online_attention_split_k2 { // MEMORY-LABEL: func.func @online_attention_split_k2() // MEMORY-COUNT-3: memref.alloc // MEMORY-NOT: memref.alloc + +// ----- + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb"> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +#map5 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +#pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding]> +#translation = #iree_codegen.translation_info + +#qk_config = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], subgroup_m_count = 2 : i64, subgroup_n_count = 1 : i64}>} +#pv_config = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [1], subgroup_m_count = 2 : i64, subgroup_n_count = 1 : i64}>} +#config = #iree_gpu.lowering_config<{promote_operands = [0, 1, 2], reduction = [0, 0, 0, 0, 0, 64], workgroup = [1, 1, 64, 64, 0, 0]}> + +module { + hal.executable public @attention_gather_k { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @attention_gather_k ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @attention_gather_k() attributes {translation_info = #translation} { + %cst = arith.constant 1.250000e-01 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x10x4096x64xf16> + %6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x10x4096x64xi64> + %7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x10x4096x64xf16> + %8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x10x4096x64xf16> + %9 = tensor.empty() : tensor<2x10x4096x64xf16> + %10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<2x10x4096x64xi64>) outs(%9 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: i64, %out: f16): + %12 = linalg.index 0 : index + %13 = linalg.index 1 : index + %14 = arith.index_cast %in : i64 to index + %15 = linalg.index 3 : index + %extracted = tensor.extract %5[%12, %13, %14, %15] : tensor<2x10x4096x64xf16> + linalg.yield %extracted : f16 + } -> tensor<2x10x4096x64xf16> + %11 = iree_linalg_ext.attention { + indexing_maps = [#map1, #map2, #map3, #map4, #map5], + decomposition_config = { qk_attrs = #qk_config, pv_attrs = #pv_config }, + lowering_config = #config} ins(%7, %10, %8, %cst : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) outs(%9 : tensor<2x10x4096x64xf16>) { + ^bb0(%arg0: f32): + iree_linalg_ext.yield %arg0 : f32 + } -> tensor<2x10x4096x64xf16> + flow.dispatch.tensor.store %11, %4, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : tensor<2x10x4096x64xf16> -> !flow.dispatch.tensor> + return + } + } + } + } +} + +// CHECK-LABEL: func.func @attention_gather_k +// CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64 +// CHECK: vector.gather +// CHECK-SAME: into vector<4x1x1x1x1x8xf16> +// CHECK: scf.yield + +// MEMORY-LABEL: func.func @attention_gather_k +// MEMORY-COUNT-3: memref.alloc