-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MemEff: Accumulate in f32 for bw #467
Conversation
[ghstack-poisoned]
ghstack-source-id: aff4f021abbadcf565a40f952eea873c7e5d3f09 Pull Request resolved: #467
[ghstack-poisoned]
ghstack-source-id: e35fdbfbadbdb88090dd7632719b6c3f071568a0 Pull Request resolved: #467
[ghstack-poisoned]
ghstack-source-id: 64b7aef3cdfba4dcaea31c651bb663a6421faec0 Pull Request resolved: #467
[ghstack-poisoned]
ghstack-source-id: 48369de3f8b94eb3c190ac2b0a1b3ddf6003e5ff Pull Request resolved: #467
[ghstack-poisoned]
ghstack-source-id: cded45d653d00d9147d31099040a18b66403b064 Pull Request resolved: #467
Codecov ReportBase: 89.79% // Head: 89.79% // No change to project coverage 👍
Additional details and impacted files@@ Coverage Diff @@
## gh/danthe3rd/48/base #467 +/- ##
=====================================================
Coverage 89.79% 89.79%
=====================================================
Files 80 80
Lines 4839 4839
=====================================================
Hits 4345 4345
Misses 494 494
Flags with carried forward coverage won't be shown. Click here to find out more. Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
[ghstack-poisoned]
ghstack-source-id: 037be211463d6b313f7c86804f6a89799f355d9d Pull Request resolved: #467
[ghstack-poisoned]
ghstack-source-id: d13ff3e407510b98ea4daf9cf38a74d966fa4d59 Pull Request resolved: #467
**PERFORMANCE** <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 02ddda11d414d63ce8d3693c60e6bf4430c2a5d9 Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 0dee5d90397264d6fd96a9a22b4a1788e9de425c Pull Request resolved: #467
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this Daniel!
Could you maybe adapt the tests to show that with this change the backward has now better numerics?
Also, would it make sense to split the PR out of this stack so that we can get the other PRs merged?
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 4759ee88b471d7e10e40c72d9b9d4a441c82aeaf Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 29c8918b52a893018f19cebeb5367bdd6202405b Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 18c8ba5fd2ce14141a4a39a1e6900707d17c19f0 Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 5af56eef32bfba5defa0673145be8c293b78c76f Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: b7b4fe67e89e46e691615a42676b04c3bc0779a6 Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 165c9a375a81b8c14ee7490c6f590d2f68b23df0 Pull Request resolved: #467
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: a7de549161fc563736f803189c103c8e5c7545e7 Pull Request resolved: #467
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and sorry for the delay!
Let's get this merged!
**PERFORMANCE** This makes performance worse in f16 :( But I think we need it for stability <details> <summary>bw P100/V100 (f32/f16)</summary> ``` [---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------] | 48_accf32_6f8e2f15 | 56_base_02bf6b4e | vanilla | 57_tmpT_b516aec4 1 threads: -------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6289.6 | 6936.0 | 2183.6 | 6940.0 f32 B=384, M=197, H=1, K=88 | 8793.3 | 9446.6 | 2175.1 | 9429.2 f16 B=384, M=197, H=1, K=80 | 5989.9 | 6596.2 | 2146.8 | 6608.6 f32 B=384, M=197, H=1, K=80 | 8427.1 | 8993.1 | 2134.0 | 9030.5 f16 B=384, M=197, H=1, K=64 | 3347.0 | 3527.4 | 1799.3 | 3538.3 f32 B=384, M=197, H=1, K=64 | 5563.1 | 5984.2 | 1801.6 | 5980.9 f16 B=1024, M=197, H=1, K=88 | 15680.5 | 17424.8 | 5671.9 | 17452.6 f32 B=1024, M=197, H=1, K=88 | 23784.2 | 25542.8 | 5664.0 | 25578.6 f16 B=1024, M=197, H=1, K=80 | 14935.5 | 16581.3 | 5559.5 | 16587.3 f32 B=1024, M=197, H=1, K=80 | 22767.6 | 24354.1 | 5550.7 | 24362.0 f16 B=1024, M=197, H=1, K=64 | 8331.1 | 8671.5 | 4644.9 | 8695.3 f32 B=1024, M=197, H=1, K=64 | 15061.6 | 16148.8 | 4650.0 | 16201.4 f16 B=512, M=197, H=1, K=80 | 7594.7 | 8306.6 | 2824.6 | 8336.5 f32 B=512, M=197, H=1, K=80 | 11857.0 | 12660.2 | 2807.4 | 12709.2 f16 B=32, M=197, H=16, K=80 | 7779.8 | 8342.0 | 2820.8 | 8393.5 f32 B=32, M=197, H=16, K=80 | 11846.5 | 12487.8 | 2806.0 | 12549.1 f16 B=32, M=197, H=16, K=64 | 4258.4 | 4445.6 | 2374.0 | 4461.2 f32 B=32, M=197, H=16, K=64 | 7828.9 | 8434.9 | 2376.0 | 8472.2 f16 B=32, M=197, H=16, K=128 | 9025.1 | 9671.9 | 3159.3 | 9707.2 f32 B=32, M=197, H=16, K=128 | 14139.8 | 14920.7 | 3157.7 | 14928.9 f16 B=256, M=197, H=1, K=88 | 4608.9 | 5118.8 | 1478.4 | 5119.2 f32 B=256, M=197, H=1, K=88 | 6174.0 | 6644.3 | 1477.6 | 6642.7 f16 B=16, M=197, H=16, K=88 | 4618.4 | 5073.4 | 1479.5 | 5062.9 f32 B=16, M=197, H=16, K=88 | 6114.2 | 6471.7 | 1474.9 | 6478.7 f16 B=16, M=197, H=16, K=64 | 2490.3 | 2557.5 | 1225.9 | 2550.0 f32 B=16, M=197, H=16, K=64 | 3918.9 | 4208.2 | 1227.0 | 4195.7 f16 B=16, M=197, H=16, K=128 | 5210.0 | 5649.3 | 1635.2 | 5648.1 f32 B=16, M=197, H=16, K=128 | 7103.1 | 7451.6 | 1645.5 | 7445.1 f16 B=1, M=4096, H=160, K=128 | 1014229.1 | 1106182.6 | | 1108040.6 f32 B=1, M=4096, H=160, K=128 | 1258173.2 | 1243183.8 | | 1241548.7 f16 B=2, M=4096, H=160, K=128 | 1642279.2 | 1753736.9 | | 1771655.3 f32 B=2, M=4096, H=160, K=128 | 2505435.4 | 2477353.1 | | 2473773.4 f16 B=1, M=8192, H=160, K=128 | 4050128.4 | 4415962.1 | | 4428194.9 f32 B=1, M=8192, H=160, K=128 | 5042352.6 | 4970582.0 | | 4965069.9 f16 B=2, M=8192, H=160, K=128 | 6600732.3 | 7026378.5 | | 7068051.8 f16 B=1024, M=82, H=8, K=64 | 21572.8 | 22531.6 | 9059.1 | 22418.4 f32 B=1024, M=82, H=8, K=64 | 38178.4 | 45927.2 | 9070.3 | 45708.0 f16 B=150, M=256, H=16, K=64 | 21436.5 | 21927.4 | 12938.5 | 22001.0 f32 B=150, M=256, H=16, K=64 | 33024.2 | 33196.3 | 13249.2 | 33199.6 f16 B=64, M=256, H=12, K=64 | 6869.8 | 7048.6 | 4200.6 | 7073.6 f32 B=64, M=256, H=12, K=64 | 10719.9 | 10832.1 | 4271.3 | 10843.6 f16 B=1, M=4096, H=16, K=40 | 134722.6 | 145429.8 | 20587.0 | 143743.8 f32 B=1, M=4096, H=16, K=40 | 143015.1 | 147850.6 | 20625.8 | 148272.8 f16 B=1, M=16384, H=16, K=40 | 2149850.4 | 2323732.4 | | 2301489.4 f32 B=1, M=16384, H=16, K=40 | 2286478.1 | 2369812.4 | | 2375911.8 f16 B=16, M=128, H=16, K=16 | 497.2 | 502.9 | 623.8 | 503.6 f32 B=16, M=128, H=16, K=16 | 573.5 | 609.8 | 617.5 | 611.6 f16 B=16, M=128, H=16, K=32 | 563.9 | 573.1 | 624.2 | 575.2 f32 B=16, M=128, H=16, K=32 | 661.5 | 702.2 | 620.4 | 703.2 f16 B=16, M=128, H=16, K=64 | 708.9 | 722.6 | 619.9 | 724.3 f32 B=16, M=128, H=16, K=64 | 916.2 | 953.5 | 618.6 | 953.6 f16 B=16, M=128, H=16, K=128 | 1465.8 | 1542.8 | 624.2 | 1545.7 f32 B=16, M=128, H=16, K=128 | 1829.1 | 1872.3 | 616.2 | 1873.9 f16 B=16, M=128, H=16, K=256 | 3796.3 | 4002.5 | 1010.7 | 4008.4 f32 B=16, M=128, H=16, K=256 | 3838.8 | 3957.1 | 1203.6 | 3951.8 f16 B=16, M=512, H=16, K=16 | 7680.4 | 7775.1 | 4848.7 | 7752.8 f32 B=16, M=512, H=16, K=16 | 9402.0 | 9926.1 | 4926.7 | 9914.5 f16 B=16, M=512, H=16, K=32 | 8897.5 | 8999.0 | 5055.9 | 9003.9 f32 B=16, M=512, H=16, K=32 | 10762.1 | 11320.9 | 5065.4 | 11341.8 f16 B=16, M=512, H=16, K=64 | 10936.9 | 11214.8 | 5484.9 | 11254.4 f32 B=16, M=512, H=16, K=64 | 15091.9 | 15210.3 | 5552.0 | 15196.3 f16 B=16, M=512, H=16, K=128 | 23491.0 | 25234.8 | 7317.1 | 25300.3 f32 B=16, M=512, H=16, K=128 | 30524.6 | 30320.0 | 7487.0 | 30200.6 f16 B=16, M=512, H=16, K=256 | 50389.7 | 54525.2 | 14015.3 | 54931.1 f32 B=16, M=512, H=16, K=256 | 62155.6 | 61046.4 | 14285.5 | 61081.1 f16 B=16, M=1024, H=16, K=16 | 31289.9 | 31951.9 | 18778.4 | 32044.9 f32 B=16, M=1024, H=16, K=16 | 37744.4 | 39586.6 | 18929.9 | 39739.9 f16 B=16, M=1024, H=16, K=32 | 35770.8 | 36651.1 | 19620.2 | 36909.0 f32 B=16, M=1024, H=16, K=32 | 43211.5 | 45544.5 | 19518.4 | 45664.8 f16 B=16, M=1024, H=16, K=64 | 43865.1 | 44864.8 | 21286.0 | 45345.7 f32 B=16, M=1024, H=16, K=64 | 60710.5 | 60966.9 | 21634.1 | 60921.9 f16 B=16, M=1024, H=16, K=128 | 94633.7 | 101796.0 | 28502.1 | 102871.6 f32 B=16, M=1024, H=16, K=128 | 124093.6 | 122196.9 | 28520.3 | 122043.0 f16 B=16, M=1024, H=16, K=256 | 194780.7 | 212303.0 | 55419.1 | 214643.5 f32 B=16, M=1024, H=16, K=256 | 250799.1 | 245196.2 | 55634.0 | 245534.2 f16 B=64, M=128, H=16, K=16 | 1658.0 | 1661.4 | 1331.0 | 1662.7 f32 B=64, M=128, H=16, K=16 | 2129.2 | 2266.2 | 1371.8 | 2269.4 f16 B=64, M=128, H=16, K=32 | 1904.3 | 1903.5 | 1384.2 | 1906.4 f32 B=64, M=128, H=16, K=32 | 2496.5 | 2643.4 | 1445.1 | 2639.8 f16 B=64, M=128, H=16, K=64 | 2393.1 | 2432.3 | 1505.2 | 2437.8 f32 B=64, M=128, H=16, K=64 | 3471.1 | 3561.2 | 1590.0 | 3565.7 f16 B=64, M=128, H=16, K=128 | 4969.1 | 5266.3 | 1988.4 | 5272.4 f32 B=64, M=128, H=16, K=128 | 6880.3 | 7040.6 | 2121.0 | 7024.3 f16 B=64, M=128, H=16, K=256 | 12635.3 | 13334.7 | 3859.5 | 13350.7 f32 B=64, M=128, H=16, K=256 | 14185.7 | 14514.4 | 4553.8 | 14503.1 f16 B=64, M=512, H=16, K=16 | 26278.4 | 26189.9 | 18916.0 | 26110.0 f32 B=64, M=512, H=16, K=16 | 35128.2 | 37075.9 | 19191.4 | 37174.5 f16 B=64, M=512, H=16, K=32 | 30414.2 | 30938.2 | 19734.7 | 31071.7 f32 B=64, M=512, H=16, K=32 | 40483.2 | 42922.5 | 19843.9 | 42868.5 f16 B=64, M=512, H=16, K=64 | 37248.2 | 38179.7 | 21640.1 | 38327.0 f32 B=64, M=512, H=16, K=64 | 57666.8 | 57675.3 | 21909.1 | 57697.2 f16 B=64, M=512, H=16, K=128 | 80113.6 | 86165.7 | 28765.5 | 86368.3 f32 B=64, M=512, H=16, K=128 | 115672.5 | 115161.0 | 28910.3 | 115320.0 f16 B=64, M=512, H=16, K=256 | 169250.0 | 183791.8 | 56315.7 | 183735.0 f32 B=64, M=512, H=16, K=256 | 236594.6 | 233093.6 | 56853.4 | 233170.2 f16 B=64, M=1024, H=16, K=16 | 106022.3 | 109588.2 | 74303.6 | 109410.4 f32 B=64, M=1024, H=16, K=16 | 141241.8 | 148854.1 | | 149651.5 f16 B=64, M=1024, H=16, K=32 | 120899.1 | 125044.6 | 77828.6 | 125716.9 f32 B=64, M=1024, H=16, K=32 | 162478.5 | 173906.1 | | 173216.4 f16 B=64, M=1024, H=16, K=64 | 149044.1 | 152290.6 | 85821.6 | 152748.5 f32 B=64, M=1024, H=16, K=64 | 233195.9 | 231479.6 | | 231533.9 f16 B=64, M=1024, H=16, K=128 | 319761.5 | 344076.5 | 113579.4 | 345058.5 f32 B=64, M=1024, H=16, K=128 | 470172.3 | 466330.7 | | 463507.4 f16 B=64, M=1024, H=16, K=256 | 658362.1 | 717070.2 | | 723057.4 f32 B=64, M=1024, H=16, K=256 | 955624.0 | 935114.3 | | 935945.4 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1811.3 | 1686.3 | 1375.2 | 1699.6 f32 B=384, M=197, H=1, K=88 | 4315.7 | 4665.1 | 2257.1 | 4663.1 f16 B=384, M=197, H=1, K=80 | 1733.0 | 1616.3 | 1281.7 | 1619.8 f32 B=384, M=197, H=1, K=80 | 3965.0 | 4226.4 | 2171.7 | 4228.3 f16 B=384, M=197, H=1, K=64 | 1135.2 | 1084.4 | 1043.8 | 1083.0 f32 B=384, M=197, H=1, K=64 | 2673.2 | 2883.0 | 1744.5 | 2878.2 f16 B=1024, M=197, H=1, K=88 | 4721.0 | 4396.6 | 3725.3 | 4404.1 f32 B=1024, M=197, H=1, K=88 | 10531.3 | 11443.8 | 6106.5 | 11464.2 f16 B=1024, M=197, H=1, K=80 | 4520.2 | 4216.2 | 3329.2 | 4223.7 f32 B=1024, M=197, H=1, K=80 | 9573.5 | 10301.4 | 5757.6 | 10305.3 f16 B=1024, M=197, H=1, K=64 | 2788.1 | 2660.1 | 2674.6 | 2663.5 f32 B=1024, M=197, H=1, K=64 | 6556.7 | 7102.4 | 4516.6 | 7096.8 f16 B=512, M=197, H=1, K=80 | 2377.2 | 2228.1 | 1685.0 | 2231.4 f32 B=512, M=197, H=1, K=80 | 5259.3 | 5639.7 | 2887.8 | 5636.6 f16 B=32, M=197, H=16, K=80 | 2403.0 | 2201.1 | 1798.2 | 2204.9 f32 B=32, M=197, H=16, K=80 | 5402.3 | 5667.7 | 3046.3 | 5662.4 f16 B=32, M=197, H=16, K=64 | 1552.7 | 1486.3 | 1451.6 | 1485.2 f32 B=32, M=197, H=16, K=64 | 3622.5 | 3911.0 | 2427.0 | 3912.9 f16 B=32, M=197, H=16, K=128 | 2776.6 | 2611.9 | 2211.3 | 2613.3 f32 B=32, M=197, H=16, K=128 | 6647.9 | 7082.3 | 4088.5 | 7104.8 f16 B=256, M=197, H=1, K=88 | 1357.5 | 1285.5 | 941.3 | 1287.6 f32 B=256, M=197, H=1, K=88 | 2874.1 | 3085.1 | 1543.2 | 3090.1 f16 B=16, M=197, H=16, K=88 | 1349.1 | 1264.5 | 964.7 | 1263.1 f32 B=16, M=197, H=16, K=88 | 2803.8 | 2967.0 | 1647.4 | 2972.0 f16 B=16, M=197, H=16, K=64 | 765.5 | 728.4 | 765.3 | 731.0 f32 B=16, M=197, H=16, K=64 | 1834.1 | 1969.7 | 1282.4 | 1974.4 f16 B=16, M=197, H=16, K=128 | 1509.1 | 1432.6 | 1139.1 | 1433.9 f32 B=16, M=197, H=16, K=128 | 3406.5 | 3606.2 | 2048.6 | 3613.3 f16 B=1, M=4096, H=160, K=128 | 168807.1 | 148652.9 | | 149343.8 f32 B=1, M=4096, H=160, K=128 | 549864.6 | 586699.9 | | 585699.9 f16 B=2, M=4096, H=160, K=128 | 339010.5 | 298827.7 | | 298808.4 f32 B=2, M=4096, H=160, K=128 | 1106963.8 | 1176218.3 | | 1179173.2 f16 B=1, M=8192, H=160, K=128 | 679742.4 | 594323.2 | | 595580.1 f32 B=1, M=8192, H=160, K=128 | 2195491.3 | 2340248.4 | | 2343505.7 f16 B=2, M=8192, H=160, K=128 | 1364983.7 | 1193787.8 | | 1192596.1 f16 B=1024, M=82, H=8, K=64 | 9052.8 | 8762.6 | 5804.2 | 8757.8 f32 B=1024, M=82, H=8, K=64 | 14726.3 | 16270.6 | 11059.7 | 16215.9 f16 B=150, M=256, H=16, K=64 | 5662.9 | 5519.4 | 7557.8 | 5524.5 f32 B=150, M=256, H=16, K=64 | 16700.3 | 17612.7 | 16426.0 | 17640.4 f16 B=64, M=256, H=12, K=64 | 1849.1 | 1793.1 | 2383.5 | 1798.7 f32 B=64, M=256, H=12, K=64 | 5451.5 | 5766.4 | 4975.3 | 5775.8 f16 B=1, M=4096, H=16, K=40 | 47263.3 | 47850.4 | 8315.6 | 47777.2 f32 B=1, M=4096, H=16, K=40 | 113099.4 | 113164.5 | 19536.0 | 113930.0 f16 B=1, M=16384, H=16, K=40 | 757091.8 | 770365.7 | | 765401.3 f32 B=1, M=16384, H=16, K=40 | 1806827.7 | 1816302.6 | | 1819162.1 f16 B=16, M=128, H=16, K=16 | 219.2 | 218.5 | 480.6 | 231.8 f32 B=16, M=128, H=16, K=16 | 301.2 | 308.2 | 498.9 | 307.9 f16 B=16, M=128, H=16, K=32 | 227.6 | 215.4 | 473.8 | 220.1 f32 B=16, M=128, H=16, K=32 | 395.8 | 401.0 | 455.1 | 400.8 f16 B=16, M=128, H=16, K=64 | 225.6 | 214.8 | 510.1 | 229.9 f32 B=16, M=128, H=16, K=64 | 561.7 | 583.1 | 598.9 | 581.2 f16 B=16, M=128, H=16, K=128 | 404.6 | 392.0 | 524.0 | 394.8 f32 B=16, M=128, H=16, K=128 | 1103.0 | 1140.8 | 1015.4 | 1142.0 f16 B=16, M=128, H=16, K=256 | 1045.3 | 1049.1 | 889.6 | 1047.4 f32 B=16, M=128, H=16, K=256 | 2181.3 | 2270.9 | 1869.2 | 2265.7 f16 B=16, M=512, H=16, K=16 | 1731.5 | 1585.7 | 1908.5 | 1586.2 f32 B=16, M=512, H=16, K=16 | 4513.8 | 4696.7 | 4222.1 | 4695.6 f16 B=16, M=512, H=16, K=32 | 1942.2 | 1823.4 | 2086.3 | 1809.7 f32 B=16, M=512, H=16, K=32 | 5596.1 | 5819.4 | 4588.1 | 5833.6 f16 B=16, M=512, H=16, K=64 | 2450.2 | 2340.8 | 2580.6 | 2353.2 f32 B=16, M=512, H=16, K=64 | 7619.5 | 7853.1 | 5536.8 | 7875.1 f16 B=16, M=512, H=16, K=128 | 4884.8 | 4473.5 | 3388.7 | 4487.9 f32 B=16, M=512, H=16, K=128 | 15010.0 | 15557.1 | 8979.4 | 15513.1 f16 B=16, M=512, H=16, K=256 | 12973.9 | 11134.7 | 5418.2 | 11106.3 f32 B=16, M=512, H=16, K=256 | 29751.1 | 30979.1 | 16856.3 | 31012.8 f16 B=16, M=1024, H=16, K=16 | 6802.6 | 6185.1 | 6996.0 | 6191.0 f32 B=16, M=1024, H=16, K=16 | 18098.2 | 18954.5 | 16129.1 | 19166.8 f16 B=16, M=1024, H=16, K=32 | 7531.9 | 7065.0 | 7436.0 | 7067.5 f32 B=16, M=1024, H=16, K=32 | 21999.9 | 22901.5 | 17040.0 | 22899.4 f16 B=16, M=1024, H=16, K=64 | 9312.0 | 8854.3 | 8605.6 | 8863.9 f32 B=16, M=1024, H=16, K=64 | 29837.7 | 30895.1 | 20355.3 | 30878.5 f16 B=16, M=1024, H=16, K=128 | 18979.0 | 16951.0 | 10561.3 | 16995.2 f32 B=16, M=1024, H=16, K=128 | 58738.3 | 60861.2 | 33427.0 | 60599.8 f16 B=16, M=1024, H=16, K=256 | 49681.9 | 41833.3 | 17329.5 | 41921.6 f32 B=16, M=1024, H=16, K=256 | 117362.4 | 121004.8 | 60515.8 | 122046.9 f16 B=64, M=128, H=16, K=16 | 432.2 | 411.1 | 642.1 | 411.3 f32 B=64, M=128, H=16, K=16 | 1028.9 | 1057.9 | 1233.4 | 1056.6 f16 B=64, M=128, H=16, K=32 | 522.5 | 500.7 | 813.6 | 499.7 f32 B=64, M=128, H=16, K=32 | 1403.5 | 1443.4 | 1535.6 | 1443.6 f16 B=64, M=128, H=16, K=64 | 750.0 | 739.8 | 1185.6 | 741.0 f32 B=64, M=128, H=16, K=64 | 2013.9 | 2110.1 | 2156.8 | 2105.4 f16 B=64, M=128, H=16, K=128 | 1421.2 | 1387.9 | 1915.5 | 1388.3 f32 B=64, M=128, H=16, K=128 | 3946.5 | 4156.7 | 3780.1 | 4156.3 f16 B=64, M=128, H=16, K=256 | 3811.5 | 3810.6 | 3448.7 | 3811.0 f32 B=64, M=128, H=16, K=256 | 7983.9 | 8432.7 | 7304.0 | 8415.5 f16 B=64, M=512, H=16, K=16 | 6157.4 | 5523.9 | 7461.8 | 5528.1 f32 B=64, M=512, H=16, K=16 | 16118.2 | 17004.7 | 16651.7 | 16984.5 f16 B=64, M=512, H=16, K=32 | 6985.1 | 6483.1 | 8278.5 | 6495.1 f32 B=64, M=512, H=16, K=32 | 20471.9 | 21230.7 | 18420.5 | 21269.0 f16 B=64, M=512, H=16, K=64 | 9045.8 | 8633.5 | 10337.1 | 8674.1 f32 B=64, M=512, H=16, K=64 | 27675.0 | 29282.8 | 22883.7 | 29136.9 f16 B=64, M=512, H=16, K=128 | 17594.2 | 15805.9 | 14700.6 | 15788.2 f32 B=64, M=512, H=16, K=128 | 54612.4 | 57815.7 | 39974.3 | 57951.8 f16 B=64, M=512, H=16, K=256 | 47452.6 | 40093.5 | 27087.5 | 40175.6 f32 B=64, M=512, H=16, K=256 | 108880.3 | 115953.1 | 77794.2 | 115951.0 f16 B=64, M=1024, H=16, K=16 | 24369.0 | 21533.0 | 28448.6 | 21556.6 f32 B=64, M=1024, H=16, K=16 | 64649.7 | 68791.8 | | 68407.4 f16 B=64, M=1024, H=16, K=32 | 27143.1 | 25683.7 | 30252.7 | 25727.9 f32 B=64, M=1024, H=16, K=32 | 79967.5 | 83351.5 | | 83084.7 f16 B=64, M=1024, H=16, K=64 | 34667.0 | 32592.7 | 36991.2 | 32659.8 f32 B=64, M=1024, H=16, K=64 | 108282.2 | 113858.1 | | 114286.0 f16 B=64, M=1024, H=16, K=128 | 68519.5 | 59757.3 | 48834.4 | 59817.7 f32 B=64, M=1024, H=16, K=128 | 215465.9 | 227335.8 | | 227204.0 f16 B=64, M=1024, H=16, K=256 | 183070.4 | 150960.1 | | 150947.1 f32 B=64, M=1024, H=16, K=256 | 425832.9 | 453717.2 | | 453349.2 Times are in microseconds (us). ``` </details> <details> <summary>bw A100 (f32/f16)</summary> ``` [----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------] | 48_accf32_69654fdb[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] | 57_tmpT_b516aec4[cutlass] 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 613.4 | | 2264.9 | 612.8 | 578.3 f32 B=384, M=197, H=1, K=88 | 2335.2 | | 1843.0 | 2438.4 | 2425.1 f16 B=384, M=197, H=1, K=80 | 583.6 | | 1922.9 | 577.6 | 548.4 f32 B=384, M=197, H=1, K=80 | 2241.3 | | 1787.8 | 2333.2 | 2333.1 f16 B=384, M=197, H=1, K=64 | 405.6 | 232.5 | 1809.8 | 386.4 | 366.1 f32 B=384, M=197, H=1, K=64 | 1259.7 | | 1675.6 | 1309.6 | 1316.8 f16 B=1024, M=197, H=1, K=88 | 1538.2 | | 5964.7 | 1550.8 | 1454.4 f32 B=1024, M=197, H=1, K=88 | 6031.4 | | 4559.5 | 6325.4 | 6332.7 f16 B=1024, M=197, H=1, K=80 | 1458.8 | | 5038.2 | 1463.8 | 1379.0 f32 B=1024, M=197, H=1, K=80 | 5786.2 | | 4412.6 | 6059.0 | 6079.9 f16 B=1024, M=197, H=1, K=64 | 929.5 | 575.9 | 4735.6 | 862.1 | 821.6 f32 B=1024, M=197, H=1, K=64 | 3289.9 | | 4119.9 | 3434.4 | 3441.4 f16 B=512, M=197, H=1, K=80 | 744.5 | | 2544.7 | 735.3 | 695.8 f32 B=512, M=197, H=1, K=80 | 2889.4 | | 2286.1 | 3008.8 | 3029.0 f16 B=32, M=197, H=16, K=80 | 741.6 | | 2569.0 | 723.3 | 693.4 f32 B=32, M=197, H=16, K=80 | 2878.6 | | 2355.3 | 3003.0 | 3024.1 f16 B=32, M=197, H=16, K=64 | 478.5 | 295.7 | 2429.2 | 456.1 | 426.7 f32 B=32, M=197, H=16, K=64 | 1784.4 | | 2196.9 | 1863.7 | 1860.2 f16 B=32, M=197, H=16, K=128 | 887.0 | 682.3 | 4492.9 | 857.5 | 853.8 f32 B=32, M=197, H=16, K=128 | 3546.6 | | 2807.3 | 3734.9 | 3737.1 f16 B=256, M=197, H=1, K=88 | 445.1 | | 1528.9 | 445.0 | 422.2 f32 B=256, M=197, H=1, K=88 | 1678.8 | | 1207.6 | 1746.6 | 1752.8 f16 B=16, M=197, H=16, K=88 | 441.5 | | 1544.2 | 437.3 | 419.5 f32 B=16, M=197, H=16, K=88 | 1668.3 | | 1250.4 | 1742.7 | 1746.1 f16 B=16, M=197, H=16, K=64 | 247.4 | 165.6 | 1242.5 | 233.2 | 217.5 f32 B=16, M=197, H=16, K=64 | 1051.1 | | 1125.3 | 1099.2 | 1096.6 f16 B=16, M=197, H=16, K=128 | 498.4 | 386.2 | 2264.5 | 488.0 | 480.5 f32 B=16, M=197, H=16, K=128 | 1950.2 | | 1446.7 | 2039.0 | 2028.6 f16 B=1, M=4096, H=160, K=128 | 55915.0 | 54620.4 | 45909.5 | 63407.6 | 51298.8 f32 B=1, M=4096, H=160, K=128 | 238514.6 | | | 232677.5 | 232672.5 f16 B=2, M=4096, H=160, K=128 | 93612.0 | 84238.0 | | 100433.1 | 84858.1 f32 B=2, M=4096, H=160, K=128 | 375037.4 | | | 364234.1 | 364407.2 f16 B=1, M=8192, H=160, K=128 | 223261.8 | 215499.8 | | 251806.6 | 202133.7 f32 B=1, M=8192, H=160, K=128 | 946708.9 | | | 924986.2 | 924988.9 f16 B=2, M=8192, H=160, K=128 | 367969.2 | 330092.8 | | 395881.3 | 332000.6 f32 B=2, M=8192, H=160, K=128 | 1492031.4 | | | 1448691.1 | 1449146.1 f16 B=1024, M=82, H=8, K=64 | 1890.2 | 1620.3 | 3819.7 | 1861.3 | 1764.6 f32 B=1024, M=82, H=8, K=64 | 8428.2 | | 8735.1 | 8831.3 | 8867.4 f16 B=150, M=256, H=16, K=64 | 2292.3 | 1625.3 | 4555.9 | 2109.8 | 2019.4 f32 B=150, M=256, H=16, K=64 | 6252.4 | | 12948.1 | 6288.9 | 6281.3 f16 B=64, M=256, H=12, K=64 | 782.2 | 567.4 | 1498.0 | 731.2 | 699.7 f32 B=64, M=256, H=12, K=64 | 2141.2 | | 4266.6 | 2160.6 | 2160.2 f16 B=1, M=4096, H=16, K=40 | 23504.2 | | 4196.1 | 23699.0 | 23008.9 f32 B=1, M=4096, H=16, K=40 | 73699.5 | | 17755.3 | 73261.8 | 73078.5 f16 B=1, M=16384, H=16, K=40 | 391408.9 | | | 439777.3 | 407653.7 f32 B=1, M=16384, H=16, K=40 | 1196173.6 | | | 1181547.9 | 1181625.1 f16 B=256, M=4096, H=16, K=64 | 733221.8 | 439627.5 | | 603237.1 | 565905.8 f16 B=16, M=128, H=16, K=16 | 130.0 | 113.1 | 265.2 | 125.5 | 124.4 f32 B=16, M=128, H=16, K=16 | 161.5 | | 373.1 | 160.6 | 162.8 f16 B=16, M=128, H=16, K=32 | 125.8 | 111.5 | 263.8 | 122.1 | 125.8 f32 B=16, M=128, H=16, K=32 | 189.8 | | 412.6 | 196.3 | 196.2 f16 B=16, M=128, H=16, K=64 | 126.0 | 112.4 | 265.8 | 120.8 | 125.7 f32 B=16, M=128, H=16, K=64 | 272.1 | | 498.7 | 285.9 | 283.3 f16 B=16, M=128, H=16, K=128 | 181.3 | 158.4 | 298.5 | 178.0 | 186.5 f32 B=16, M=128, H=16, K=128 | 509.6 | | 673.9 | 521.1 | 521.5 f16 B=16, M=128, H=16, K=256 | 774.0 | | 541.4 | 775.5 | 757.0 f32 B=16, M=128, H=16, K=256 | 975.2 | | 1162.6 | 994.5 | 994.5 f16 B=16, M=512, H=16, K=16 | 621.0 | 322.6 | 1204.9 | 555.0 | 519.9 f32 B=16, M=512, H=16, K=16 | 2148.0 | | 4414.4 | 2178.9 | 2180.4 f16 B=16, M=512, H=16, K=32 | 709.9 | 435.6 | 1306.2 | 653.1 | 602.8 f32 B=16, M=512, H=16, K=32 | 2335.6 | | 4640.7 | 2336.0 | 2336.3 f16 B=16, M=512, H=16, K=64 | 917.8 | 702.7 | 1545.8 | 849.9 | 797.4 f32 B=16, M=512, H=16, K=64 | 2965.5 | | 5125.0 | 2986.9 | 2988.3 f16 B=16, M=512, H=16, K=128 | 1644.0 | 1584.1 | 1983.4 | 1757.0 | 1548.0 f32 B=16, M=512, H=16, K=128 | 6152.7 | | 6099.1 | 6067.7 | 6068.6 f16 B=16, M=512, H=16, K=256 | 8178.3 | | 2899.1 | 7895.5 | 7977.4 f32 B=16, M=512, H=16, K=256 | 11894.0 | | 10639.5 | 11635.0 | 11624.7 f16 B=16, M=1024, H=16, K=16 | 2420.5 | 1240.8 | 4259.2 | 2234.6 | 2048.9 f32 B=16, M=1024, H=16, K=16 | 8467.2 | | 16650.4 | 8512.2 | 8510.2 f16 B=16, M=1024, H=16, K=32 | 2675.7 | 1618.9 | 4491.2 | 2441.3 | 2230.0 f32 B=16, M=1024, H=16, K=32 | 9010.0 | | 17301.0 | 9012.1 | 9015.7 f16 B=16, M=1024, H=16, K=64 | 3328.4 | 2370.3 | 4994.5 | 3032.2 | 2820.0 f32 B=16, M=1024, H=16, K=64 | 11566.8 | | 18714.2 | 11494.1 | 11492.8 f16 B=16, M=1024, H=16, K=128 | 5867.9 | 5632.5 | 5952.8 | 6401.1 | 5440.8 f32 B=16, M=1024, H=16, K=128 | 23345.4 | | 21523.7 | 22859.1 | 22870.3 f16 B=16, M=1024, H=16, K=256 | 30619.2 | | 7893.1 | 29884.9 | 29060.4 f32 B=16, M=1024, H=16, K=256 | 45211.4 | | 38093.0 | 43435.3 | 43423.8 f16 B=64, M=128, H=16, K=16 | 159.6 | 145.2 | 439.9 | 161.2 | 167.0 f32 B=64, M=128, H=16, K=16 | 493.4 | | 1270.0 | 502.7 | 503.2 f16 B=64, M=128, H=16, K=32 | 208.5 | 212.1 | 545.3 | 206.2 | 204.9 f32 B=64, M=128, H=16, K=32 | 601.2 | | 1427.0 | 610.5 | 610.6 f16 B=64, M=128, H=16, K=64 | 329.0 | 310.8 | 766.0 | 327.2 | 314.7 f32 B=64, M=128, H=16, K=64 | 867.7 | | 1743.5 | 889.2 | 889.0 f16 B=64, M=128, H=16, K=128 | 635.5 | 562.1 | 1226.7 | 613.3 | 650.7 f32 B=64, M=128, H=16, K=128 | 1774.4 | | 2386.5 | 1800.0 | 1799.6 f16 B=64, M=128, H=16, K=256 | 2839.8 | | 2122.8 | 2765.2 | 2821.9 f32 B=64, M=128, H=16, K=256 | 3419.2 | | 4320.7 | 3458.8 | 3457.0 f16 B=64, M=512, H=16, K=16 | 2316.2 | 1202.3 | 4487.1 | 1983.9 | 1868.4 f32 B=64, M=512, H=16, K=16 | 6686.0 | | 16991.5 | 6709.9 | 6713.2 f16 B=64, M=512, H=16, K=32 | 2701.7 | 1541.9 | 4975.9 | 2346.4 | 2184.3 f32 B=64, M=512, H=16, K=32 | 7460.5 | | 17859.9 | 7429.6 | 7438.0 f16 B=64, M=512, H=16, K=64 | 3461.2 | 2418.2 | 5886.1 | 3083.1 | 2942.6 f32 B=64, M=512, H=16, K=64 | 9553.4 | | 19768.6 | 9526.5 | 9516.6 f16 B=64, M=512, H=16, K=128 | 5875.5 | 5443.1 | 7711.0 | 6141.3 | 5513.6 f32 B=64, M=512, H=16, K=128 | 21317.0 | | 23651.2 | 20931.9 | 20932.3 f16 B=64, M=512, H=16, K=256 | 31238.2 | | 11490.6 | 28626.8 | 28954.0 f32 B=64, M=512, H=16, K=256 | 41124.8 | | 42468.0 | 39562.8 | 39579.5 f16 B=64, M=1024, H=16, K=16 | 9142.8 | 4707.2 | 16882.8 | 7887.0 | 7314.5 f32 B=64, M=1024, H=16, K=16 | 26512.3 | | 66311.7 | 26497.5 | 26459.8 f16 B=64, M=1024, H=16, K=32 | 10420.8 | 5698.3 | 17875.0 | 8851.4 | 7974.0 f32 B=64, M=1024, H=16, K=32 | 28300.0 | | 69088.7 | 28102.1 | 28098.2 f16 B=64, M=1024, H=16, K=64 | 12948.5 | 8119.0 | 19944.3 | 11064.2 | 10477.6 f32 B=64, M=1024, H=16, K=64 | 35600.6 | | 74762.8 | 35316.4 | 35361.2 f16 B=64, M=1024, H=16, K=128 | 20820.5 | 19184.2 | 23699.3 | 21954.8 | 19220.0 f32 B=64, M=1024, H=16, K=128 | 80800.3 | | 86003.8 | 78521.9 | 78393.9 f16 B=64, M=1024, H=16, K=256 | 114411.1 | | 32958.3 | 103287.2 | 104304.2 f32 B=64, M=1024, H=16, K=256 | 155731.5 | | 153011.0 | 148071.6 | 148165.6 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: f713589d43273c6785ba6e3ae92e0974ef8ccfba Pull Request resolved: #467
Stack from ghstack (oldest at bottom):
PERFORMANCE
This makes performance worse in f16 :(
But I think we need it for stability
bw P100/V100 (f32/f16)
bw A100 (f32/f16)