Skip to content

Commit

Permalink
[Dispach] Clone chain of ops into dispatch (#19723)
Browse files Browse the repository at this point in the history
This change modifies `cloneProducersToRegion` to iteratively clone
producers into the region. For example if ther is a chain A -> B ->
Dispatch, where A and B are cloneable. `A` is only detected as a
clonable op after `B` has been cloned.


Also traverses backwards to delete trivially dead ops so that consumers
get erased first.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jan 17, 2025
1 parent f31cc72 commit e1f010c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -915,17 +915,20 @@ getCloneableOps(IREE::Flow::DispatchRegionOp regionOp) {
/// Clone producers into the dispatch region.
LogicalResult cloneProducersToRegion(RewriterBase &rewriter,
IREE::Flow::DispatchRegionOp regionOp) {
SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
(void)sortResult;
assert(sortResult && "could not compute topological sorting");

for (Operation *producer : llvm::reverse(cloneableOps)) {
if (failed(
clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) {
return failure();
SmallVector<Operation *> cloneableOps;
do {
cloneableOps = getCloneableOps(regionOp);
bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
(void)sortResult;
assert(sortResult && "could not compute topological sorting");

for (Operation *producer : llvm::reverse(cloneableOps)) {
if (failed(clonePrecedingOpIntoDispatchRegion(rewriter, producer,
regionOp))) {
return failure();
}
}
}
} while (!cloneableOps.empty());

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand All @@ -34,10 +35,13 @@ struct CloneProducersIntoDispatchRegionsPass final
return signalPassFailure();
});

funcOp->walk([&](Operation *op) {
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
if (isOpTriviallyDead(op)) {
return rewriter.eraseOp(op);
}
});

funcOp->walk([&](Operation *op) {
if (!IREE::Flow::isNonNullAndOutsideDispatch(op) ||
!isa<linalg::GenericOp>(op)) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,51 @@ util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<1638
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
// CHECK: flow.return %[[ATTENTION]]

// -----

util.func public @clone_bit_ext_of_gather_like(%arg0: tensor<128256x4096xf16>, %arg1: tensor<4x?xi64>, %arg2: tensor<4096xf32>) -> tensor<4x?xf32> {
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e-01 : f32
%dim = tensor.dim %arg1, %c1 : tensor<4x?xi64>
%0 = tensor.empty(%dim) : tensor<4x?x4096xf16>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x?xi64>) outs(%0 : tensor<4x?x4096xf16>) {
^bb0(%in: i64, %out: f16):
%7 = arith.index_cast %in : i64 to index
%8 = linalg.index 2 : index
%extracted = tensor.extract %arg0[%7, %8] : tensor<128256x4096xf16>
linalg.yield %extracted : f16
} -> tensor<4x?x4096xf16>
%2 = tensor.empty(%dim) : tensor<4x?x4096xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<4x?x4096xf16>) outs(%2 : tensor<4x?x4096xf32>) {
^bb0(%in: f16, %out: f32):
%7 = arith.extf %in : f16 to f32
linalg.yield %7 : f32
} -> tensor<4x?x4096xf32>
%4 = tensor.empty(%dim) : tensor<4x?xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x?xf32>) -> tensor<4x?xf32>
%6 = flow.dispatch.region -> (tensor<4x?xf32>{%dim}) {
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3 : tensor<4x?x4096xf32>) outs(%5 : tensor<4x?xf32>) {
^bb0(%in: f32, %out: f32):
%8 = math.powf %in, %cst_0 : f32
%9 = arith.addf %8, %out : f32
linalg.yield %9 : f32
} -> tensor<4x?xf32>
flow.return %7 : tensor<4x?xf32>
}
util.return %6 : tensor<4x?xf32>
}

// CHECK-LABEL: util.func public @clone_bit_ext_of_gather_like
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[GATHER0:.+]] = linalg.generic
// CHECK: %[[EXTRACT:.+]] = tensor.extract
// CHECK: linalg.yield %[[EXTRACT]]
// CHECK: %[[EXT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GATHER0]] : tensor<4x?x4096xf16>)
// CHECK: %[[EXTF:.+]] = arith.extf
// CHECK: linalg.yield %[[EXTF]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXT]] : tensor<4x?x4096xf32>)
// CHECK: flow.return %[[RES]]

0 comments on commit e1f010c

Please sign in to comment.