Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged Decode Attention #387

Merged
merged 3 commits into from
Jan 17, 2025
Merged

Paged Decode Attention #387

merged 3 commits into from
Jan 17, 2025

Conversation

harsh-nod
Copy link
Contributor

@harsh-nod harsh-nod commented Jan 13, 2025

This PR adds a functional paged decode attention kernel. To get this working, the following changes were made:

  1. Add support for sympy.Min, sympy.Piecewise
  2. Add paged decode attention kernel with test with vLLM and Sglang reference
  3. Bookkeeping of lifted variables to track implicit captures
  4. Updates to expansion to handle MMA reductions where the none of the MMA reduction dims match the reduction dim of the parent op
  5. Updates to expansion to expand symbols that have no constraints on them
  6. Better handling of SetSymbol and ApplyExpr in expansion and other passes
  7. Adds the notion of a primary and non-primary workgroup constraint. Primary workgroup constraints are used to determine the grid shape. Non-primary constraints can use the same workgroup id but are not used to determine the grid shape. This is a weaker form of the symbolic alias where symbols can share workgroup ids and nothing else.

@harsh-nod harsh-nod force-pushed the paged_attn_v2 branch 15 times, most recently from 5e6ef37 to 5779cd1 Compare January 16, 2025 22:10
@harsh-nod harsh-nod changed the title Paged Attention v2 Paged Decode Attention Jan 16, 2025
@harsh-nod harsh-nod requested review from Hardcode84, raikonenfnu and martin-luecke and removed request for Hardcode84 and raikonenfnu January 17, 2025 04:26
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
iree/turbine/kernel/ops/wave_ops.py Outdated Show resolved Hide resolved
iree/turbine/kernel/ops/wave_ops.py Show resolved Hide resolved
iree/turbine/kernel/wave/utils.py Outdated Show resolved Hide resolved
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
@harsh-nod harsh-nod merged commit bd8a1f8 into iree-org:main Jan 17, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants