Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Hoist loop invariant reads #296

Merged
merged 6 commits into from
Nov 27, 2024
Merged

Conversation

raikonenfnu
Copy link
Contributor

In flash attention, Q's reduction dimension is typically relatively small, and hence we do only have reduction tile across K2 dimension/reduction dimension of 2nd gemm(P and V). Hence, an optimization we can do is to hoist reading of Q from global memory out of the for loop, this actually generates quite a big speedup (hoistQ + use global->register for Q gives typically 2x speed up.)

To implement the optimization above, we needed to add:

  1. Expand hoisting.py to also look for Read that is:
    • independent of induction variable
    • has memory that is read-only/do not have write as users (i.e important for correctness since this guarantee data being read is constant/not changing with loop)
    • has memory who is a captured_var (i.e memory can be traced to outside the loop)
  2. Implement method to hoist reads properly:
    • Copy Read to rootOp
    • replace rootOp's memory who is a captured_var with it's counterpart in the RootOp by querying reduction's implicit_capture
    • Remove unused captured_var from Reduction otherwise scf.for will be indexing/loading from the wrong bindings.
  3. Updated specifically chained_gemm_tests in lit_tests/codegen.py to test for the hoisted reads from global.
  4. Updated lit_tests/attention.py since this change generates new schedule

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM besides couple comments

lit_tests/kernel/wave/codegen.py Outdated Show resolved Hide resolved
lit_tests/kernel/wave/codegen.py Outdated Show resolved Hide resolved
iree/turbine/kernel/wave/hoisting.py Show resolved Hide resolved
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks for landing this. Just a minor comment on refactoring but overall looks good!

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@raikonenfnu raikonenfnu merged commit e3b6c87 into iree-org:main Nov 27, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants