From ae1bd690419032c95406940c8533a905cb1ae026 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Nov 2023 22:45:08 +0200 Subject: [PATCH] bench : add batch size 5 bench --- examples/bench/bench.cpp | 30 ++++++++++++++++++++---------- extra/bench-all.sh | 7 ++++--- whisper.cpp | 16 +++++++++++++--- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index db1c4e800cd..949e5737167 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) { } // heat encoder if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to encode: %d\n", ret); return 4; } @@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) { // prompt heat if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } // text-generation heat if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } @@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) { // actual run if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to encode: %d\n", ret); return 4; } - for (int i = 0; i < 16; i++) { - if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + // text-generation + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } } - for (int i = 0; i < 256; i++) { - if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + // batched decoding + for (int i = 0; i < 64; i++) { + if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } + } + + // prompt processing + for (int i = 0; i < 16; i++) { + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } } diff --git a/extra/bench-all.sh b/extra/bench-all.sh index db042673d69..af8f67599a4 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit" -printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" +printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit" +printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do # actual run @@ -56,6 +56,7 @@ for model in "${models[@]}"; do # parse the output: encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}') decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}') + batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}') prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}') system_info=$(echo "$output" | grep "system_info") n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}') @@ -94,6 +95,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit" + printf "| | | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" fi done diff --git a/whisper.cpp b/whisper.cpp index 21fac2e5d21..84e23c91ffb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -773,13 +773,15 @@ struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) - int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures @@ -2616,9 +2618,12 @@ static bool whisper_decode_internal( if (batch.n_tokens == 1) { wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.n_decode++; + } else if (batch.n_tokens < 16) { + wstate.t_batchd_us += ggml_time_us() - t_start_us; + wstate.n_batchd += n_tokens; } else { wstate.t_prompt_us += ggml_time_us() - t_start_us; - wstate.n_prompt++; + wstate.n_prompt += n_tokens; } return !(abort_callback && abort_callback(abort_callback_data)); @@ -3827,6 +3832,7 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); @@ -3834,6 +3840,7 @@ void whisper_print_timings(struct whisper_context * ctx) { WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); @@ -3850,6 +3857,7 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->n_sample = 0; ctx->state->n_encode = 0; ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; ctx->state->n_prompt = 0; } } @@ -5896,11 +5904,13 @@ int whisper_full_parallel( ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; ctx->state->t_prompt_us += states[i]->t_prompt_us; ctx->state->n_sample += states[i]->n_sample; ctx->state->n_encode += states[i]->n_encode; ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; ctx->state->n_prompt += states[i]->n_prompt; whisper_free_state(states[i]);