Skip to content

Commit

Permalink
Revert "Support fusing broadcast transposes with attention" (#19835)
Browse files Browse the repository at this point in the history
Reverts #19828 
Fixes #19833
  • Loading branch information
IanWood1 authored Jan 28, 2025
1 parent 6a5c12e commit 9870a6d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
Expand Down Expand Up @@ -59,44 +58,42 @@ struct FuseTransposeWithAttentionOp final

LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = nullptr;
linalg::LinalgOp producer;
OpOperand *transposeOperand = nullptr;
linalg::LinalgOp transposeOp;
for (OpOperand *input : attentionOp.getDpsInputOperands()) {
if (controlFn && !controlFn(input)) {
continue;
}

auto maybeProducer = input->get().getDefiningOp<linalg::GenericOp>();
if (maybeProducer && maybeProducer.isSingleYieldOp()) {
producer = maybeProducer;
operand = input;
auto maybeTransposeOp = input->get().getDefiningOp<linalg::LinalgOp>();
if (maybeTransposeOp && isaTranspose(maybeTransposeOp) &&
maybeTransposeOp->hasOneUse()) {
transposeOp = maybeTransposeOp;
transposeOperand = input;
break;
}
}
if (!operand) {
return rewriter.notifyMatchFailure(attentionOp, "no operand found");
if (!transposeOperand) {
return rewriter.notifyMatchFailure(attentionOp, "no transpose operand");
}

int64_t inputIndex = operand->getOperandNumber();

auto producerMaps = producer.getIndexingMapsArray();
AffineMap producerInputMap = producerMaps[0];
AffineMap producerResultMap = producerMaps[1];
if (!producerInputMap.isProjectedPermutation() ||
!producerResultMap.isPermutation()) {
return failure();
}
int64_t inputIndex = transposeOperand->getOperandNumber();
SmallVector<int64_t> perm = getPermutation(transposeOp);
auto invPerm = invertPermutationVector(perm);

rewriter.modifyOpInPlace(attentionOp, [&]() {
SmallVector<AffineMap> newIndexingMaps =
attentionOp.getIndexingMapsArray();
AffineMap consumerInputMap = attentionOp.getMatchingIndexingMap(operand);
AffineMap composedMap =
producerInputMap.compose(inversePermutation(producerResultMap));
newIndexingMaps[inputIndex] = composedMap.compose(consumerInputMap);
AffineMap inputMap = attentionOp.getMatchingIndexingMap(transposeOperand);
SmallVector<AffineExpr> newExprs =
applyPermutation(inputMap.getResults(), invPerm);
AffineMap transposedMap =
AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
newExprs, rewriter.getContext());
newIndexingMaps[inputIndex] = transposedMap;
attentionOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));
attentionOp.setOperand(inputIndex, producer.getDpsInputs()[0]);
attentionOp.setOperand(inputIndex, transposeOp.getDpsInputs()[0]);
});

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,6 @@ util.func public @fuse_generic_gather2(
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32

// -----

util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> {
// Dequantize int-quantization of V
%init_dequant = tensor.empty() : tensor<2x10x4096x64xf16>
Expand Down Expand Up @@ -260,64 +258,3 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]]

// -----

util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> {
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x8x128x?xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x8x4x128x?xf16>
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) {
^bb0(%arg7: f32):
iree_linalg_ext.yield %arg7 : f32
} -> tensor<4x8x4x?x32x128xf16>
util.return %1 : tensor<4x8x4x?x32x128xf16>
}

// CHECK-LABEL: func public @fuse_attention_with_broadcast
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
// CHECK: util.return %[[ATTENTION]]


// -----

util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x128xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> {
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x8x128xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x8x4x128x?xf16>
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) {
^bb0(%arg7: f32):
iree_linalg_ext.yield %arg7 : f32
} -> tensor<4x8x4x?x32x128xf16>
util.return %1 : tensor<4x8x4x?x32x128xf16>
}

// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d5)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
// CHECK: util.return %[[ATTENTION]]

0 comments on commit 9870a6d

Please sign in to comment.