-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get GEMMs working without minimize_global_loads #167
Conversation
d9dffbe
to
d6f9844
Compare
This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} | ||
for key in mma_slices: | ||
if custom.fx_node in mma_slices[key]: | ||
operand_map[key] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we want for every for key in mma_slices to be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0} then
{MMA_LHS: 0, MMA_RHS: 1, MMA_ACC: 0} and then
{MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1},
But in current state wouldn't this be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}
Although if it that is indeed what we are going for, can you explaine intuition behind it? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if a node is determined to be in the backward slice of the LHS, then we want to specialize it by substituting {MMA_LHS = 1, all else 0}. For RHS, we want {MMA_RHS = 1, all else 0}. For ACC, {MMA_ACC = 1, all else 0}. And if its not in the backward slices of the LHS and RHS or forward slice of the ACC, then {all = 0}. You can think of this as an alternative to propagation. Because we set the entire indices everywhere, we need to specialize them depending on some constraints, and for that we use the forward/backward slices of the MMA operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, in that case I think we need to move the operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0}
above the if custom.fx_node in mma_slices[key]:
. Otherwise the previous state carry over. i.e we will get:
iter_0 setting MMA_LHS, {MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
iter_1 setting MMA_RHS, {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
iter_2 setting MMA_ACC, {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized I put in the wrong state on the previous comment, updated it to make it make more sense haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, we would get carry over, except for the fact that we return as soon as we get a match. So that guarantees that we our dictionary's values will always only have one non-zero entry (= 1) .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK that makes sense that's why it's implicitly functionally equivalent. Can we still bring it down though for better clarity/straightforward-ness? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good overall, left some comments.
Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?
Seems like the "specialization"/"adjusting" of index is only related with subbing in the MMA_ACC/MMA_LHS/MMA_RHS to respective op's index_seq, if this is not happening before, wouldn't we still have a symbolic indexing and the program won't run at all? |
Ohhh actually were we circumventing around that issue by using |
So what was happening before is that we were setting the MMA indices only to the MMA op and then propagating this to the operands during post-expansion. The problem with this was that it required us to do even more propagation to get the shared memory IGEMM case working (since we had to propagate the indices all the way to the global read). As an alternative to propagation, we are now setting globally per dimension indices that include the effects of all the constraints. But the problem with this is you end up with a whole bunch of Piecewise functions that you can't reason about. So we use the slices to determine how to convert these Piecewise functions to indices. As I said in another comment, this will be useful when we deal with multiple MMAs because we will now have overlapping slices and can do some sort of "equality saturation" to determine how to resolve multiple indices per node. This was not a problem before because we restricted these MMA_{LHS/RHS/ACC} variables only to MMA nodes and their neighbors. |
Makes sense, thanks! :) |
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. --------- Signed-off-by: Harsh Menon <harsh@nod-labs.com>
This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. --------- Signed-off-by: Harsh Menon <harsh@nod-labs.com> Signed-off-by: Ian <ian.nordeng@amd.com>
This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. --------- Signed-off-by: Harsh Menon <harsh@nod-labs.com> Signed-off-by: Ian <ian.nordeng@amd.com>
This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. --------- Signed-off-by: Harsh Menon <harsh@nod-labs.com>
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.