diff --git a/BENCHMARKS.md b/BENCHMARKS.md index 5123419c50..190eaf8e82 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -37,6 +37,8 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py ## Triton layers +Please not that as of November 2022 these layers are not optimized for typical production GPUs out there (not developed for some time and mostly tested on a laptop GPU), and that better performances are probably possible with some minor changes as proven in other libraries since xformers went out. + ### Fused softmax You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12. diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png index 4b01969654..19b0236d21 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png index 4d969d9602..d087b6fc85 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png index dc93c1a99b..d79eef6632 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png index 41e95e9656..d90c5949ae 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png index 33e6949735..27d547e54f 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png index a992f5b973..0e0e0ea386 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png index 6219071691..f8163a18ab 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png index cb092efbac..5845bd24a2 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png index 12b9e19fb4..ae8ec6a71c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png index 681a8bf7d3..6bc8e34414 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png index fcc4480db5..d0bcd447e5 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png index 8889349a46..d8306f41a5 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png index bb1f3168ca..2d6e6b446e 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png index ba8109f341..d0432b3699 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png index f37a780e6b..04b69772da 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png index 4ad91ec132..5b8b6a2609 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png index a78c14e75f..64c429ef69 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png index cd635ce765..e118413241 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png index f5b0633a12..210bef5221 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png index 3678c34b18..a491e87a48 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png index 8dd7ea07a9..effbd7c15c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png index f5a7465510..7a104de755 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png index 4b60ff3e91..05ebd12a35 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png index d7e3f28351..6d69ce39d1 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png index 144f79774e..0877126b95 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png index 27f4d3cc73..0ea00f8f3b 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png index 8bab52d57f..2eb50535c0 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png index ed4765510a..ade66789b2 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png index e1189c2379..acf74a4510 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png index 2b19a25464..be4f749910 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_star_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_star_relu.png new file mode 100644 index 0000000000..696a14931a Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_star_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png index b27e8854d4..976734b210 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png index 19d31b8b38..e7b83deacc 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png index db9fa9fae7..fc8aa1eabc 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png index 07d3b4da5a..35374d8f47 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png index bbd6c36328..4d687c2e52 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png index 1005338fb7..57a23a3773 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_star_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_star_relu.png new file mode 100644 index 0000000000..fe0bf8ddb3 Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp16_FW_star_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png index 86531131fc..a7c66db23c 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png index 60c2036656..1b116a5fdd 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png index c8a223a57b..c40830aad9 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png index 048287650a..dc011a1f2a 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_smelu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_smelu.png new file mode 100644 index 0000000000..ec18489aa7 Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_smelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png index d0fcf5e2bd..bfe79c75e4 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_star_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_star_relu.png new file mode 100644 index 0000000000..e6248fe7c5 Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp32_FW_BW_star_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png index d116457c58..c6c75327d9 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png index 54db85af26..dfa0244378 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_none.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_none.png index 93e255972b..81aa4bb2e2 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_none.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png index 8077fb6840..c7c46b1edd 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_smelu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_smelu.png new file mode 100644 index 0000000000..70f35ffe26 Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp32_FW_smelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png index 577d9cf87c..d2bc34ad28 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp32_FW_star_relu.png b/docs/plots/fused_linear/FusedLinear_fp32_FW_star_relu.png new file mode 100644 index 0000000000..59f1416cee Binary files /dev/null and b/docs/plots/fused_linear/FusedLinear_fp32_FW_star_relu.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png index 2bb20cd112..6a4a14f0af 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png index 17600ed1e5..1cb22007ed 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png index 4ccf849881..5572853ca4 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png index 61cb30654b..9f92b6393b 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png index 297ec3c667..d8ac3c74a4 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png and b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png index 256b8d9f37..3783eb4a57 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png and b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png b/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png index 4000421a80..4324aefaca 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png and b/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png b/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png index 6012e52e74..d928e28a41 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png and b/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png differ diff --git a/requirements-test.txt b/requirements-test.txt index 9798451267..b6cc8d5aac 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -27,4 +27,4 @@ hydra-core >= 1.1 fairscale >= 0.4.5 # Dependency for fused layers, optional -triton==2.0.0.dev20221014 +triton==2.0.0.dev20221105 diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 13265b6d99..18a9efc18f 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -17,6 +17,8 @@ if _triton_available: try: + import triton + from xformers.triton import dropout as triton_dropout from xformers.triton.dropout import FusedDropoutBias from xformers.triton.utils import gpu_capabilities_older_than_70 @@ -130,7 +132,7 @@ def test_dropout(shape, amp, bias, p): torch.cuda.manual_seed(0) y_2 = triton_dropout(x, p=0.5) - assert torch.allclose(y_1, y_2) + triton.testing.assert_almost_equal(y_1, y_2) @pytest.mark.skipif(not _gpu_available, reason="GPU is not available") diff --git a/tests/test_triton_fused_linear.py b/tests/test_triton_fused_linear.py index 9499edaf72..479ab9ba29 100644 --- a/tests/test_triton_fused_linear.py +++ b/tests/test_triton_fused_linear.py @@ -65,7 +65,8 @@ def test_fused_matmul(shape, dtype): ) # Now check that adding an activation to the mix still produces valid results - # NOTE: SquaredReLU fails, some outlier representation issue + # NOTE: SquaredReLU fails, some outlier representation issue but the eyeballed results look reasonable + # could be due to a different accumulation out of the box (tf32 for instance) for activation in filter( lambda x: x not in (Activation.SquaredReLU, Activation.StarReLU), Activation ): diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 9a165b2f7d..d36a05646b 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -18,7 +18,7 @@ from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw BLOCK_M = 32 -BLOCK_N = 128 +BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf # Helper to handle the SPMD launch grid and error cases @@ -36,7 +36,7 @@ def forward(ctx, x, p, bias, activation, trainable_bias): def grid(meta): return ( - triton.cdiv(M, meta["BLOCK_M"]), # 4 x + triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]), ) @@ -101,7 +101,7 @@ def backward( # - over N we compromise in between trying to use as much memory paralellism as possible, # (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too # big because of register spilling - N_BLOCKS_M = triton.cdiv(M, BLOCK_M) # 4x + N_BLOCKS_M = triton.cdiv(M, BLOCK_M) if ctx.trainable_bias: grad_bias = torch.empty( diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index edde0c5ff1..ebfbbbd785 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -31,9 +31,6 @@ triton.Config({}, num_warps=16), ] -MAX_INT32 = 2147483647 -MAX_UINT32 = 4294967295 - # fmt: off @triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]}) @@ -66,7 +63,7 @@ def k_dropout_fw( # fmt: on row_id = tl.program_id(axis=0) - rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) # 4x + rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) col_id = tl.program_id(axis=1) cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) @@ -106,20 +103,12 @@ def k_dropout_fw( # get the random keep mask rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) seed_int = tl.load(SEEDS + col_id) + r = tl.rand(seed_int, rand_offsets) + keep_mask = r > p - if 1: - r = tl.rand(seed_int, rand_offsets) - keep_mask = r > p - - # prune and normalize in one go - keep = tl.reshape(keep_mask, x.shape) - output = tl.where(keep, (x * p_scale).to(x.dtype), 0.) - else: - r0, r1, r2, r3 = tl.randint4x(seed_int, rand_offsets) - r = tl.cat(tl.cat(r0, r1), tl.cat(r2, r3)) - r = r.to(tl.uint32, bitcast=True) - r = tl.reshape(r, x.shape) - output = tl.where(r > p * MAX_UINT32, x * p_scale, 0.) + # prune and normalize in one go + keep = tl.reshape(keep_mask, x.shape) + output = tl.where(keep, (x * p_scale).to(x.dtype), 0.) tl.store(y_ptrs, output, mask=block_mask) # output @@ -158,7 +147,7 @@ def k_dropout_bw( # fmt: on row_id = tl.program_id(axis=0) - rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) # 4x + rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) col_id = tl.program_id(axis=1) cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) @@ -206,16 +195,9 @@ def k_dropout_bw( # from the same seeds, so the same drop mask is applied here rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) seed_int = tl.load(SEEDS + col_id) - if 1: - r = tl.rand(seed_int, rand_offsets) - r = tl.reshape(r, grad_out.shape) - output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.) - else: - r0, r1, r2, r3 = tl.randint4x(seed_int, rand_offsets) - r = tl.cat(tl.cat(r0, r1), tl.cat(r2, r3)) - r = r.to(tl.uint32, bitcast=True) - r = tl.reshape(r, inputs.shape) - output = tl.where(r > p * MAX_UINT32, grad_out * p_scale, 0.) + r = tl.rand(seed_int, rand_offsets) + r = tl.reshape(r, grad_out.shape) + output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.) # write-back tl.store(grad_in_ptrs, output, mask=block_mask) diff --git a/xformers/triton/k_fused_matmul_bw.py b/xformers/triton/k_fused_matmul_bw.py index 278509888e..f323cda9fc 100644 --- a/xformers/triton/k_fused_matmul_bw.py +++ b/xformers/triton/k_fused_matmul_bw.py @@ -18,18 +18,16 @@ squared_relu_grad, star_relu_grad, ) -from xformers.triton.sum_strided import sum_2d_dim_0 # fmt: off @triton.autotune( configs=[ - triton.Config({"BLOCK_N": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_N": 64}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2), + triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4), ], key=["N"], ) @@ -155,9 +153,9 @@ def fused_matmul_backward( # just before the activation grad_out_ = grad_act - # The following ops can also be handled by triton - grad_in = grad_out_ @ weight + # The following ops can also be handled by pytorch + grad_in = triton.ops.matmul(grad_out_, weight) grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None - grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None + grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None return grad_in.reshape_as(inputs), grad_weight, grad_bias diff --git a/xformers/triton/k_fused_matmul_fw.py b/xformers/triton/k_fused_matmul_fw.py index 0e242b82bb..f14aad35e2 100644 --- a/xformers/triton/k_fused_matmul_fw.py +++ b/xformers/triton/k_fused_matmul_fw.py @@ -21,24 +21,64 @@ # CREDITS: Initially inspired by the Triton tutorial on matrix multiplications +def get_configs(block_k): + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k}, + num_stages=3, + num_warps=8, + ), + ] + + # fmt: off @triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 16, "BLOCK_N": 16}, num_stages=5, num_warps=1), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=5, num_warps=1), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_stages=3, num_warps=8), - ], + configs=[c for block_k in [32, 64] for c in get_configs(block_k)], key=["M", "N", "K"], ) +@triton.heuristics({ + 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, +}) @triton.jit def kernel_fma( # Pointers to matrices @@ -53,6 +93,7 @@ def kernel_fma( # Meta-parameters BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_N: tl.constexpr, BIAS: tl.constexpr, SAVE_ACT_INPUTS: tl.constexpr, ACTIVATION: tl.constexpr, @@ -110,7 +151,10 @@ def kernel_fma( acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) if BIAS: - bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) + if EVEN_N: + bias = tl.load(bias + rn).to(tl.float32) + else: + bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) acc += bias[None, :] # block level matrix multiplication. @@ -125,6 +169,9 @@ def kernel_fma( acc += tl.dot(a, w) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # optional: save the activation inputs if SAVE_ACT_INPUTS: act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :] @@ -184,7 +231,6 @@ def fused_matmul( # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - BLOCK_K = 32 if K < 1024 else 64 # fmt: off kernel_fma[grid]( @@ -196,7 +242,6 @@ def fused_matmul( ACTIVATION=activation, # optional fused activation BIAS=bias is not None, # optional fused bias GROUP_M=8, # speed optimization: group the programs - BLOCK_K=BLOCK_K, SAVE_ACT_INPUTS=save_act_inputs, is_fp16=x_.dtype == torch.float16 )