Skip to content

Commit

Permalink
Reshape propagation to enable broadcast(transpose) -> attention(q, kt…
Browse files Browse the repository at this point in the history
…, vt) fusion. (#19661)

Fixes some minor bugs and adds a missing pattern to enable reshape
propagation that moves reshapes out of the transpose operation and
attention operation to be fusable subsequently.

This doesnt necessarily fuse the the transpose and attention yet. That
will be addressed subsequently.

Signed-off-by: MaheshRavishankar <mravisha@amd.com>
  • Loading branch information
MaheshRavishankar authored Jan 13, 2025
1 parent cac7a96 commit d90c505
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void BubbleUpExpandShapesPass::runOnOperation() {
auto dimExpr = getAffineDimExpr(dim, operandMap.getContext());
if (std::optional<int64_t> maybeDim =
operandMap.getResultPosition(dimExpr);
maybeDim && !reassoc[maybeDim.value()].empty()) {
maybeDim && reassoc[maybeDim.value()].size() > 1) {
return false;
}
}
Expand All @@ -204,6 +204,8 @@ void BubbleUpExpandShapesPass::runOnOperation() {
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.insert<BubbleExpandThroughExtract>(context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 :
#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
util.func public @dont_sink_through_k2(%0 : tensor<128x64x1x1x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
%13 = tensor.empty() : tensor<4x32x64x128xf16>
%collapsed_12 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4]] : tensor<128x64x128x1x1xf16> into tensor<128x64x128xf16>
%collapsed_12 = tensor.collapse_shape %0 [[0], [1, 2, 3], [4]] : tensor<128x64x1x1x128xf16> into tensor<128x64x128xf16>
%17 = tensor.empty() : tensor<128x64x128xf16>
%18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) {
^bb0(%score: f16):
Expand All @@ -482,7 +482,6 @@ util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : ten
// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] :
// CHECK: util.return %[[ATTENTION]]


// -----

util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,51 @@ util.func public @unsupported_bubbble_expand_through_extract(%arg0 : tensor<2x40
// CHECK-LABEL: @unsupported_bubbble_expand_through_extract
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[EXTRACT]]

// -----

// Checks two things
// 1. Propagation of reshapes across attention operations
// 2. Use of folders to convert (expand(collapse)) -> (collapse)
util.func public @attention_v_reshape_propagation(%arg0: index,
%arg1: tensor<4x8x4x128x?xf16>, %arg2: tensor<128x?x128xf16>,
%arg3: tensor<128x?x128xf16>, %arg4: f16, %arg5: tensor<128x?x?xf16>)
-> tensor<4x?x32x128xf16> {
%0 = tensor.empty(%arg0) : tensor<4x?x32x128xf16>
%1 = tensor.empty(%arg0) : tensor<128x?x128xf16>
%collapsed = tensor.collapse_shape %arg1 [[0, 1, 2], [3], [4]]
: tensor<4x8x4x128x?xf16> into tensor<128x128x?xf16>
%4 = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]}
ins(%arg2, %arg3, %collapsed, %arg4, %arg5
: tensor<128x?x128xf16>, tensor<128x?x128xf16>, tensor<128x128x?xf16>,
f16, tensor<128x?x?xf16>)
outs(%1 : tensor<128x?x128xf16>) {
^bb0(%arg6: f32):
iree_linalg_ext.yield %arg6 : f32
} -> tensor<128x?x128xf16>
%expanded = tensor.expand_shape %4 [[0, 1], [2], [3]]
output_shape [4, 32, %arg0, 128]
: tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded : tensor<4x32x?x128xf16>) outs(%0 : tensor<4x?x32x128xf16>) { ^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x?x32x128xf16>
util.return %5 : tensor<4x?x32x128xf16>
}
// CHECK-LABEL: func public @attention_v_reshape_propagation
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x8x4x128x?xf16>
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %[[ARG1]],
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ATTENTION]]
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]

0 comments on commit d90c505

Please sign in to comment.