Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagation of thread dependent index sequences #292

Merged
merged 1 commit into from
Nov 27, 2024

Conversation

harsh-nod
Copy link
Contributor

@harsh-nod harsh-nod commented Nov 22, 2024

This PR adds a major refactor of the index sequence analysis. Specifically,

  1. Index sequence computation is broken down into two phases - a thread independent index which is calculated with one pass through the graph.
  2. And a thread dependent index which is computed by propagating indices from nodes such as MMA, Read or Write.
  3. A heuristic is added to determine how to propagate information in case of multiple nodes competing for a given node based on dimensional analysis
  4. An additional unit test is added for attention expansion
  5. An e2e test for attention with bias is added

@harsh-nod harsh-nod force-pushed the attn_bias branch 7 times, most recently from 70e166d to a3b6f19 Compare November 24, 2024 19:32
@harsh-nod harsh-nod changed the title Add attention with bias tests Add propagation of thread dependent index sequences Nov 24, 2024
@harsh-nod harsh-nod force-pushed the attn_bias branch 2 times, most recently from 8831e4d to 2ef7724 Compare November 25, 2024 00:13
@harsh-nod harsh-nod force-pushed the attn_bias branch 2 times, most recently from 9a03b35 to bb59e95 Compare November 25, 2024 03:41
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Comment on lines +75 to +77
if dst_op:
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we use it without dst_op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This was when I was using it in 2 places where dst_op was optional. I can modify this so that it doesn't handle that case.

thread_dependent_index: dict[IndexSymbol, IndexSequence],
) -> dict[IndexSymbol, IndexSequence]:
combined_index = {k: v for k, v in thread_independent_index.items()}
for k in combined_index:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add some docs, to different examples/cases on when we'd add the start offsets. (perhaps read -> read or offseted_read + mma?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, we split the index assignment into 2 phases. The first is a thread-independent index assignment where we set the indices based on work group constraints and tiling constraints and anything that in general does not have any thread level dependence. Once we have this index then we go through and propagate the thread dependent index which comes either from MMA nodes or if there are no MMA nodes it comes from reads and writes and we then add the thread dependent index to the thread independent index. So we are always going to be adding these two offsets together.

vector_shapes = (
custom.vector_shapes if custom.vector_shapes else source_vector_shapes
)
sources.append((custom, source_index, vector_shapes))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use source.index? Seems like the only time sourc.index != source_index is when source.index has underlying indices before getting propagated. But wouldn't we want that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So source.index contains the unified index which is the sum of the thread dependent index and the thread independent index. During propagation we only want to propagate the thread dependent index and that's why we only propagate source_index.

Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good and nice refactoring, just couple clarifying questions. But we can land for now. :)

@raikonenfnu raikonenfnu merged commit d9d2e7b into iree-org:main Nov 27, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants