Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wave] Add self_index, predicate, and selectOp to implement causal attention #452

Merged
merged 23 commits into from
Feb 5, 2025

Conversation

raikonenfnu
Copy link
Contributor

@raikonenfnu raikonenfnu commented Feb 4, 2025

  • Extracted core pieces of self_index, predicate, and selectOp, and LIT for predicate and select written by @nicolasvasilache and @ftynse which is required for causal mask and remove causal mask unrelated pieces.
  • Implemented a numerically correct causal attention kernel based on original from @nicolasvasilache
  • Added GPR_NUM partitioning support for SelfIndex to allow causal to work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs)
  • Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve number of tkw ops and for user ergonomics
  • Refactored vanilla kernel to support both in single kernel, controlled by is_causal flag
  • Add support on handle_op to take in multiple Ops that map to same function.
  • Added a bunch of LIT tests

@raikonenfnu raikonenfnu requested a review from harsh-nod February 4, 2025 06:50
@raikonenfnu raikonenfnu force-pushed the causal_attention branch 3 times, most recently from 504fb96 to 7ba8dcf Compare February 4, 2025 07:35
elif isinstance(custom, SelfIndex):
# TODO: Add support on how to handle strided reads.
new_node = SelfIndex(
custom.idx, custom.dtype, custom.elements_per_thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: in line 202 above, elements_per_thread is taken as custom.index[custom.idx].size instead of the op argument. I don't fully understand what's going on here, so only pointing out the inconsistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch here, got some more clarification from Nicolas, fixed this up to be more consistent

@@ -1205,13 +1283,77 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
return result


@handle_binary_op(operator.gt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: originally, I intentionally didn't overload comparison operators. In my previous experience with python bindings and MLIR, this only brings pain and suffering.

Also, arith dialect operations are only expected to work on signless integers, and will likely fail verificaiton or assert if you try feeding them signed/unsigned, so using the signedness bit to differentiate comparison types is most likely not going to work. It is just not exercised anywhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raikonenfnu +1 on not being too smart here and bias towards more explicit as we are mixing sympy, tracing and tkw functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have support for both >,<, ... and tkw.gt,le etc. One of the big plusses was the readability and my two cents are that making it more readable is a good thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an assertion to disallow unsigned types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done and done!

element_type.is_signed or element_type.is_signless
):
result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that is_unsigned is a property, it should not be called. Here and elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks! :)

offset = vector_d.splat(vector_index_type, start)
shifted = arith_d.AddIOp(scaled, offset)
shifted = arith_d.addi(scaled, offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you prefer the pretty form, should you also use

arith_d.index_cast

below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, done!

@@ -95,6 +98,12 @@ def repeat(
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
if is_causal:
m_index = tkw.self_index(M, tkl.i64, elements_per_thread=1)
m_index = tkw.broadcast(m_index, target_shape=[M, K2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the big different I am seeing here with what I had written is that we broadcast to [M, K2].
Was that the part that triggered incorrect codegen and how should we guard against such mistakes in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a target_shape verifier here in __post__init

@@ -95,6 +98,12 @@ def repeat(
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
if is_causal:
m_index = tkw.self_index(M, tkl.i64, elements_per_thread=1)
m_index = tkw.broadcast(m_index, target_shape=[M, K2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I would keep the comment about M will resolve to 1 after mapping and therefore the broadcast will be legal thanks to static information carried by transformations. This is not at all trivial to understand and should be pointed out IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, done, thanks!

@@ -197,7 +198,10 @@ def has_gpr_offsets(node: fx.Node) -> bool:
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}
elements_per_thread = subs_idxc(custom.elements_per_thread)
if isinstance(custom, SelfIndex):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite dark magic to me :)
Could you explain (as a reply to this comment) what is happening here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason we are doing this is to support 32x32x8 MFMA layout. Let me give a short description on the 32x32x8 MFMA layout. On the C matrix (shaped 32x32) it is col-major order, and on a given column, the thread ownership looks something like

[base_tid, base_tid, base_tid, base_tid, base_tid+32, base_tid+32, base_tid+32, base_tid+32, base_tid, base_tid, base_tid, base_tid,  ...., base_tid+32, base_tid+32, base_tid+32, base_tid+32]

Where base_tid on every column in it's 32 row is different (0, 1, 2, ..., 31)
This means on a given column, the data will be owned by two different threads which is fine.
But the more difficult part if you notice is in a column, thread ownership switches every 4 elements. These contiguous 4 element owned by a single thread is what we call a GPR_CHUNK. For this case each GPR_CHUNK has 4 elements, so to reference each element we will go by it's GPR_NUM, which in this case is (0, 1, 2, 3).

We now know that for this layout (32x32x8), each thread in a column will own 4 GPR_CHUNKs, where every chunk has a size of 4 elements and a relative offset of 0, 8, 16, 24. => we cannot use a single self_index with the original base offset and size 16 (equivalent to C[base_thread_offset:base_thread_offset+16, :]), but rather, we'd need 4 self_index with a different relative offsets, which we later combine into one with our ReshapeOp

slice0 = C[base_thread_offset:base_thread_offset+4, :])
slice1 = C[base_thread_offset+8:base_thread_offset+12, :])
slice2 = C[base_thread_offset+16:base_thread_offset+20, :])
slice3 = C[base_thread_offset+24:base_thread_offset+28, :])
combined_slice = Reshape([slice0, slice1, slice2, slice3])

ReshapeOp with multiple inputs in Wave has been used as a "combiner" like op.
Hope that makes sense! :)

ftynse and others added 20 commits February 4, 2025 11:59
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
This revision uses `tkw.apply_expr` to circumvent type mismatches such as:

```
ValueError: Expected an fx.Node but got <class 'int'>
ValueError: Expected an fx.Node but got <class 'Symbol'>
```

This further requires supporting index_cast in `tkw.cast` and
playgroun/vanialla_attention.py now produces valid IR.

Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Refactor and add default torch implementation against which we allclose.
Set sizes to known good values that pass the checks; it is easy to fall off the cliff with various size combinations.

Additionally, with the following, one can remove the inplace hack.
```
pip install -r pytorch-rocm-requirements.txt  -e .
```

Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! this is great work!

@raikonenfnu raikonenfnu merged commit 7038127 into iree-org:main Feb 5, 2025
9 of 10 checks passed
xintin pushed a commit to xintin/iree-turbine that referenced this pull request Feb 14, 2025
…tention (iree-org#452)

- Extracted core pieces of self_index, predicate, and selectOp, and LIT
for predicate and select written by @nicolasvasilache and @ftynse which
is required for causal mask and remove causal mask unrelated pieces.
- Implemented a numerically correct causal attention kernel based on
original from @nicolasvasilache
- Added GPR_NUM partitioning support for SelfIndex to allow causal to
work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs)
- Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve
number of tkw ops and for user ergonomics
- Refactored vanilla kernel to support both in single kernel, controlled
by is_causal flag
- Add support on handle_op to take in multiple Ops that map to same
function.
 - Added a bunch of LIT tests

---------

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Alex Zinenko <git@ozinenko.com>
Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: xintin <vermagaurav9408@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants