From 78481a6ed98c9be1dd9c33eda0572e391a4d8d89 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:44:42 -0700 Subject: [PATCH] Propagate reshapes through generics with reduction iterators (#18857) Removes the constraint in `BubbleUpExpandShapes` that prevents moving tensor reshape ops through reduction `linalg.generic` ops. This has the benefit of increasing the dimensionality of reduction ops (more fusion opportunities) as well as increasing the chance these ops will be moved to the edge of the program. Closes https://github.com/iree-org/iree/issues/18854 --------- Signed-off-by: Ian Wood --- .github/workflows/pkgci_regression_test.yml | 4 ++-- .../Dialect/Flow/Transforms/test/pipeline_tests.mlir | 7 +++---- .../DispatchCreation/BubbleUpExpandShapes.cpp | 12 +++--------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 9849c574dd72..a11107771012 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -222,7 +222,7 @@ jobs: --goldentime-rocm-vae-ms 337.0 \ --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ - --goldendispatch-rocm-vae 247 \ + --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2280000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ @@ -243,7 +243,7 @@ jobs: --goldentime-rocm-vae-ms 80.0 \ --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ - --goldendispatch-rocm-vae 247 \ + --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir index 8973ba5a0278..0c4430fe7dfe 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir @@ -80,13 +80,13 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1: // CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]] // CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]] // CHECK: %[[GEN0:.+]] = linalg.generic -// CHECK-SAME: ["parallel", "parallel", "parallel"] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK: arith.extui // CHECK: arith.uitofp // CHECK: arith.subf // CHECK: arith.mulf // CHECK: %[[GEN1:.+]] = linalg.generic -// CHECK-SAME: ["parallel", "reduction", "reduction"] +// CHECK-SAME: ["parallel", "parallel", "parallel", "reduction", "reduction"] // CHECK-SAME: ins( // CHECK-SAME: %[[GEN0]] // CHECK-SAME: outs( @@ -95,5 +95,4 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1: // CHECK: flow.dispatch.tensor.store %[[GEN1]] // CHECK: util.func public @grouped_quantized_matmul( // CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]] -// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32> -// CHECK: util.return %[[RS]] +// CHECK: util.return %[[T0]] diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 79ae8d3b2ba8..9ee67d637c06 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -57,12 +57,8 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } - // Do not fuse producer generic op if it has more than one user - // or any reduction iterators. if (auto producerGenericOp = dyn_cast(producer)) { - return producerGenericOp->hasOneUse() && - llvm::all_of(producerGenericOp.getIteratorTypesArray(), - linalg::isParallelIterator); + return true; } // Do not fuse with any producer linalg named ops for now. @@ -70,11 +66,9 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } - // Do not fuse with consumer linalg named ops or reductions. + // Do not fuse with consumer linalg named ops. if (auto consumerLinalgOp = dyn_cast(consumer)) { - return isa(consumerLinalgOp) && - llvm::all_of(consumerLinalgOp.getIteratorTypesArray(), - linalg::isParallelIterator); + return isa(consumerLinalgOp); } // Fuse in all other cases. return true;