Skip to content

Commit

Permalink
Propagate reshapes through generics with reduction iterators (iree-or…
Browse files Browse the repository at this point in the history
…g#18857)

Removes the constraint in `BubbleUpExpandShapes` that prevents moving
tensor reshape ops through reduction `linalg.generic` ops. This has the
benefit of increasing the dimensionality of reduction ops (more fusion
opportunities) as well as increasing the chance these ops will be moved
to the edge of the program.

Closes iree-org#18854

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Oct 30, 2024
1 parent 14f58e0 commit 78481a6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ jobs:
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 247 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -243,7 +243,7 @@ jobs:
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 247 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1:
// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
// CHECK: %[[GEN0:.+]] = linalg.generic
// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK: arith.extui
// CHECK: arith.uitofp
// CHECK: arith.subf
// CHECK: arith.mulf
// CHECK: %[[GEN1:.+]] = linalg.generic
// CHECK-SAME: ["parallel", "reduction", "reduction"]
// CHECK-SAME: ["parallel", "parallel", "parallel", "reduction", "reduction"]
// CHECK-SAME: ins(
// CHECK-SAME: %[[GEN0]]
// CHECK-SAME: outs(
Expand All @@ -95,5 +95,4 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1:
// CHECK: flow.dispatch.tensor.store %[[GEN1]]
// CHECK: util.func public @grouped_quantized_matmul(
// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32>
// CHECK: util.return %[[RS]]
// CHECK: util.return %[[T0]]
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,18 @@ void BubbleUpExpandShapesPass::runOnOperation() {
return false;
}

// Do not fuse producer generic op if it has more than one user
// or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
return producerGenericOp->hasOneUse() &&
llvm::all_of(producerGenericOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return true;
}

// Do not fuse with any producer linalg named ops for now.
if (isa<linalg::LinalgOp>(producer)) {
return false;
}

// Do not fuse with consumer linalg named ops or reductions.
// Do not fuse with consumer linalg named ops.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
return isa<linalg::GenericOp>(consumerLinalgOp) &&
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return isa<linalg::GenericOp>(consumerLinalgOp);
}
// Fuse in all other cases.
return true;
Expand Down

0 comments on commit 78481a6

Please sign in to comment.