diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 71fe957a0830..48ee4d4d80c1 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -191,7 +191,7 @@ void BubbleUpExpandShapesPass::runOnOperation() { auto dimExpr = getAffineDimExpr(dim, operandMap.getContext()); if (std::optional maybeDim = operandMap.getResultPosition(dimExpr); - maybeDim && !reassoc[maybeDim.value()].empty()) { + maybeDim && reassoc[maybeDim.value()].size() > 1) { return false; } } @@ -204,6 +204,8 @@ void BubbleUpExpandShapesPass::runOnOperation() { // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); bubbleExpandShapePatterns.insert(context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); GreedyRewriteConfig rewriteConfig; rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir index b3fdd9038b49..244b02e917e9 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir @@ -455,9 +455,9 @@ util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> -util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { +util.func public @dont_sink_through_k2(%0 : tensor<128x64x1x1x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { %13 = tensor.empty() : tensor<4x32x64x128xf16> - %collapsed_12 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4]] : tensor<128x64x128x1x1xf16> into tensor<128x64x128xf16> + %collapsed_12 = tensor.collapse_shape %0 [[0], [1, 2, 3], [4]] : tensor<128x64x1x1x128xf16> into tensor<128x64x128xf16> %17 = tensor.empty() : tensor<128x64x128xf16> %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) { ^bb0(%score: f16): @@ -482,7 +482,6 @@ util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : ten // CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] : // CHECK: util.return %[[ATTENTION]] - // ----- util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor, %arg2: tensor) -> tensor { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir index b014d59f881c..d654df337520 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir @@ -21,3 +21,51 @@ util.func public @unsupported_bubbble_expand_through_extract(%arg0 : tensor<2x40 // CHECK-LABEL: @unsupported_bubbble_expand_through_extract // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[EXTRACT]] + +// ----- + +// Checks two things +// 1. Propagation of reshapes across attention operations +// 2. Use of folders to convert (expand(collapse)) -> (collapse) +util.func public @attention_v_reshape_propagation(%arg0: index, + %arg1: tensor<4x8x4x128x?xf16>, %arg2: tensor<128x?x128xf16>, + %arg3: tensor<128x?x128xf16>, %arg4: f16, %arg5: tensor<128x?x?xf16>) + -> tensor<4x?x32x128xf16> { + %0 = tensor.empty(%arg0) : tensor<4x?x32x128xf16> + %1 = tensor.empty(%arg0) : tensor<128x?x128xf16> + %collapsed = tensor.collapse_shape %arg1 [[0, 1, 2], [3], [4]] + : tensor<4x8x4x128x?xf16> into tensor<128x128x?xf16> + %4 = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} + ins(%arg2, %arg3, %collapsed, %arg4, %arg5 + : tensor<128x?x128xf16>, tensor<128x?x128xf16>, tensor<128x128x?xf16>, + f16, tensor<128x?x?xf16>) + outs(%1 : tensor<128x?x128xf16>) { + ^bb0(%arg6: f32): + iree_linalg_ext.yield %arg6 : f32 + } -> tensor<128x?x128xf16> + %expanded = tensor.expand_shape %4 [[0, 1], [2], [3]] + output_shape [4, 32, %arg0, 128] + : tensor<128x?x128xf16> into tensor<4x32x?x128xf16> + %5 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%expanded : tensor<4x32x?x128xf16>) outs(%0 : tensor<4x?x32x128xf16>) { ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x?x32x128xf16> + util.return %5 : tensor<4x?x32x128xf16> +} +// CHECK-LABEL: func public @attention_v_reshape_propagation +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x8x4x128x?xf16> +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %[[ARG1]], +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ATTENTION]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] +// CHECK: return %[[COLLAPSE]]