Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Disable logsumexp test for MemoryEfficientAttentionCutlassFwdFlashBwO…
…p on "fctls_bflsh: New OP that combines cutlass's fw + flash's bw" **PERFORMANCE** <details> <summary>A100 bw</summary> ``` [--------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------] | flash[flshatt] | vanilla | fwbw[fctls_bflsh] | 48_chunk3_31735f9[cutlass] 1 threads: ------------------------------------------------------------------------------------------------------------ f16 B=384, M=197, H=1, K=64 | 232.7 | 1813.9 | 240.3 | 391.7 f16 B=1024, M=197, H=1, K=64 | 577.0 | 4746.9 | 582.8 | 876.9 f16 B=32, M=197, H=16, K=64 | 296.2 | 2434.6 | 303.2 | 459.9 f16 B=32, M=197, H=16, K=128 | 682.8 | 4504.9 | 688.6 | 792.5 f16 B=16, M=197, H=16, K=64 | 164.9 | 1246.6 | 172.4 | 235.4 f16 B=16, M=197, H=16, K=128 | 385.9 | 2272.5 | 394.1 | 455.4 f16 B=1, M=4096, H=160, K=128 | 54810.6 | 45967.1 | 54876.8 | 62454.4 f16 B=2, M=4096, H=160, K=128 | 84422.1 | | 84371.5 | 98791.3 f16 B=1, M=8192, H=160, K=128 | 216095.0 | | 216170.6 | 248498.9 f16 B=2, M=8192, H=160, K=128 | 330754.4 | | 331201.2 | 389207.8 f16 B=1024, M=82, H=8, K=64 | 1621.7 | 3820.0 | 1625.6 | 1872.4 f16 B=150, M=256, H=16, K=64 | 1625.7 | 4551.9 | 1629.7 | 2126.4 f16 B=64, M=256, H=12, K=64 | 567.7 | 1493.7 | 569.7 | 741.2 f16 B=256, M=4096, H=16, K=64 | 441302.3 | | 441526.8 | 597391.6 f16 B=16, M=128, H=16, K=16 | 114.4 | 266.0 | 148.1 | 93.1 f16 B=16, M=128, H=16, K=32 | 112.6 | 269.7 | 243.3 | 127.9 f16 B=16, M=128, H=16, K=64 | 113.6 | 267.8 | 149.5 | 131.4 f16 B=16, M=128, H=16, K=128 | 158.6 | 298.0 | 160.3 | 175.6 f16 B=16, M=512, H=16, K=16 | 323.5 | 1203.4 | 325.9 | 558.2 f16 B=16, M=512, H=16, K=32 | 435.2 | 1305.2 | 436.8 | 653.5 f16 B=16, M=512, H=16, K=64 | 703.1 | 1543.5 | 706.4 | 848.8 f16 B=16, M=512, H=16, K=128 | 1586.5 | 1982.6 | 1588.1 | 1735.4 f16 B=16, M=1024, H=16, K=16 | 1252.2 | 4273.8 | 1251.6 | 2236.4 f16 B=16, M=1024, H=16, K=32 | 1621.6 | 4494.4 | 1623.8 | 2430.8 f16 B=16, M=1024, H=16, K=64 | 2376.6 | 5007.3 | 2381.7 | 3007.2 f16 B=16, M=1024, H=16, K=128 | 5647.1 | 5956.1 | 5650.4 | 6296.2 f16 B=64, M=128, H=16, K=16 | 145.6 | 439.2 | 148.6 | 165.5 f16 B=64, M=128, H=16, K=32 | 212.1 | 544.4 | 214.4 | 210.4 f16 B=64, M=128, H=16, K=64 | 310.1 | 767.3 | 312.5 | 330.4 f16 B=64, M=128, H=16, K=128 | 562.3 | 1226.6 | 564.7 | 605.5 f16 B=64, M=512, H=16, K=16 | 1202.3 | 4481.6 | 1203.0 | 2004.7 f16 B=64, M=512, H=16, K=32 | 1543.5 | 4966.9 | 1548.2 | 2379.3 f16 B=64, M=512, H=16, K=64 | 2421.7 | 5886.7 | 2424.0 | 3129.6 f16 B=64, M=512, H=16, K=128 | 5446.6 | 7713.8 | 5452.7 | 6054.1 f16 B=64, M=1024, H=16, K=16 | 4725.5 | 16880.0 | 4719.8 | 7929.4 f16 B=64, M=1024, H=16, K=32 | 5716.0 | 17869.0 | 5722.7 | 8876.8 f16 B=64, M=1024, H=16, K=64 | 8155.3 | 19924.2 | 8163.5 | 11198.7 f16 B=64, M=1024, H=16, K=128 | 19213.5 | 23735.2 | 19228.8 | 21618.9 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
- Loading branch information