-
Notifications
You must be signed in to change notification settings - Fork 646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DispatchCreation] Run preprocessing before elementwise fusion #18920
[DispatchCreation] Run preprocessing before elementwise fusion #18920
Conversation
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
This is the state of the IR (from the linked issue) before reshape propagation: %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<30522x128xf16>) outs(%1 : tensor<30522x128xf32>) {
^bb0(%in: f16, %out: f32):
%8 = arith.extf %in : f16 to f32
linalg.yield %8 : f32
} -> tensor<30522x128xf32>
%expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 30522, 128] : tensor<30522x128xf32> into tensor<1x30522x128xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<1x128xi32> into tensor<128xi32>
%3 = tensor.empty() : tensor<128x128xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<128xi32>) outs(%3 : tensor<128x128xf32>) {
^bb0(%in: i32, %out: f32):
%8 = arith.index_cast %in : i32 to index
%9 = linalg.index 1 : index
%extracted = tensor.extract %expanded[%c0, %8, %9] : tensor<1x30522x128xf32>
linalg.yield %extracted : f32
} -> tensor<128x128xf32>
Note that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense to me. Just one question, I think it is a prerequisite for element-wise fusion. Perhaps we should also delete the below passes (which are run right before addDispatchRegionCreationPreprocessingPasses
), because they are moved into the preprocessing passes. What do you think?
(You probably want to run canonicalizer and cse at the beginning of the pipeline.)
iree/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
Lines 297 to 301 in e66171a
FunctionLikeNest(passManager) | |
// Preprocess the input to a form more amenable for fusion. | |
.addPass(DispatchCreation::createFusionPreprocessingPass) | |
.addPass(IREE::Flow::createCanonicalizerPass) | |
.addPass(mlir::createCSEPass); |
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
That makes sense to me, I'll move them inside and delete the extra preprocessing |
@c-rhodes this should fix the problem you were encountering on #17226 (comment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test please :) .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm this fixes the issue, thanks for the fix!
there's a regression in amdgpu_rocm_mi300_gfx942 but I can see this in other PRs so I think it's safe to say it's not because of this PR. I'll go ahead and land the fix, cheers! |
This might have regressed VAE decode time on MI250.
That was also reported in the checks on the PR: https://github.com/iree-org/iree/actions/runs/11561437104/job/32180922269#step:8:128 |
@ScottTodd thank you for the comment, I didn't realize this was merged. I was going to look into the regressions before merging but don't have a fix yet. This may cause issues for other's PRs so I'll open a revert and re-land once the regressions have been resolved. I looked into this a bit yesterday and I'm a bit confused why this is causing runtime regressions but no change to dispatch count. |
apologies for jumping the gun and landing this one! I noticed the regression but thought I saw the same one on other PRs and disregarded it. |
No worries, I also thought it was just being flaky at first :) |
This PR got merged before I was able to resolve the perf regressions in VAE decode on MI250. See @ScottTodd's comment on the original PR. I need time to resolve the regressions but this can be relanded once resolved Reverts #18920
I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from #17226 (comment). --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu> Signed-off-by: Elias Joseph <eljoseph@amd.com>
This PR got merged before I was able to resolve the perf regressions in VAE decode on MI250. See @ScottTodd's comment on the original PR. I need time to resolve the regressions but this can be relanded once resolved Reverts #18920 Signed-off-by: Elias Joseph <eljoseph@amd.com>
…org#18920) I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from iree-org#17226 (comment). --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu> Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
This PR got merged before I was able to resolve the perf regressions in VAE decode on MI250. See @ScottTodd's comment on the original PR. I need time to resolve the regressions but this can be relanded once resolved Reverts iree-org#18920 Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
I think it makes sense to run
FusionPreprocessingPass
beforeElementwiseOpFusionPass
because it helps put the IR in a better state for fusion (e.g. interchanginglinalg.generic
indexing maps). But also, reshapes have been propagated to the edges of the program, which allows theGatherFusionPattern
to be more effective.Fixes compilation error from #17226 (comment).