Skip to content

Commit

Permalink
Reshape propagation to enable broadcast(transpose) -> attention(q, kt…
Browse files Browse the repository at this point in the history
…, vt) fusion.

Fixes some minor bugs and adds a missing pattern to enable reshape
propagation that moves reshapes out of the transpose operation and
attention operation to be fusable subsequently.

Signed-off-by: MaheshRavishankar <mravisha@amd.com>
  • Loading branch information
MaheshRavishankar committed Jan 10, 2025
1 parent 801e2c1 commit 568e472
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void BubbleUpExpandShapesPass::runOnOperation() {
auto dimExpr = getAffineDimExpr(dim, operandMap.getContext());
if (std::optional<int64_t> maybeDim =
operandMap.getResultPosition(dimExpr);
maybeDim && !reassoc[maybeDim.value()].empty()) {
maybeDim && reassoc[maybeDim.value()].size() > 1) {
return false;
}
}
Expand All @@ -204,6 +204,8 @@ void BubbleUpExpandShapesPass::runOnOperation() {
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.insert<BubbleExpandThroughExtract>(context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 :
#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
util.func public @dont_sink_through_k2(%0 : tensor<128x64x1x1x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
%13 = tensor.empty() : tensor<4x32x64x128xf16>
%collapsed_12 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4]] : tensor<128x64x128x1x1xf16> into tensor<128x64x128xf16>
%collapsed_12 = tensor.collapse_shape %0 [[0], [1, 2, 3], [4]] : tensor<128x64x1x1x128xf16> into tensor<128x64x128xf16>
%17 = tensor.empty() : tensor<128x64x128xf16>
%18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) {
^bb0(%score: f16):
Expand All @@ -482,7 +482,6 @@ util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : ten
// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] :
// CHECK: util.return %[[ATTENTION]]


// -----

util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,51 @@ util.func public @unsupported_bubbble_expand_through_extract(%arg0 : tensor<2x40
// CHECK-LABEL: @unsupported_bubbble_expand_through_extract
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[EXTRACT]]

// -----

// Checks two things
// 1. Propagation of reshapes across attention operations
// 2. Use of folders to convert (expand(collapse)) -> (collapse)
util.func public @attention_v_reshape_propagation(%arg0: index,
%arg1: tensor<4x8x4x128x?xf16>, %arg2: tensor<128x?x128xf16>,
%arg3: tensor<128x?x128xf16>, %arg4: f16, %arg5: tensor<128x?x?xf16>)
-> tensor<4x?x32x128xf16> {
%0 = tensor.empty(%arg0) : tensor<4x?x32x128xf16>
%1 = tensor.empty(%arg0) : tensor<128x?x128xf16>
%collapsed = tensor.collapse_shape %arg1 [[0, 1, 2], [3], [4]]
: tensor<4x8x4x128x?xf16> into tensor<128x128x?xf16>
%4 = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]}
ins(%arg2, %arg3, %collapsed, %arg4, %arg5
: tensor<128x?x128xf16>, tensor<128x?x128xf16>, tensor<128x128x?xf16>,
f16, tensor<128x?x?xf16>)
outs(%1 : tensor<128x?x128xf16>) {
^bb0(%arg6: f32):
iree_linalg_ext.yield %arg6 : f32
} -> tensor<128x?x128xf16>
%expanded = tensor.expand_shape %4 [[0, 1], [2], [3]]
output_shape [4, 32, %arg0, 128]
: tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded : tensor<4x32x?x128xf16>) outs(%0 : tensor<4x?x32x128xf16>) { ^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x?x32x128xf16>
util.return %5 : tensor<4x?x32x128xf16>
}
// CHECK-LABEL: func public @attention_v_reshape_propagation
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x8x4x128x?xf16>
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %[[ARG1]],
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ATTENTION]]
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]

0 comments on commit 568e472

Please sign in to comment.