You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This looks kind of similar to #374, but I think it's actually not. If you do a chained gemm into a tensor with a dimension that isn't spread over workgroups or tiles, only half the output gets written. The following test (available in GMNGeoffrey@0fef747413), reduced from the flash attention 2 backward pass, fails:
There's only one block, so the writes to dv are for [(x/16)*4:(x/16)*4+4, x%16], but dv is K2xN = 16x32 and the second index maxes out at 15, so half of the N dimension never gets written. Looking at it another way, there are 16x32=512 elements of dv, but each thread is only writing 4, so it's impossible for them all to get filled in. Interestingly, if I add a no-op distribution over the N dimension (tkw.WorkgroupConstraint(N, N, 1)), then we get the correct number of writes and the test passes.
The text was updated successfully, but these errors were encountered:
This looks kind of similar to #374, but I think it's actually not. If you do a chained gemm into a tensor with a dimension that isn't spread over workgroups or tiles, only half the output gets written. The following test (available in GMNGeoffrey@0fef747413), reduced from the flash attention 2 backward pass, fails:
python test
produces this fx trace:
fx trace
and this IR (some ssa variables renamed for readability)
mlir
There's only one block, so the writes to dv are for [(x/16)*4:(x/16)*4+4, x%16], but
dv
isK2xN = 16x32
and the second index maxes out at 15, so half of theN
dimension never gets written. Looking at it another way, there are 16x32=512 elements ofdv
, but each thread is only writing 4, so it's impossible for them all to get filled in. Interestingly, if I add a no-op distribution over theN
dimension (tkw.WorkgroupConstraint(N, N, 1)
), then we get the correct number of writes and the test passes.The text was updated successfully, but these errors were encountered: