Skip to content

Commit

Permalink
Disable logsumexp test for MemoryEfficientAttentionCutlassFwdFlashBwO…
Browse files Browse the repository at this point in the history
…p on "fctls_bflsh: New OP that combines cutlass's fw + flash's bw"


**PERFORMANCE**

<details>
<summary>A100 bw</summary>

```
[--------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------]
                                     |  flash[flshatt]  |  vanilla  |  fwbw[fctls_bflsh]  |  48_chunk3_31735f9[cutlass]
1 threads: ------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |        232.7     |   1813.9  |          240.3      |              391.7         
      f16 B=1024, M=197, H=1, K=64   |        577.0     |   4746.9  |          582.8      |              876.9         
      f16 B=32, M=197, H=16, K=64    |        296.2     |   2434.6  |          303.2      |              459.9         
      f16 B=32, M=197, H=16, K=128   |        682.8     |   4504.9  |          688.6      |              792.5         
      f16 B=16, M=197, H=16, K=64    |        164.9     |   1246.6  |          172.4      |              235.4         
      f16 B=16, M=197, H=16, K=128   |        385.9     |   2272.5  |          394.1      |              455.4         
      f16 B=1, M=4096, H=160, K=128  |      54810.6     |  45967.1  |        54876.8      |            62454.4         
      f16 B=2, M=4096, H=160, K=128  |      84422.1     |           |        84371.5      |            98791.3         
      f16 B=1, M=8192, H=160, K=128  |     216095.0     |           |       216170.6      |           248498.9         
      f16 B=2, M=8192, H=160, K=128  |     330754.4     |           |       331201.2      |           389207.8         
      f16 B=1024, M=82, H=8, K=64    |       1621.7     |   3820.0  |         1625.6      |             1872.4         
      f16 B=150, M=256, H=16, K=64   |       1625.7     |   4551.9  |         1629.7      |             2126.4         
      f16 B=64, M=256, H=12, K=64    |        567.7     |   1493.7  |          569.7      |              741.2         
      f16 B=256, M=4096, H=16, K=64  |     441302.3     |           |       441526.8      |           597391.6         
      f16 B=16, M=128, H=16, K=16    |        114.4     |    266.0  |          148.1      |               93.1         
      f16 B=16, M=128, H=16, K=32    |        112.6     |    269.7  |          243.3      |              127.9         
      f16 B=16, M=128, H=16, K=64    |        113.6     |    267.8  |          149.5      |              131.4         
      f16 B=16, M=128, H=16, K=128   |        158.6     |    298.0  |          160.3      |              175.6         
      f16 B=16, M=512, H=16, K=16    |        323.5     |   1203.4  |          325.9      |              558.2         
      f16 B=16, M=512, H=16, K=32    |        435.2     |   1305.2  |          436.8      |              653.5         
      f16 B=16, M=512, H=16, K=64    |        703.1     |   1543.5  |          706.4      |              848.8         
      f16 B=16, M=512, H=16, K=128   |       1586.5     |   1982.6  |         1588.1      |             1735.4         
      f16 B=16, M=1024, H=16, K=16   |       1252.2     |   4273.8  |         1251.6      |             2236.4         
      f16 B=16, M=1024, H=16, K=32   |       1621.6     |   4494.4  |         1623.8      |             2430.8         
      f16 B=16, M=1024, H=16, K=64   |       2376.6     |   5007.3  |         2381.7      |             3007.2         
      f16 B=16, M=1024, H=16, K=128  |       5647.1     |   5956.1  |         5650.4      |             6296.2         
      f16 B=64, M=128, H=16, K=16    |        145.6     |    439.2  |          148.6      |              165.5         
      f16 B=64, M=128, H=16, K=32    |        212.1     |    544.4  |          214.4      |              210.4         
      f16 B=64, M=128, H=16, K=64    |        310.1     |    767.3  |          312.5      |              330.4         
      f16 B=64, M=128, H=16, K=128   |        562.3     |   1226.6  |          564.7      |              605.5         
      f16 B=64, M=512, H=16, K=16    |       1202.3     |   4481.6  |         1203.0      |             2004.7         
      f16 B=64, M=512, H=16, K=32    |       1543.5     |   4966.9  |         1548.2      |             2379.3         
      f16 B=64, M=512, H=16, K=64    |       2421.7     |   5886.7  |         2424.0      |             3129.6         
      f16 B=64, M=512, H=16, K=128   |       5446.6     |   7713.8  |         5452.7      |             6054.1         
      f16 B=64, M=1024, H=16, K=16   |       4725.5     |  16880.0  |         4719.8      |             7929.4         
      f16 B=64, M=1024, H=16, K=32   |       5716.0     |  17869.0  |         5722.7      |             8876.8         
      f16 B=64, M=1024, H=16, K=64   |       8155.3     |  19924.2  |         8163.5      |            11198.7         
      f16 B=64, M=1024, H=16, K=128  |      19213.5     |  23735.2  |        19228.8      |            21618.9  

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Oct 7, 2022
1 parent 7e21543 commit 6f069c9
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
k,
kv,
) = op_device_dtype_B_Mq_Mkv_H_K_Kv
if op.FORWARD_OPERATOR is None:
if (
op.FORWARD_OPERATOR is None
or op is xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp
):
return
query, key, value, attn_bias = create_tensors(
*op_device_dtype_B_Mq_Mkv_H_K_Kv, fmt="BMK"
Expand Down

0 comments on commit 6f069c9

Please sign in to comment.