From 7094ea5e750266e16c16c7aecac8fc03294ecaa3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 May 2024 09:38:19 +0300 Subject: [PATCH] whisper : use flash attention (#2152) * whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log --- examples/bench/bench.cpp | 17 +- examples/command/command.cpp | 7 +- examples/lsp/lsp.cpp | 8 +- examples/main/main.cpp | 9 +- examples/server/server.cpp | 7 +- examples/stream/stream.cpp | 7 +- examples/talk-llama/talk-llama.cpp | 9 +- examples/talk/talk.cpp | 7 +- examples/wchess/wchess.cmd/wchess.cmd.cpp | 7 +- scripts/bench-all-gg.txt | 298 +++++++++++++++ scripts/bench-all.sh | 25 +- whisper.cpp | 429 ++++++++++++++-------- whisper.h | 1 + 13 files changed, 658 insertions(+), 173 deletions(-) create mode 100644 scripts/bench-all-gg.txt diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index b77621ac884..cac9385c82f 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -12,7 +12,8 @@ struct whisper_params { std::string model = "models/ggml-base.en.bin"; - bool use_gpu = true; + bool use_gpu = true; + bool flash_attn = false; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); @@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/command/command.cpp b/examples/command/command.cpp index ec749d60247..cd6cc023994 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -44,6 +44,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -696,7 +699,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index e5f8360f83d..3df54266a25 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -31,6 +31,7 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else { @@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, "\n"); @@ -436,7 +439,10 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); // init audio diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d11c1c3f81b..45eb17fe7f3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -70,6 +70,7 @@ struct whisper_params { bool no_timestamps = false; bool log_score = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt; @@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } @@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); @@ -977,7 +980,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { cparams.dtw_token_timestamps = true; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e3b96698228..c78b3026e18 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -75,6 +75,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt = ""; @@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } @@ -502,7 +504,10 @@ int main(int argc, char ** argv) { } // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + if (!params.dtw.empty()) { cparams.dtw_token_timestamps = true; cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index b82e379dc61..60c1b0894e4 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -36,6 +36,7 @@ struct whisper_params { bool tinydiarize = false; bool save_audio = false; // save audio to wav file bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false"); fprintf(stderr, "\n"); } @@ -153,7 +156,9 @@ int main(int argc, char ** argv) { } struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 838d6f56357..4aab62b9a6f 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,6 +66,7 @@ struct whisper_params { bool no_timestamps = true; bool verbose_prompt = false; bool use_gpu = true; + bool flash_attn = false; std::string person = "Georgi"; std::string bot_name = "LLaMA"; @@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } @@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str()); fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str()); @@ -285,7 +287,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); if (!ctx_wsp) { @@ -316,6 +320,7 @@ int main(int argc, char ** argv) { lcparams.n_ctx = 2048; lcparams.seed = 1; lcparams.n_threads = params.n_threads; + lcparams.flash_attn = params.flash_attn; struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams); diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index c1c6f8ba0b2..3e34e5724ff 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -32,6 +32,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string person = "Santa"; std::string language = "en"; @@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } @@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); @@ -188,7 +191,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); diff --git a/examples/wchess/wchess.cmd/wchess.cmd.cpp b/examples/wchess/wchess.cmd/wchess.cmd.cpp index f66b1765f5b..09e53f13172 100644 --- a/examples/wchess/wchess.cmd/wchess.cmd.cpp +++ b/examples/wchess/wchess.cmd/wchess.cmd.cpp @@ -32,6 +32,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -61,6 +62,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during decoding\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -92,6 +94,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -183,7 +186,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (!ctx) { diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt new file mode 100644 index 00000000000..6fd5605a2bd --- /dev/null +++ b/scripts/bench-all-gg.txt @@ -0,0 +1,298 @@ +## M1 Pro + +make -j && ./scripts/bench-all.sh 8 + +Running memcpy benchmark + +memcpy: 39.10 GB/s (heat-up) +memcpy: 44.75 GB/s ( 1 thread) +memcpy: 44.78 GB/s ( 1 thread) +memcpy: 44.97 GB/s ( 2 thread) +memcpy: 48.04 GB/s ( 3 thread) +memcpy: 50.55 GB/s ( 4 thread) +memcpy: 55.20 GB/s ( 5 thread) +memcpy: 65.60 GB/s ( 6 thread) +memcpy: 70.64 GB/s ( 7 thread) +memcpy: 73.34 GB/s ( 8 thread) +sum: -5120002535.000000 + + +make -j && ./scripts/bench-all.sh 1 0 0 + +Running ggml_mul_mat benchmark with 1 threads + + 64 x 64: Q4_0 237.1 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs) + 64 x 64: Q5_0 136.4 GFLOPS (128 runs) | Q5_1 135.6 GFLOPS (128 runs) | Q8_0 243.1 GFLOPS (128 runs) + 64 x 64: F16 140.4 GFLOPS (128 runs) | F32 316.6 GFLOPS (128 runs) + 128 x 128: Q4_0 496.6 GFLOPS (128 runs) | Q4_1 348.6 GFLOPS (128 runs) + 128 x 128: Q5_0 273.2 GFLOPS (128 runs) | Q5_1 274.1 GFLOPS (128 runs) | Q8_0 505.1 GFLOPS (128 runs) + 128 x 128: F16 300.4 GFLOPS (128 runs) | F32 653.9 GFLOPS (128 runs) + 256 x 256: Q4_0 791.7 GFLOPS (128 runs) | Q4_1 615.3 GFLOPS (128 runs) + 256 x 256: Q5_0 651.0 GFLOPS (128 runs) | Q5_1 674.7 GFLOPS (128 runs) | Q8_0 803.1 GFLOPS (128 runs) + 256 x 256: F16 869.6 GFLOPS (128 runs) | F32 957.2 GFLOPS (128 runs) + 512 x 512: Q4_0 973.3 GFLOPS (128 runs) | Q4_1 897.9 GFLOPS (128 runs) + 512 x 512: Q5_0 1078.8 GFLOPS (128 runs) | Q5_1 998.4 GFLOPS (128 runs) | Q8_0 752.4 GFLOPS (128 runs) + 512 x 512: F16 892.5 GFLOPS (128 runs) | F32 1399.6 GFLOPS (128 runs) +1024 x 1024: Q4_0 1402.7 GFLOPS (128 runs) | Q4_1 1218.5 GFLOPS (128 runs) +1024 x 1024: Q5_0 1444.8 GFLOPS (128 runs) | Q5_1 1444.7 GFLOPS (128 runs) | Q8_0 1395.7 GFLOPS (128 runs) +1024 x 1024: F16 1524.1 GFLOPS (128 runs) | F32 1726.6 GFLOPS (128 runs) +2048 x 2048: Q4_0 1479.4 GFLOPS ( 87 runs) | Q4_1 1378.5 GFLOPS ( 81 runs) +2048 x 2048: Q5_0 1454.6 GFLOPS ( 85 runs) | Q5_1 1462.9 GFLOPS ( 86 runs) | Q8_0 1483.2 GFLOPS ( 87 runs) +2048 x 2048: F16 1488.0 GFLOPS ( 87 runs) | F32 1538.2 GFLOPS ( 90 runs) +4096 x 4096: Q4_0 1509.7 GFLOPS ( 11 runs) | Q4_1 1433.0 GFLOPS ( 11 runs) +4096 x 4096: Q5_0 1422.4 GFLOPS ( 11 runs) | Q5_1 1437.0 GFLOPS ( 11 runs) | Q8_0 1523.0 GFLOPS ( 12 runs) +4096 x 4096: F16 1551.3 GFLOPS ( 12 runs) | F32 1451.0 GFLOPS ( 11 runs) + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M1 Pro | METAL | tiny | 1 | 0 | 39.21 | 1.74 | 0.61 | 0.04 | 22c96b4 | +| M1 Pro | METAL | base | 1 | 0 | 70.76 | 2.60 | 0.93 | 0.06 | 22c96b4 | +| M1 Pro | METAL | small | 1 | 0 | 217.28 | 6.42 | 2.14 | 0.17 | 22c96b4 | +| M1 Pro | METAL | medium | 1 | 0 | 596.74 | 14.43 | 4.75 | 0.45 | 22c96b4 | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M1 Pro | METAL | tiny | 1 | 1 | 30.77 | 1.59 | 0.54 | 0.03 | 22c96b4 | +| M1 Pro | METAL | base | 1 | 1 | 60.42 | 2.29 | 0.81 | 0.05 | 22c96b4 | +| M1 Pro | METAL | small | 1 | 1 | 183.82 | 5.12 | 1.81 | 0.14 | 22c96b4 | +| M1 Pro | METAL | medium | 1 | 1 | 517.92 | 11.60 | 4.01 | 0.38 | 22c96b4 | + + +## M2 Ultra + +make -j && ./scripts/bench-all.sh 8 + +Running memcpy benchmark + +memcpy: 46.58 GB/s (heat-up) +memcpy: 54.16 GB/s ( 1 thread) +memcpy: 54.23 GB/s ( 1 thread) +memcpy: 99.63 GB/s ( 2 thread) +memcpy: 140.59 GB/s ( 3 thread) +memcpy: 176.52 GB/s ( 4 thread) +memcpy: 158.90 GB/s ( 5 thread) +memcpy: 163.00 GB/s ( 6 thread) +memcpy: 189.69 GB/s ( 7 thread) +memcpy: 197.15 GB/s ( 8 thread) +sum: -5120002007.000000 + + +make -j && ./scripts/bench-all.sh 1 + +Running ggml_mul_mat benchmark with 1 threads + + 64 x 64: Q4_0 245.8 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs) + 64 x 64: Q5_0 115.7 GFLOPS (128 runs) | Q5_1 125.9 GFLOPS (128 runs) | Q8_0 215.8 GFLOPS (128 runs) + 64 x 64: F16 139.5 GFLOPS (128 runs) | F32 337.2 GFLOPS (128 runs) + 128 x 128: Q4_0 494.8 GFLOPS (128 runs) | Q4_1 350.4 GFLOPS (128 runs) + 128 x 128: Q5_0 257.1 GFLOPS (128 runs) | Q5_1 261.4 GFLOPS (128 runs) | Q8_0 509.4 GFLOPS (128 runs) + 128 x 128: F16 302.3 GFLOPS (128 runs) | F32 672.8 GFLOPS (128 runs) + 256 x 256: Q4_0 795.7 GFLOPS (128 runs) | Q4_1 663.7 GFLOPS (128 runs) + 256 x 256: Q5_0 737.8 GFLOPS (128 runs) | Q5_1 757.6 GFLOPS (128 runs) | Q8_0 827.7 GFLOPS (128 runs) + 256 x 256: F16 872.6 GFLOPS (128 runs) | F32 956.3 GFLOPS (128 runs) + 512 x 512: Q4_0 1188.0 GFLOPS (128 runs) | Q4_1 1085.0 GFLOPS (128 runs) + 512 x 512: Q5_0 1421.1 GFLOPS (128 runs) | Q5_1 1454.9 GFLOPS (128 runs) | Q8_0 1191.4 GFLOPS (128 runs) + 512 x 512: F16 1577.4 GFLOPS (128 runs) | F32 1982.0 GFLOPS (128 runs) +1024 x 1024: Q4_0 2342.6 GFLOPS (128 runs) | Q4_1 1955.8 GFLOPS (128 runs) +1024 x 1024: Q5_0 2306.7 GFLOPS (128 runs) | Q5_1 2217.0 GFLOPS (128 runs) | Q8_0 2230.7 GFLOPS (128 runs) +1024 x 1024: F16 2593.8 GFLOPS (128 runs) | F32 3269.0 GFLOPS (128 runs) +2048 x 2048: Q4_0 3735.7 GFLOPS (128 runs) | Q4_1 3205.3 GFLOPS (128 runs) +2048 x 2048: Q5_0 3584.5 GFLOPS (128 runs) | Q5_1 3621.7 GFLOPS (128 runs) | Q8_0 3622.3 GFLOPS (128 runs) +2048 x 2048: F16 3763.6 GFLOPS (128 runs) | F32 4153.3 GFLOPS (128 runs) +4096 x 4096: Q4_0 3891.1 GFLOPS ( 29 runs) | Q4_1 3554.0 GFLOPS ( 26 runs) +4096 x 4096: Q5_0 3753.1 GFLOPS ( 28 runs) | Q5_1 3750.1 GFLOPS ( 28 runs) | Q8_0 3768.5 GFLOPS ( 28 runs) +4096 x 4096: F16 3864.2 GFLOPS ( 29 runs) | F32 3970.5 GFLOPS ( 29 runs) + + +make -j && ./scripts/bench-all.sh 1 1 0 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M2 ULTRA | METAL | tiny | 1 | 0 | 12.32 | 1.35 | 0.49 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 11.65 | 1.30 | 0.51 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 12.08 | 1.30 | 0.51 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | base | 1 | 0 | 17.58 | 1.90 | 0.76 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 18.89 | 1.86 | 0.79 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 20.69 | 1.88 | 0.79 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | small | 1 | 0 | 49.32 | 3.85 | 1.71 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 54.91 | 3.81 | 1.82 | 0.06 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 54.92 | 3.81 | 1.79 | 0.06 | 22c96b4 | +| M2 ULTRA | METAL | medium | 1 | 0 | 134.34 | 8.04 | 3.82 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 151.68 | 7.59 | 4.07 | 0.14 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 151.58 | 7.67 | 4.07 | 0.14 | 22c96b4 | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 120.82 | 1.07 | 0.41 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 235.63 | 12.27 | 5.85 | 0.22 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 273.38 | 11.17 | 6.40 | 0.26 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 272.44 | 11.32 | 6.29 | 0.26 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 212.51 | 1.20 | 0.47 | 0.02 | 22c96b4 | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M2 ULTRA | METAL | tiny | 1 | 1 | 9.07 | 1.33 | 0.45 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 9.74 | 1.33 | 0.47 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 8.93 | 1.31 | 0.46 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | base | 1 | 1 | 15.75 | 1.87 | 0.71 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 17.04 | 1.83 | 0.74 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 17.17 | 1.83 | 0.74 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | small | 1 | 1 | 42.33 | 3.64 | 1.60 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 47.61 | 3.63 | 1.70 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 47.70 | 3.66 | 1.68 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | medium | 1 | 1 | 114.42 | 7.53 | 3.55 | 0.11 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 132.63 | 7.02 | 3.77 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 132.28 | 7.10 | 3.76 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 102.34 | 1.01 | 0.42 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 203.01 | 11.03 | 5.45 | 0.20 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 240.05 | 10.18 | 5.98 | 0.23 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 239.22 | 10.23 | 5.87 | 0.23 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 181.14 | 1.14 | 0.48 | 0.02 | 22c96b4 | + + + +## Ryzen 9 5950X + RTX 2060 + +make -j && ./scripts/bench-all.sh 8 0 0 + +Running memcpy benchmark + +memcpy: 12.36 GB/s (heat-up) +memcpy: 12.33 GB/s ( 1 thread) +memcpy: 12.38 GB/s ( 1 thread) +memcpy: 14.48 GB/s ( 2 thread) +memcpy: 15.00 GB/s ( 3 thread) +memcpy: 14.77 GB/s ( 4 thread) +memcpy: 14.60 GB/s ( 5 thread) +memcpy: 14.57 GB/s ( 6 thread) +memcpy: 14.34 GB/s ( 7 thread) +memcpy: 14.40 GB/s ( 8 thread) +sum: -5119998076.000000 + +Running ggml_mul_mat benchmark with 8 threads + + 64 x 64: Q4_0 3.1 GFLOPS (128 runs) | Q4_1 3.1 GFLOPS (128 runs) + 64 x 64: Q5_0 3.0 GFLOPS (128 runs) | Q5_1 2.9 GFLOPS (128 runs) | Q8_0 3.1 GFLOPS (128 runs) + 64 x 64: F16 3.0 GFLOPS (128 runs) | F32 3.0 GFLOPS (128 runs) + 128 x 128: Q4_0 21.1 GFLOPS (128 runs) | Q4_1 20.3 GFLOPS (128 runs) + 128 x 128: Q5_0 20.6 GFLOPS (128 runs) | Q5_1 20.4 GFLOPS (128 runs) | Q8_0 22.1 GFLOPS (128 runs) + 128 x 128: F16 21.7 GFLOPS (128 runs) | F32 21.7 GFLOPS (128 runs) + 256 x 256: Q4_0 105.7 GFLOPS (128 runs) | Q4_1 94.4 GFLOPS (128 runs) + 256 x 256: Q5_0 94.8 GFLOPS (128 runs) | Q5_1 87.5 GFLOPS (128 runs) | Q8_0 107.2 GFLOPS (128 runs) + 256 x 256: F16 95.1 GFLOPS (128 runs) | F32 94.3 GFLOPS (128 runs) + 512 x 512: Q4_0 214.7 GFLOPS (128 runs) | Q4_1 189.8 GFLOPS (128 runs) + 512 x 512: Q5_0 187.7 GFLOPS (128 runs) | Q5_1 176.2 GFLOPS (128 runs) | Q8_0 252.2 GFLOPS (128 runs) + 512 x 512: F16 220.8 GFLOPS (128 runs) | F32 218.3 GFLOPS (128 runs) +1024 x 1024: Q4_0 333.7 GFLOPS (128 runs) | Q4_1 305.8 GFLOPS (128 runs) +1024 x 1024: Q5_0 283.2 GFLOPS (128 runs) | Q5_1 268.2 GFLOPS (125 runs) | Q8_0 394.1 GFLOPS (128 runs) +1024 x 1024: F16 355.0 GFLOPS (128 runs) | F32 313.0 GFLOPS (128 runs) +2048 x 2048: Q4_0 395.0 GFLOPS ( 23 runs) | Q4_1 380.6 GFLOPS ( 23 runs) +2048 x 2048: Q5_0 336.6 GFLOPS ( 20 runs) | Q5_1 318.4 GFLOPS ( 19 runs) | Q8_0 482.6 GFLOPS ( 29 runs) +2048 x 2048: F16 424.5 GFLOPS ( 25 runs) | F32 337.7 GFLOPS ( 20 runs) +4096 x 4096: Q4_0 412.8 GFLOPS ( 4 runs) | Q4_1 405.1 GFLOPS ( 3 runs) +4096 x 4096: Q5_0 346.0 GFLOPS ( 3 runs) | Q5_1 334.6 GFLOPS ( 3 runs) | Q8_0 502.6 GFLOPS ( 4 runs) +4096 x 4096: F16 412.5 GFLOPS ( 4 runs) | F32 274.0 GFLOPS ( 3 runs) + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| Ryzen 9 5950X | AVX2 | tiny | 8 | 0 | 195.29 | 1.57 | 0.51 | 0.26 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | tiny-q5_0 | 8 | 0 | 213.33 | 1.10 | 0.50 | 0.30 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | tiny-q5_1 | 8 | 0 | 219.38 | 1.18 | 0.53 | 0.32 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base | 8 | 0 | 424.85 | 3.71 | 1.03 | 0.46 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base-q5_0 | 8 | 0 | 473.61 | 1.81 | 0.82 | 0.52 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base-q5_1 | 8 | 0 | 484.14 | 1.92 | 0.85 | 0.56 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small | 8 | 0 | 1458.32 | 12.66 | 3.09 | 1.26 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small-q5_0 | 8 | 0 | 1673.22 | 6.42 | 2.18 | 1.45 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small-q5_1 | 8 | 0 | 1724.78 | 6.72 | 2.32 | 1.52 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium | 8 | 0 | 4333.87 | 36.80 | 8.56 | 3.37 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-q5_0 | 8 | 0 | 5194.09 | 19.21 | 5.71 | 3.97 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-q5_1 | 8 | 0 | 5450.39 | 20.01 | 5.99 | 4.17 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-dis | 8 | 0 | 3995.19 | 5.08 | 1.21 | 0.55 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2 | 8 | 0 | 8056.16 | 69.74 | 16.11 | 6.13 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-q5_0 | 8 | 0 | 9799.58 | 35.16 | 10.49 | 7.28 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-q5_1 | 8 | 0 | ms | 36.74 | 11.02 | 7.65 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-dis | 8 | 0 | 7490.03 | 7.40 | 1.70 | 0.72 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RTX 2060 | AVX2 CUDA | tiny | 8 | 0 | 12.54 | 0.93 | 0.29 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 0 | 12.73 | 0.98 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 0 | 12.72 | 0.99 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base | 8 | 0 | 24.14 | 1.28 | 0.41 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 0 | 24.58 | 1.38 | 0.35 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 0 | 24.58 | 1.37 | 0.35 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small | 8 | 0 | 74.70 | 2.91 | 0.84 | 0.07 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 0 | 76.12 | 2.84 | 0.77 | 0.08 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 0 | 76.14 | 2.84 | 0.76 | 0.08 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium | 8 | 0 | 200.69 | 6.46 | 1.83 | 0.17 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 0 | 204.80 | 5.90 | 1.65 | 0.19 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 0 | 205.61 | 5.85 | 1.61 | 0.19 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 0 | 186.17 | 0.86 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 0 | 347.22 | 10.36 | 2.82 | 0.29 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 0 | 357.06 | 8.81 | 2.58 | 0.34 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 0 | 356.97 | 8.62 | 2.49 | 0.33 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 0 | 318.05 | 1.03 | 0.34 | 0.04 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RTX 2060 | AVX2 CUDA | tiny | 8 | 1 | 7.21 | 0.76 | 0.29 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 1 | 7.42 | 0.82 | 0.18 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 1 | 7.38 | 0.82 | 0.18 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base | 8 | 1 | 13.49 | 1.04 | 0.36 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 1 | 13.94 | 1.13 | 0.26 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 1 | 13.94 | 1.14 | 0.26 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small | 8 | 1 | 42.81 | 2.33 | 0.69 | 0.05 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 1 | 44.43 | 2.25 | 0.59 | 0.06 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 1 | 44.11 | 2.24 | 0.58 | 0.06 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium | 8 | 1 | 115.47 | 5.17 | 1.45 | 0.11 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 1 | 120.37 | 4.63 | 1.25 | 0.13 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 1 | 120.28 | 4.55 | 1.21 | 0.13 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 1 | 101.69 | 0.75 | 0.20 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 1 | 205.67 | 8.49 | 2.19 | 0.18 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 1 | 214.07 | 6.88 | 1.94 | 0.22 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 1 | 213.98 | 6.70 | 1.86 | 0.22 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 1 | 176.71 | 0.91 | 0.31 | 0.03 | 22c96b4 | + + + + +# V100 + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| V100 | AVX2 CUDA | tiny | 1 | 0 | 6.21 | 1.11 | 0.30 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 0 | 5.97 | 1.10 | 0.26 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base | 1 | 0 | 10.95 | 1.47 | 0.42 | 0.03 | 22c96b4 | +| V100 | AVX2 CUDA | base-q5_1 | 1 | 0 | 11.13 | 1.53 | 0.36 | 0.03 | 22c96b4 | +| V100 | AVX2 CUDA | small | 1 | 0 | 31.57 | 2.96 | 0.84 | 0.05 | 22c96b4 | +| V100 | AVX2 CUDA | small-q5_1 | 1 | 0 | 32.19 | 3.14 | 0.75 | 0.05 | 22c96b4 | +| V100 | AVX2 CUDA | medium | 1 | 0 | 85.88 | 6.49 | 1.80 | 0.10 | 22c96b4 | +| V100 | AVX2 CUDA | medium-q5_0 | 1 | 0 | 87.53 | 5.82 | 1.37 | 0.10 | 22c96b4 | +| V100 | AVX2 CUDA | large-v2 | 1 | 0 | 142.23 | 8.92 | 2.62 | 0.15 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| V100 | AVX2 CUDA | tiny | 1 | 1 | 3.96 | 0.82 | 0.24 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 1 | 4.05 | 0.85 | 0.18 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base | 1 | 1 | 7.21 | 1.16 | 0.36 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base-q5_1 | 1 | 1 | 7.39 | 1.21 | 0.26 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | small | 1 | 1 | 19.81 | 2.41 | 0.71 | 0.04 | 22c96b4 | +| V100 | AVX2 CUDA | small-q5_1 | 1 | 1 | 20.50 | 2.31 | 0.51 | 0.04 | 22c96b4 | +| V100 | AVX2 CUDA | medium | 1 | 1 | 56.02 | 4.89 | 1.44 | 0.07 | 22c96b4 | +| V100 | AVX2 CUDA | medium-q5_0 | 1 | 1 | 57.85 | 4.73 | 1.09 | 0.08 | 22c96b4 | +| V100 | AVX2 CUDA | large-v2 | 1 | 1 | 92.73 | 7.18 | 2.14 | 0.10 | 22c96b4 | + diff --git a/scripts/bench-all.sh b/scripts/bench-all.sh index 6939dafaca0..8a857c67b6c 100755 --- a/scripts/bench-all.sh +++ b/scripts/bench-all.sh @@ -2,7 +2,7 @@ # Helper script to run the bench tool on all models and print the results in share-able format -printf "Usage: ./bench.sh [n_threads] [encoder-only]\n" +printf "Usage: ./bench.sh [n_threads] [encoder-only] [flash-attn]\n" if [ -z "$1" ]; then n_threads=4 @@ -11,12 +11,19 @@ else fi encoder_only=0 -if [ -z "$2" ]; then +if [ -z "$2" ] || [ "$2" -eq 0 ]; then encoder_only=0 else encoder_only=$2 fi +fattn="" +if [ -z "$3" ] || [ "$3" -eq 0 ]; then + fattn="" +else + fattn="-fa" +fi + models=( \ "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \ "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \ @@ -44,13 +51,19 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit" -printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" +if [ "$fattn" == "-fa" ]; then + fattn_i=1 +else + fattn_i=0 +fi + +printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "FA" "Enc." "Dec." "Bch5" "PP" "Commit" +printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do # actual run # store stderr output in a variable in order to parse it later - output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1) + output=$(./bench -m ./models/ggml-$model.bin -t $n_threads $fattn 2>&1) ret=$? # parse the output: @@ -95,6 +108,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" + printf "| | | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$fattn_i" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" fi done diff --git a/whisper.cpp b/whisper.cpp index ff4223daf42..84aec8238cd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -809,14 +809,15 @@ struct whisper_state { // shared between all decoders whisper_kv_cache kv_cross; + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + whisper_mel mel; whisper_batch batch; whisper_decoder decoders[WHISPER_MAX_DECODERS]; - ggml_backend_t backend = nullptr; - // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) { } static bool kv_cache_init( - const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, int n_ctx) { - const int64_t n_text_state = hparams.n_text_state; - const int64_t n_text_layer = hparams.n_text_layer; - const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; @@ -941,6 +940,8 @@ static bool kv_cache_init( return false; } + ggml_backend_buffer_clear(cache.buffer, 0); + return true; } @@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp( } } +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) { + if (!wctx.params.flash_attn) { + return 1u; + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend)) { + return 32u; + } +#endif + +#ifdef GGML_USE_CUDA + if (ggml_backend_is_cuda(wctx.backend)) { + return 256u; + } +#endif + + return 1u; +} + // [EXPERIMENTAL] Token-level timestamps with DTW static bool aheads_masks_init( const whisper_context_params & cparams, @@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; + const int n_state_head = n_state/n_head; + + auto & kv_pad = wstate.kv_pad; + + WHISPER_ASSERT(!!kv_pad.ctx); + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_encode.meta.size(), /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), @@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); - const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); + const float KQscale = 1.0f/sqrtf(float(n_state_head)); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25)); + //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25)); + //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, @@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), 0, 2, 1, 3); - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0))); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k)*n_state, + ggml_element_size(kv_pad.k)*n_state_head, + 0); - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) - ); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v)*n_state, + ggml_element_size(kv_pad.v)*n_state_head, + 0); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } else { + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } } // projection @@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross( const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; + const int n_state_head = n_state/n_head; + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_cross.meta.size(), /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), @@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); - const float Kscale = pow(float(n_state) / n_head, -0.25); + const float Kscale = pow(float(n_state_head), -0.25); for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross( Vcross, layer.cross_attn_v_b); - Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + struct ggml_tensor * k; + struct ggml_tensor * v; - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, - n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); - struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad)); + } else { + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + + v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); @@ -2195,7 +2261,7 @@ static bool whisper_encode_internal( } if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } else { @@ -2218,7 +2284,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2234,7 +2300,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder( const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; + const int n_state_head = n_state/n_head; + const int n_tokens = batch.n_tokens; const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); + + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); @@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(position, "position"); ggml_set_input(position); - const float KQscale = pow(float(n_state)/n_head, -0.25); + const float KQscale = pow(float(n_state_head), -0.25); - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); + struct ggml_tensor * k; + struct ggml_tensor * v; + + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state, + (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head)); + } else { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_kv, n_head, + n_state_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, - ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v)*n_state, + ggml_element_size(kv_self.v)*n_state_head, + ggml_element_size(kv_self.v)*n_state*n_ctx*il); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_state/n_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - n_ctx*ggml_element_size(kv_self.v)*n_state*il); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_state_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state_head, + n_ctx*ggml_element_size(kv_self.v)*n_state*il); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Qcur, layer.cross_attn_q_b); - Qcur = ggml_scale(ctx0, Qcur, KQscale); - - // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, n_audio_ctx, n_head, - ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); - - //struct ggml_tensor * Vcross = - // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, n_audio_ctx); - - //struct ggml_tensor * V_trans = - // ggml_cpy(ctx0, - // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, wstate.kv_cross.v, - n_audio_ctx, n_state/n_head, n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v), - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); - - // ------ - struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); + if (wctx.params.flash_attn) { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il); - // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v)*n_state, + ggml_element_size(wstate.kv_cross.v)*n_state_head, + ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f); - // [EXPERIMENTAL] Token-level timestamps with DTW - if (wctx.params.dtw_token_timestamps) { - if (wstate.aheads_masks.m[il] != nullptr) { - struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); - if (aheads_cross_QKs == NULL) { - aheads_cross_QKs = aheads_KQs; - } else { - aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); + + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_audio_ctx, n_state_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); + + // ------ + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) { + if (wstate.aheads_masks.m[il] != nullptr) { + struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) { + aheads_cross_QKs = aheads_KQs; + } else { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + } } } - } - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, n_tokens) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2638,7 +2733,9 @@ static bool whisper_decode_internal( return false; } - kv_self.n = whisper_kv_cache_cell_max(kv_self); + const uint32_t pad = whisper_kv_cache_get_padding(wctx); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); } @@ -2672,9 +2769,10 @@ static bool whisper_decode_internal( struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); auto & kv_self = wstate.kv_self; - const int32_t n_kv = kv_self.n; - wstate.inp_mask.resize(n_kv*n_tokens); + const int32_t n_kv = kv_self.n; + + wstate.inp_mask.resize(ggml_nelements(KQ_mask)); float * data = wstate.inp_mask.data(); memset(data, 0, ggml_nbytes(KQ_mask)); @@ -2690,6 +2788,12 @@ static bool whisper_decode_internal( } } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); @@ -2697,7 +2801,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; - state->backend = whisper_backend_init(ctx->params); - if (!state->backend) { - WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); - whisper_free_state(state); - return nullptr; - } - // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { @@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, + /*.flash_attn =*/ false, /*.gpu_device =*/ 0, /*.dtw_token_timestamps =*/ false, @@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { ggml_time_init(); + if (params.flash_attn && params.dtw_token_timestamps) { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + whisper_context * ctx = new whisper_context; ctx->params = params; @@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) { if (state) { kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); + kv_cache_free(state->kv_pad); #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) { ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_decode.alloc); - ggml_backend_free(state->backend); - // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks); diff --git a/whisper.h b/whisper.h index 6a875d3bbb9..9c7c58d874b 100644 --- a/whisper.h +++ b/whisper.h @@ -113,6 +113,7 @@ extern "C" { struct whisper_context_params { bool use_gpu; + bool flash_attn; int gpu_device; // CUDA device // [EXPERIMENTAL] Token-level timestamps with DTW