-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Detect contiguous access pattern in mapped reads/writes #304
Conversation
1ec5fc1
to
565f81f
Compare
return start_indices | ||
|
||
|
||
def _simplify_sympy_expr(expr: IndexExpr) -> IndexExpr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To accomplish this, you would do the following
> next_indices[2] - prev_indices[2]
-floor(floor((Mod($T0, 4))/80 + floor($ARGC*HF*WF*BLOCK_K/8)/80)/3) + floor(floor((Mod($T0, 4))/80 + floor($ARGC*HF*WF*BLOCK_K/8)/80 + 1/640)/3)
> sympy.cse(next_indices[2] - prev_indices[2])[0]
[(x0, (Mod($T0, 4))/80 + floor($ARGC*HF*WF*BLOCK_K/8)/80)]
> sympy.cse(next_indices[2] - prev_indices[2])[1]
[-floor(floor(x0)/3) + floor(floor(x0 + 1/640)/3)]
Then, finally if you want to add the assumption that x0 ~ O(1), you can just subs x0 -> 1 and you will get 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, as any sane person should do, I wrote a fuzzing tests (see latest commit) and it indeed found some failing cases. But after stricten some checks (and marking induction vars as non-negative) and running 100 000s tests I'm now reasonable sure my transformations are correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks that sounds good, but I think we have been thinking about this the wrong way.
So essentially what we are looking to do here is determine whether the mapping is "contiguity-preserving" (which means that elements that are contiguous pre-mapping will remain contiguous post-mapping). This is a function entirely of the mapping and should not depend at all on the particular access pattern of the underlying data. So I think these patterns to simplify the index sequences are not required and instead we should be looking just at the mapping. For example, looking at one IGEMM mapping which maps a tensor of shape [M, K] to a tensor of shape [N, H, W, C]
x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j % C,
H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF,
},
outputs={M: i, K: j},
)
we can immediately deduce that this mapping is contiguity-preserving because j gets mapped to j % C.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, sympy.cse
looks useful, but I weren't able to construct working code quickly. I will need to think more on it offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, let's land this PR. I have a WIP PR that should simplify this a lot use sympy.cse.
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
ce8b6de
to
ffca63d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor changes, but overall looks good. How did the cse simplification help?
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's land for now. I think we can greatly simplify the checks by using sympy.cse but we can explore that after landing this PR.
Add utility to detect contiguous access pattern in mapped reads/writes so we can use contiguous vector ops instead of gathers/scatters during lowering.
Iterate over
elements_per_thread
and check if every subsequent mapped index have only diff of 1 in fastest changing dim.Also, it need some custom sympy simplifications to successfully work on IGEMM conv.