-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wave] Add self_index, predicate, and selectOp to implement causal attention #452
Conversation
raikonenfnu
commented
Feb 4, 2025
•
edited
Loading
edited
- Extracted core pieces of self_index, predicate, and selectOp, and LIT for predicate and select written by @nicolasvasilache and @ftynse which is required for causal mask and remove causal mask unrelated pieces.
- Implemented a numerically correct causal attention kernel based on original from @nicolasvasilache
- Added GPR_NUM partitioning support for SelfIndex to allow causal to work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs)
- Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve number of tkw ops and for user ergonomics
- Refactored vanilla kernel to support both in single kernel, controlled by is_causal flag
- Add support on handle_op to take in multiple Ops that map to same function.
- Added a bunch of LIT tests
504fb96
to
7ba8dcf
Compare
elif isinstance(custom, SelfIndex): | ||
# TODO: Add support on how to handle strided reads. | ||
new_node = SelfIndex( | ||
custom.idx, custom.dtype, custom.elements_per_thread |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: in line 202 above, elements_per_thread
is taken as custom.index[custom.idx].size
instead of the op argument. I don't fully understand what's going on here, so only pointing out the inconsistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch here, got some more clarification from Nicolas, fixed this up to be more consistent
iree/turbine/kernel/wave/codegen.py
Outdated
@@ -1205,13 +1283,77 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult: | |||
return result | |||
|
|||
|
|||
@handle_binary_op(operator.gt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: originally, I intentionally didn't overload comparison operators. In my previous experience with python bindings and MLIR, this only brings pain and suffering.
Also, arith dialect operations are only expected to work on signless integers, and will likely fail verificaiton or assert if you try feeding them signed/unsigned, so using the signedness bit to differentiate comparison types is most likely not going to work. It is just not exercised anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raikonenfnu +1 on not being too smart here and bias towards more explicit as we are mixing sympy, tracing and tkw functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have support for both >,<, ... and tkw.gt,le etc. One of the big plusses was the readability and my two cents are that making it more readable is a good thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an assertion to disallow unsigned types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done and done!
iree/turbine/kernel/wave/codegen.py
Outdated
element_type.is_signed or element_type.is_signless | ||
): | ||
result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs) | ||
elif _is_integer_like_type(element_type) and element_type.is_unsigned(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that is_unsigned
is a property, it should not be called. Here and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks! :)
iree/turbine/kernel/wave/codegen.py
Outdated
offset = vector_d.splat(vector_index_type, start) | ||
shifted = arith_d.AddIOp(scaled, offset) | ||
shifted = arith_d.addi(scaled, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you prefer the pretty form, should you also use
arith_d.index_cast
below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, done!
@@ -95,6 +98,12 @@ def repeat( | |||
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) | |||
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) | |||
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) | |||
if is_causal: | |||
m_index = tkw.self_index(M, tkl.i64, elements_per_thread=1) | |||
m_index = tkw.broadcast(m_index, target_shape=[M, K2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the big different I am seeing here with what I had written is that we broadcast to [M, K2]
.
Was that the part that triggered incorrect codegen and how should we guard against such mistakes in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a target_shape verifier here in __post__init
@@ -95,6 +98,12 @@ def repeat( | |||
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) | |||
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) | |||
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) | |||
if is_causal: | |||
m_index = tkw.self_index(M, tkl.i64, elements_per_thread=1) | |||
m_index = tkw.broadcast(m_index, target_shape=[M, K2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I would keep the comment about M will resolve to 1 after mapping and therefore the broadcast will be legal thanks to static information carried by transformations. This is not at all trivial to understand and should be pointed out IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, done, thanks!
@@ -197,7 +198,10 @@ def has_gpr_offsets(node: fx.Node) -> bool: | |||
dim: simplify_index(custom.index.get(dim, custom.index[dim])) | |||
for dim in custom.index | |||
} | |||
elements_per_thread = subs_idxc(custom.elements_per_thread) | |||
if isinstance(custom, SelfIndex): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite dark magic to me :)
Could you explain (as a reply to this comment) what is happening here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason we are doing this is to support 32x32x8 MFMA layout. Let me give a short description on the 32x32x8 MFMA layout. On the C matrix (shaped 32x32) it is col-major order, and on a given column, the thread ownership looks something like
[base_tid, base_tid, base_tid, base_tid, base_tid+32, base_tid+32, base_tid+32, base_tid+32, base_tid, base_tid, base_tid, base_tid, ...., base_tid+32, base_tid+32, base_tid+32, base_tid+32]
Where base_tid on every column in it's 32 row is different (0, 1, 2, ..., 31)
This means on a given column, the data will be owned by two different threads which is fine.
But the more difficult part if you notice is in a column, thread ownership switches every 4 elements. These contiguous 4 element owned by a single thread is what we call a GPR_CHUNK
. For this case each GPR_CHUNK
has 4 elements, so to reference each element we will go by it's GPR_NUM
, which in this case is (0, 1, 2, 3).
We now know that for this layout (32x32x8), each thread in a column will own 4 GPR_CHUNK
s, where every chunk has a size of 4 elements and a relative offset of 0, 8, 16, 24
. => we cannot use a single self_index with the original base offset and size 16 (equivalent to C[base_thread_offset:base_thread_offset+16, :])
, but rather, we'd need 4 self_index with a different relative offsets, which we later combine into one with our ReshapeOp
slice0 = C[base_thread_offset:base_thread_offset+4, :])
slice1 = C[base_thread_offset+8:base_thread_offset+12, :])
slice2 = C[base_thread_offset+16:base_thread_offset+20, :])
slice3 = C[base_thread_offset+24:base_thread_offset+28, :])
combined_slice = Reshape([slice0, slice1, slice2, slice3])
ReshapeOp with multiple inputs in Wave has been used as a "combiner" like op.
Hope that makes sense! :)
c28560f
to
5531c94
Compare
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
This revision uses `tkw.apply_expr` to circumvent type mismatches such as: ``` ValueError: Expected an fx.Node but got <class 'int'> ValueError: Expected an fx.Node but got <class 'Symbol'> ``` This further requires supporting index_cast in `tkw.cast` and playgroun/vanialla_attention.py now produces valid IR. Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Refactor and add default torch implementation against which we allclose. Set sizes to known good values that pass the checks; it is easy to fall off the cliff with various size combinations. Additionally, with the following, one can remove the inplace hack. ``` pip install -r pytorch-rocm-requirements.txt -e . ``` Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! this is great work!
…tention (iree-org#452) - Extracted core pieces of self_index, predicate, and selectOp, and LIT for predicate and select written by @nicolasvasilache and @ftynse which is required for causal mask and remove causal mask unrelated pieces. - Implemented a numerically correct causal attention kernel based on original from @nicolasvasilache - Added GPR_NUM partitioning support for SelfIndex to allow causal to work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs) - Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve number of tkw ops and for user ergonomics - Refactored vanilla kernel to support both in single kernel, controlled by is_causal flag - Add support on handle_op to take in multiple Ops that map to same function. - Added a bunch of LIT tests --------- Signed-off-by: Alex Zinenko <git@ozinenko.com> Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com> Signed-off-by: Stanley Winata <stanley.winata@amd.com> Co-authored-by: Alex Zinenko <git@ozinenko.com> Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com> Signed-off-by: xintin <vermagaurav9408@gmail.com>