Skip to content

Commit

Permalink
[Flow] Make sink reshapes changes less conservative. (#17706)
Browse files Browse the repository at this point in the history
While deciding if a reshape needs "sinking", for a `tensor.expand_shape`
-> `linalg.*`, first check was to check that the `linalg.*` operation
could already fuse with one of its existing producers. That check was
broadly aggressive. The fusion only kicks in when the iteration domains
match. Eventually the actual dispatch formation logic needs to be
commoned to a single place to do this better, but kicking that to a
follow up.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Jun 29, 2024
1 parent 7090f64 commit 4ad00ef
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 19 deletions.
101 changes: 82 additions & 19 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,98 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) {
if (!isNonNullAndOutsideDispatch({reshapeOp, consumer})) {
return false;
}
auto consumerGenericOp = dyn_cast<linalg::GenericOp>(consumer);
if (!consumerGenericOp) {
return false;
}
// Only sink across parallel generic ops for now.
if (consumerGenericOp.getNumParallelLoops() !=
consumerGenericOp.getNumLoops()) {
return false;
}

// Do not sink reshapes across dequantize operations since they are
// cloned into their producers.
// cloned into their consumers.
if (isDequantizationLikeOp(consumer)) {
return false;
}

// If the op is already fusable with producer using tile and fuse,
// do nothing.
if (llvm::any_of(consumer->getOpOperands(), [](OpOperand &opOperand) {
Operation *currProducer = opOperand.get().getDefiningOp();
Operation *currConsumer = opOperand.getOwner();
return isFusableUsingTileAndFuse(currProducer, currConsumer) &&
// The check for the producer having a single use is not fully
// worked out. Ideally we can fuse with a producer irrespective
// of number of uses, but is a good thumb rule in practice.
llvm::hasSingleElement(currProducer->getUses());
})) {
// First check that the expand_shape producer and consumer can be fused.
Operation *reshapeProducer = reshapeOp.getSrc().getDefiningOp();
if (!reshapeProducer) {
return false;
}

// Do not sink if consumer is a contraction/matmul like op.
if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumer)) {
if (linalg::isaContractionOpInterface(linalgConsumerOp))
return false;
if (!isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(),
consumer)) {
return false;
}

return isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(),
consumer);
// If the op is already fusable with producer using tile and fuse,
// do nothing.
for (OpOperand &opOperand : consumer->getOpOperands()) {
Operation *currProducer = opOperand.get().getDefiningOp();
if (!currProducer) {
continue;
}

// The check for the producer having a single use is not fully
// worked out. Ideally we can fuse with a producer irrespective
// of number of uses, but is a good thumb rule in practice.
if (!llvm::hasSingleElement(currProducer->getUses())) {
continue;
}

// Check if a producer can already be tiled and fused with the consumer.
if (!isFusableUsingTileAndFuse(currProducer, consumer)) {
continue;
}

// There is already a tile-and-fusable producer to fuse with. Still prefer
// fusing with the producer whose parallel iteration space rank matches
// the consumer parallel iteration space rank to avoid loss of parallelism.
if (auto currLinalgProducer = dyn_cast<linalg::LinalgOp>(currProducer)) {
auto reshapeLinalgProducer = dyn_cast<linalg::LinalgOp>(reshapeProducer);
if (!reshapeLinalgProducer) {
// For now we will prefer to fold with Linalg op. So if the reshape
// producer is not a Linalg op, bail.
return false;
}

// Somehow this logic does not seem to work well when the reshape producer
// is an elementwise operation. For one, should never have a reshape
// "after" an elementwise operation, since bubble expand shape should
// already account for it, and fuse the elementwise producer of reshape
// and the consumer (which is also elementwise). Needs more investigation
// but removes regressions and lit test failures.
if (reshapeLinalgProducer.getNumLoops() ==
reshapeLinalgProducer.getNumParallelLoops() &&
currLinalgProducer.getNumLoops() !=
currLinalgProducer.getNumParallelLoops()) {
return false;
}

unsigned currConsumerNumParallelLoops =
consumerGenericOp.getNumParallelLoops();
unsigned currProducerNumParallelLoops =
currLinalgProducer.getNumParallelLoops();
if (currProducerNumParallelLoops == currConsumerNumParallelLoops) {
// If the producer has same number of parallel loops as consumer,
// then this is the operand to fuse along. So do nothing.
return false;
}
// If the producer has less number of parallel loops as the consumer,
// ignore this operand.
if (currProducerNumParallelLoops < currConsumerNumParallelLoops) {
continue;
}
unsigned reshapeProducerNumParallelLoops =
reshapeLinalgProducer.getNumParallelLoops();
if (currProducerNumParallelLoops < reshapeProducerNumParallelLoops) {
return false;
}
}
}
return true;
}

void SinkReshapesPass::runOnOperation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,87 @@ func.func @do_not_sink_across_dequantize_ops(%arg0: tensor<?x?xf32>) -> tensor<2
// CHECK: %[[DEQUANT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND]] :
// CHECK: return %[[DEQUANT]]

// -----

// Check that reshape sinks based with better estimate of what producers
// -> consumer are fusable.
func.func @better_producer_estimate(%lhs : tensor<2x4096x640xi32>, %rhs : tensor<2x640x640xi32>,
%fill0 : tensor<2x4096x640xi32>, %fill1 : tensor<2x4096xi32>) -> tensor<2x4096x640x1xf16> {
%bmm = linalg.batch_matmul_transpose_b ins(%lhs, %rhs : tensor<2x4096x640xi32>, tensor<2x640x640xi32>)
outs(%fill0 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%reduction = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs : tensor<2x4096x640xi32>) outs(%fill1 : tensor<2x4096xi32>) {
^bb0(%in: i32, %out: i32):
%12 = arith.addi %in, %out : i32
linalg.yield %12 : i32
} -> tensor<2x4096xi32>
%expanded = tensor.expand_shape %bmm [[0], [1], [2, 3]] output_shape [2, 4096, 640, 1]
: tensor<2x4096x640xi32> into tensor<2x4096x640x1xi32>
%empty = tensor.empty() : tensor<2x4096x640x1xf16>
%quant = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded, %reduction : tensor<2x4096x640x1xi32>, tensor<2x4096xi32>)
outs(%empty : tensor<2x4096x640x1xf16>) {
^bb0(%in: i32, %in_3: i32, %out: f16):
%14 = arith.subi %in, %in_3 : i32
%16 = arith.sitofp %14 : i32 to f32
%18 = arith.truncf %16 : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x4096x640x1xf16>
return %quant : tensor<2x4096x640x1xf16>
}
// CHECK-LABEL: func @better_producer_estimate(
// CHECK: %[[BMM:.+]] = linalg.batch_matmul_transpose_b
// CHECK: %[[REDUCTION:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[BMM]], %[[REDUCTION]] :
// CHECK: %[[COLLAPSE:.+]] = tensor.expand_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]

// -----

func.func @reduce_broadcast(%arg0: tensor<4x768xf32>, %arg1: tensor<4xf32>,
%arg2: tensor<4xf32>, %arg3: tensor<1x4x768xf32>) -> tensor<1x4x768xf32> {
%cst = arith.constant 9.000000e+00 : f32
%cst_0 = arith.constant 8.000000e+00 : f32
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<4x768xf32>, tensor<4xf32>)
outs(%arg2 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%3 = arith.subf %in, %in_2 : f32
%4 = arith.mulf %3, %3 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
} -> tensor<4xf32>
%expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4]
: tensor<4xf32> into tensor<1x4xf32>
%1 = tensor.empty() : tensor<1x4x768xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg3, %expanded : tensor<1x4x768xf32>, tensor<1x4xf32>)
outs(%1 : tensor<1x4x768xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%9 = arith.mulf %in, %in_2 : f32
linalg.yield %9 : f32
} -> tensor<1x4x768xf32>
return %2 : tensor<1x4x768xf32>
}
// CHECK-LABEL: func @reduce_broadcast(
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC1]] :
// CHECK: tensor.expand_shape %[[GENERIC2]]

0 comments on commit 4ad00ef

Please sign in to comment.