Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get GEMMs working without minimize_global_loads #167

Merged
merged 2 commits into from
Sep 26, 2024

Conversation

harsh-nod
Copy link
Contributor

@harsh-nod harsh-nod commented Sep 25, 2024

This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

@harsh-nod harsh-nod force-pushed the fix_gemm branch 2 times, most recently from d9dffbe to d6f9844 Compare September 25, 2024 02:47
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0}
for key in mma_slices:
if custom.fx_node in mma_slices[key]:
operand_map[key] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we want for every for key in mma_slices to be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0} then
{MMA_LHS: 0, MMA_RHS: 1, MMA_ACC: 0} and then
{MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1},

But in current state wouldn't this be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}

Although if it that is indeed what we are going for, can you explaine intuition behind it? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if a node is determined to be in the backward slice of the LHS, then we want to specialize it by substituting {MMA_LHS = 1, all else 0}. For RHS, we want {MMA_RHS = 1, all else 0}. For ACC, {MMA_ACC = 1, all else 0}. And if its not in the backward slices of the LHS and RHS or forward slice of the ACC, then {all = 0}. You can think of this as an alternative to propagation. Because we set the entire indices everywhere, we need to specialize them depending on some constraints, and for that we use the forward/backward slices of the MMA operands.

Copy link
Contributor

@raikonenfnu raikonenfnu Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, in that case I think we need to move the operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} above the if custom.fx_node in mma_slices[key]:. Otherwise the previous state carry over. i.e we will get:

iter_0 setting MMA_LHS, {MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
iter_1 setting MMA_RHS,  {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
iter_2 setting MMA_ACC, {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized I put in the wrong state on the previous comment, updated it to make it make more sense haha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, we would get carry over, except for the fact that we return as soon as we get a match. So that guarantees that we our dictionary's values will always only have one non-zero entry (= 1) .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense that's why it's implicitly functionally equivalent. Can we still bring it down though for better clarity/straightforward-ness? :)

Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good overall, left some comments.

Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?

@raikonenfnu
Copy link
Contributor

Code looks good overall, left some comments.

Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?

Seems like the "specialization"/"adjusting" of index is only related with subbing in the MMA_ACC/MMA_LHS/MMA_RHS to respective op's index_seq, if this is not happening before, wouldn't we still have a symbolic indexing and the program won't run at all?

@raikonenfnu
Copy link
Contributor

raikonenfnu commented Sep 25, 2024

Code looks good overall, left some comments.
Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?

Seems like the "specialization"/"adjusting" of index is only related with subbing in the MMA_ACC/MMA_LHS/MMA_RHS to respective op's index_seq, if this is not happening before, wouldn't we still have a symbolic indexing and the program won't run at all?

Ohhh actually were we circumventing around that issue by using self.lhs.index = self.lhs_index (similarly for rhs and acc)?

@harsh-nod
Copy link
Contributor Author

harsh-nod commented Sep 25, 2024

Code looks good overall, left some comments.
Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?

Seems like the "specialization"/"adjusting" of index is only related with subbing in the MMA_ACC/MMA_LHS/MMA_RHS to respective op's index_seq, if this is not happening before, wouldn't we still have a symbolic indexing and the program won't run at all?

Ohhh actually were we circumventing around that issue by using self.lhs.index = self.lhs_index (similarly for rhs and acc)?

So what was happening before is that we were setting the MMA indices only to the MMA op and then propagating this to the operands during post-expansion. The problem with this was that it required us to do even more propagation to get the shared memory IGEMM case working (since we had to propagate the indices all the way to the global read). As an alternative to propagation, we are now setting globally per dimension indices that include the effects of all the constraints. But the problem with this is you end up with a whole bunch of Piecewise functions that you can't reason about. So we use the slices to determine how to convert these Piecewise functions to indices. As I said in another comment, this will be useful when we deal with multiple MMAs because we will now have overlapping slices and can do some sort of "equality saturation" to determine how to resolve multiple indices per node.

This was not a problem before because we restricted these MMA_{LHS/RHS/ACC} variables only to MMA nodes and their neighbors.

@raikonenfnu
Copy link
Contributor

Code looks good overall, left some comments.
Also quick conceptual Q, so in this PR are we just obtaining the slices to collect all ops that will be impacted by indexing of MMA, and then adjusting their index_seq appropriately S.T it will work with the MMA indexing?

Seems like the "specialization"/"adjusting" of index is only related with subbing in the MMA_ACC/MMA_LHS/MMA_RHS to respective op's index_seq, if this is not happening before, wouldn't we still have a symbolic indexing and the program won't run at all?

Ohhh actually were we circumventing around that issue by using self.lhs.index = self.lhs_index (similarly for rhs and acc)?

So what was happening before is that we were setting the MMA indices only to the MMA op and then propagating this to the operands during post-expansion. The problem with this was that it required us to do even more propagation to get the shared memory IGEMM case working (since we had to propagate the indices all the way to the global read). As an alternative to propagation, we are now setting globally per dimension indices that include the effects of all the constraints. But the problem with this is you end up with a whole bunch of Piecewise functions that you can't reason about. So we use the slices to determine how to convert these Piecewise functions to indices. As I said in another comment, this will be useful when we deal with multiple MMAs because we will now have overlapping slices and can do some sort of "equality saturation" to determine how to resolve multiple indices per node.

This was not a problem before because we restricted these MMA_{LHS/RHS/ACC} variables only to MMA nodes and their neighbors.

Makes sense, thanks! :)

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@harsh-nod harsh-nod merged commit d37c6a4 into iree-org:main Sep 26, 2024
8 checks passed
@Hardcode84 Hardcode84 mentioned this pull request Sep 26, 2024
IanNod pushed a commit to IanNod/iree-turbine that referenced this pull request Sep 30, 2024
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
IanNod pushed a commit to IanNod/iree-turbine that referenced this pull request Sep 30, 2024
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
IanNod pushed a commit to IanNod/iree-turbine that referenced this pull request Sep 30, 2024
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
stellaraccident pushed a commit that referenced this pull request Oct 13, 2024
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants