From 239defe61dbe9dddc6304942e8a3d03d6a3c69ab Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 3 Nov 2023 22:26:27 +0200 Subject: [PATCH] sync : whisper.cpp (ARM 32-bit, abort callback, wav_writer, etc.) (#602) --- examples/common.h | 100 +++++++++++++++ examples/whisper/main.cpp | 78 +++++++++--- examples/whisper/whisper.cpp | 61 +++++---- examples/whisper/whisper.h | 10 ++ scripts/sync-whisper.sh | 17 ++- src/ggml-impl.h | 6 - src/ggml-metal.m | 8 +- src/ggml-quants.c | 241 ++++++++++++++++++++++++----------- src/ggml.c | 18 +++ 9 files changed, 419 insertions(+), 120 deletions(-) diff --git a/examples/common.h b/examples/common.h index 1d4e9c37c..9a94bab7a 100644 --- a/examples/common.h +++ b/examples/common.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #define COMMON_SAMPLE_RATE 16000 @@ -142,6 +144,104 @@ bool read_wav( std::vector> & pcmf32s, bool stereo); +// Write PCM data into WAV audio file +class wav_writer { +private: + std::ofstream file; + uint32_t dataSize = 0; + std::string wav_filename; + + bool write_header(const uint32_t sample_rate, + const uint16_t bits_per_sample, + const uint16_t channels) { + + file.write("RIFF", 4); + file.write("\0\0\0\0", 4); // Placeholder for file size + file.write("WAVE", 4); + file.write("fmt ", 4); + + const uint32_t sub_chunk_size = 16; + const uint16_t audio_format = 1; // PCM format + const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8; + const uint16_t block_align = channels * bits_per_sample / 8; + + file.write(reinterpret_cast(&sub_chunk_size), 4); + file.write(reinterpret_cast(&audio_format), 2); + file.write(reinterpret_cast(&channels), 2); + file.write(reinterpret_cast(&sample_rate), 4); + file.write(reinterpret_cast(&byte_rate), 4); + file.write(reinterpret_cast(&block_align), 2); + file.write(reinterpret_cast(&bits_per_sample), 2); + file.write("data", 4); + file.write("\0\0\0\0", 4); // Placeholder for data size + + return true; + } + + // It is assumed that PCM data is normalized to a range from -1 to 1 + bool write_audio(const float * data, size_t length) { + for (size_t i = 0; i < length; ++i) { + const auto intSample = static_cast(data[i] * 32767); + file.write(reinterpret_cast(&intSample), sizeof(int16_t)); + dataSize += sizeof(int16_t); + } + if (file.is_open()) { + file.seekp(4, std::ios::beg); + uint32_t fileSize = 36 + dataSize; + file.write(reinterpret_cast(&fileSize), 4); + file.seekp(40, std::ios::beg); + file.write(reinterpret_cast(&dataSize), 4); + file.seekp(0, std::ios::end); + } + return true; + } + + bool open_wav(const std::string & filename) { + if (filename != wav_filename) { + if (file.is_open()) { + file.close(); + } + } + if (!file.is_open()) { + file.open(filename, std::ios::binary); + wav_filename = filename; + dataSize = 0; + } + return file.is_open(); + } + +public: + bool open(const std::string & filename, + const uint32_t sample_rate, + const uint16_t bits_per_sample, + const uint16_t channels) { + + if (open_wav(filename)) { + write_header(sample_rate, bits_per_sample, channels); + } else { + return false; + } + + return true; + } + + bool close() { + file.close(); + return true; + } + + bool write(const float * data, size_t length) { + return write_audio(data, length); + } + + ~wav_writer() { + if (file.is_open()) { + file.close(); + } + } +}; + + // Apply a high-pass frequency filter to PCM audio // Suppresses frequencies below cutoff Hz void high_pass_filter( diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 60c1cca75..bed0789f9 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -83,6 +83,7 @@ struct whisper_params { bool output_wts = false; bool output_csv = false; bool output_jsn = false; + bool output_jsn_full = false; bool output_lrc = false; bool print_special = false; bool print_colors = false; @@ -151,6 +152,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; } else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } @@ -206,6 +208,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); + fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); @@ -511,7 +514,12 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe return true; } -bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { +bool output_json( + struct whisper_context * ctx, + const char * fname, + const whisper_params & params, + std::vector> pcmf32s, + bool full) { std::ofstream fout(fname); int indent = 0; @@ -528,7 +536,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper auto end_arr = [&](bool end) { indent--; doindent(); - fout << (end ? "]\n" : "},\n"); + fout << (end ? "]\n" : "],\n"); }; auto start_obj = [&](const char *name) { @@ -569,12 +577,29 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper end_value(end); }; + auto value_f = [&](const char *name, const float val, bool end) { + start_value(name); + fout << val; + end_value(end); + }; + auto value_b = [&](const char *name, const bool val, bool end) { start_value(name); fout << (val ? "true" : "false"); end_value(end); }; + auto times_o = [&](int64_t t0, int64_t t1, bool end) { + start_obj("timestamps"); + value_s("from", to_timestamp(t0, true).c_str(), false); + value_s("to", to_timestamp(t1, true).c_str(), true); + end_obj(false); + start_obj("offsets"); + value_i("from", t0 * 10, false); + value_i("to", t1 * 10, true); + end_obj(end); + }; + if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -620,15 +645,26 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper const int64_t t1 = whisper_full_get_segment_t1(ctx, i); start_obj(nullptr); - start_obj("timestamps"); - value_s("from", to_timestamp(t0, true).c_str(), false); - value_s("to", to_timestamp(t1, true).c_str(), true); - end_obj(false); - start_obj("offsets"); - value_i("from", t0 * 10, false); - value_i("to", t1 * 10, true); - end_obj(false); - value_s("text", text, !params.diarize && !params.tinydiarize); + times_o(t0, t1, false); + value_s("text", text, !params.diarize && !params.tinydiarize && !full); + + if (full) { + start_arr("tokens"); + const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { + auto token = whisper_full_get_token_data(ctx, i, j); + start_obj(nullptr); + value_s("text", whisper_token_to_str(ctx, token.id), false); + if(token.t0 > -1 && token.t1 > -1) { + // If we have per-token timestamps, write them out + times_o(token.t0, token.t1, false); + } + value_i("id", token.id, false); + value_f("p", token.p, true); + end_obj(j == (n - 1)); + } + end_arr(!params.diarize && !params.tinydiarize); + } if (params.diarize && pcmf32s.size() == 2) { value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); @@ -912,7 +948,7 @@ int main(int argc, char ** argv) { wparams.offset_ms = params.offset_t_ms; wparams.duration_ms = params.duration_ms; - wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0; wparams.thold_pt = params.word_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.split_on_word = params.split_on_word; @@ -944,8 +980,9 @@ int main(int argc, char ** argv) { wparams.progress_callback_user_data = &user_data; } - // example for abort mechanism - // in this example, we do not abort the processing, but we could if the flag is set to true + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race @@ -957,6 +994,17 @@ int main(int argc, char ** argv) { wparams.encoder_begin_callback_user_data = &is_aborted; } + // the callback is called before every computation - if it returns true, the computation is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void * user_data) { + bool is_aborted = *(bool*)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; + } + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 10; @@ -1000,7 +1048,7 @@ int main(int argc, char ** argv) { // output to JSON file if (params.output_jsn) { const auto fname_jsn = fname_out + ".json"; - output_json(ctx, fname_jsn.c_str(), params, pcmf32s); + output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full); } // output to LRC file diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 6d4a2c127..17ef4d9e8 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -120,14 +120,23 @@ static void byteswap_tensor(ggml_tensor * tensor) { //#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 +#define WHISPER_MAX_NODES 4096 // // ggml helpers // -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { +static void ggml_graph_compute_helper( + std::vector & buf, + ggml_cgraph * graph, + int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + plan.abort_callback = abort_callback; + plan.abort_callback_data = abort_callback_data; + if (plan.work_size > 0) { buf.resize(plan.work_size); plan.work_data = buf.data(); @@ -655,7 +664,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct auto & meta = allocr.meta; auto & data = allocr.data; - meta.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); + meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); alloc = ggml_allocr_new_measure(tensor_alignment); @@ -1608,7 +1617,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); ggml_allocr * alloc = wstate.alloc_encode.alloc; @@ -1922,7 +1931,9 @@ static bool whisper_encode_internal( whisper_context & wctx, whisper_state & wstate, const int mel_offset, - const int n_threads) { + const int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); // conv @@ -1936,7 +1947,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } } @@ -1955,10 +1966,10 @@ static bool whisper_encode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -1977,10 +1988,10 @@ static bool whisper_encode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -2024,7 +2035,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); ggml_allocr * alloc = wstate.alloc_decode.alloc; @@ -2346,7 +2357,9 @@ static bool whisper_decode_internal( const whisper_token * tokens, const int n_tokens, const int n_past, - const int n_threads) { + const int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; @@ -2375,10 +2388,10 @@ static bool whisper_decode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -3290,7 +3303,7 @@ int whisper_set_mel( } int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return -1; } @@ -3299,7 +3312,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return -1; } @@ -3310,7 +3323,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { const int selected_decoder_id = 0; - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return 1; } @@ -3327,7 +3340,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return 1; } @@ -4594,7 +4607,7 @@ int whisper_full_with_state( } // encode audio features starting at offset seek - if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to encode\n", __func__); return -6; } @@ -4677,7 +4690,7 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to decode\n", __func__); return -7; } @@ -4901,7 +4914,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to decode\n", __func__); return -8; } @@ -5270,6 +5283,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) return ctx->state->result_all[i_segment].t1; } +bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].speaker_turn_next; +} + bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { return ctx->state->result_all[i_segment].speaker_turn_next; } @@ -5471,12 +5488,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { double tsum = 0.0; // heat-up - ggml_graph_compute_helper(work, gf, n_threads); + ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = ggml_time_us(); - ggml_graph_compute_helper(work, gf, n_threads); + ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); const int64_t t1 = ggml_time_us(); diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 73ab4d799..c3118c9c9 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -334,6 +334,11 @@ extern "C" { // If it returns false, the computation is aborted typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + // Abort callback + // If not NULL, called before ggml computation + // If it returns true, the computation is aborted + typedef bool (*whisper_abort_callback)(void * user_data); + // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits @@ -428,6 +433,10 @@ extern "C" { whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; + // called each time before ggml computation starts + whisper_abort_callback abort_callback; + void * abort_callback_user_data; + // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; void * logits_filter_callback_user_data; @@ -485,6 +494,7 @@ extern "C" { // Get whether the next segment is predicted as a speaker turn WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment); // Get the text of the specified segment WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); diff --git a/scripts/sync-whisper.sh b/scripts/sync-whisper.sh index c17a091da..6976c77dc 100755 --- a/scripts/sync-whisper.sh +++ b/scripts/sync-whisper.sh @@ -1,20 +1,31 @@ #!/bin/bash cp -rpv ../whisper.cpp/ggml.c src/ggml.c +cp -rpv ../whisper.cpp/ggml-impl.h src/ggml-impl.h cp -rpv ../whisper.cpp/ggml-alloc.c src/ggml-alloc.c -cp -rpv ../whisper.cpp/ggml-cuda.h src/ggml-cuda.h +cp -rpv ../whisper.cpp/ggml-backend-impl.h src/ggml-backend-impl.h +cp -rpv ../whisper.cpp/ggml-backend.c src/ggml-backend.c cp -rpv ../whisper.cpp/ggml-cuda.cu src/ggml-cuda.cu -cp -rpv ../whisper.cpp/ggml-opencl.h src/ggml-opencl.h -cp -rpv ../whisper.cpp/ggml-opencl.cpp src/ggml-opencl.cpp +cp -rpv ../whisper.cpp/ggml-cuda.h src/ggml-cuda.h cp -rpv ../whisper.cpp/ggml-metal.h src/ggml-metal.h cp -rpv ../whisper.cpp/ggml-metal.m src/ggml-metal.m cp -rpv ../whisper.cpp/ggml-metal.metal src/ggml-metal.metal +#cp -rpv ../whisper.cpp/ggml-mpi.h src/ggml-mpi.h +#cp -rpv ../whisper.cpp/ggml-mpi.m src/ggml-mpi.m +cp -rpv ../whisper.cpp/ggml-opencl.cpp src/ggml-opencl.cpp +cp -rpv ../whisper.cpp/ggml-opencl.h src/ggml-opencl.h +cp -rpv ../whisper.cpp/ggml-quants.c src/ggml-quants.c +cp -rpv ../whisper.cpp/ggml-quants.h src/ggml-quants.h + cp -rpv ../whisper.cpp/ggml.h include/ggml/ggml.h cp -rpv ../whisper.cpp/ggml-alloc.h include/ggml/ggml-alloc.h +cp -rpv ../whisper.cpp/ggml-backend.h include/ggml/ggml-backend.h + cp -rpv ../whisper.cpp/examples/common.h examples/common.h cp -rpv ../whisper.cpp/examples/common.cpp examples/common.cpp cp -rpv ../whisper.cpp/examples/common-ggml.h examples/common-ggml.h cp -rpv ../whisper.cpp/examples/common-ggml.cpp examples/common-ggml.cpp + cp -rpv ../whisper.cpp/whisper.h examples/whisper/whisper.h cp -rpv ../whisper.cpp/whisper.cpp examples/whisper/whisper.cpp cp -rpv ../whisper.cpp/examples/main/main.cpp examples/whisper/main.cpp diff --git a/src/ggml-impl.h b/src/ggml-impl.h index d88f26144..06c07339e 100644 --- a/src/ggml-impl.h +++ b/src/ggml-impl.h @@ -39,12 +39,6 @@ extern "C" { #endif #endif -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - // 16-bit float // on Arm, we use __fp16 // on x86, we use uint16_t diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 9136a7cf6..43d0dff09 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -210,7 +210,13 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * sourcePath; + NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + if (ggmlMetalPathResources) { + sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } if (sourcePath == nil) { GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); sourcePath = @"ggml-metal.metal"; diff --git a/src/ggml-quants.c b/src/ggml-quants.c index 740be6dc5..a48eda732 100644 --- a/src/ggml-quants.c +++ b/src/ggml-quants.c @@ -14,26 +14,6 @@ // #include -#if !defined(__aarch64__) -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} -#endif - #else #ifdef __wasm_simd128__ @@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if !defined(__riscv) && !defined(__s390__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) #include #endif #endif #endif #endif #endif +#endif #ifdef __riscv_v_intrinsic #include @@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #undef MIN #undef MAX + #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #if defined(__ARM_NEON) - #if !defined(__aarch64__) +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 + #endif #endif @@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t vzero = vdupq_n_s32(0); #endif - int8x16x2_t q2bytes; + ggml_int8x16x2_t q2bytes; uint8_t aux[16]; float sum = 0; @@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri vst1q_u8(aux, scales); const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), @@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #endif #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ MULTIPLY_ACCUM_WITH_SCALE((index)); @@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); MULTIPLY_ACCUM_WITH_SCALE(0); @@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t vzero = vdupq_n_s32(0); #endif - int8x16x4_t q2bytes; + ggml_int8x16x4_t q2bytes; uint32_t aux32[2]; const uint8_t * scales = (const uint8_t *)aux32; @@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t q2bits = vld1q_u8(q2); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); @@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m3 = vshlq_n_u8(m0, 3); const int8_t m32 = 32; - int8x16x4_t q3bytes; + ggml_int8x16x4_t q3bytes; float sum = 0; @@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict qh = x[i].hmask; const int8_t * restrict q8 = y[i].qs; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - uint8x16x4_t q3h; + ggml_uint8x16x4_t q3h; int32_t isum = 0; @@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); @@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m3b = vdupq_n_u8(0x3); const uint8x16_t mh = vdupq_n_u8(4); - int8x16x4_t q3bytes; + ggml_int8x16x4_t q3bytes; uint16_t aux16[2]; int8_t * scales = (int8_t *)aux16; @@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - uint8x16x4_t q3h; + ggml_uint8x16x4_t q3h; const uint8x8_t hbits = vld1_u8(x[i].hmask); const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); const uint16_t a = *(const uint16_t *)x[i].scales; aux16[0] = a & 0x0f0f; @@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; float sumf = 0; @@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/64; ++j) { - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; #ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); @@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri sumi2 += vaddvq_s32(p2) * scales[2*j+1]; #else - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - int8x16x2_t q4bytes; - int8x16x4_t q8bytes; + ggml_int8x16x2_t q4bytes; + ggml_int8x16x4_t q8bytes; float sum_mins = 0.f; @@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d[0]; - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); #ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x4(q8); + q8bytes = ggml_vld1q_s8_x4(q8); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); @@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; #else - q8bytes = vld1q_s8_x4(q8); + q8bytes = ggml_vld1q_s8_x4(q8); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x4_t q5bytes; + ggml_int8x16x4_t q5bytes; float sumf = 0; @@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - uint8x16x4_t q5h; + ggml_uint8x16x4_t q5h; int32_t sumi = 0; for (int j = 0; j < QK_K/64; ++j) { - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); @@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x4_t q5bytes; - uint8x16x4_t q5h; + ggml_int8x16x4_t q5bytes; + ggml_uint8x16x4_t q5h; float sumf = 0; @@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x8_t qhbits = vld1_u8(qh); - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); @@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mone = vdupq_n_u8(3); - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { @@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const int8_t * restrict scale = x[i].scales; - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), @@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); @@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri scale += 2; #endif - q8bytes = vld1q_s8_x4(q8); q8 += 64; + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; shifted = vshrq_n_u8(qhbits.val[0], 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); @@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mone = vdupq_n_u8(3); - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { @@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri int32_t isum = 0; - uint8x16_t qhbits = vld1q_u8(qh); - uint8x16x2_t q6bits = vld1q_u8_x2(q6); - int8x16x4_t q8bytes = vld1q_s8_x4(q8); + uint8x16_t qhbits = vld1q_u8(qh); + ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); uint8x16_t shifted = vshrq_n_u8(qhbits, 2); diff --git a/src/ggml.c b/src/ggml.c index af9f96c33..018f0ce0b 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -143,6 +143,12 @@ void ggml_print_backtrace(void) { } #endif +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 @@ -604,6 +610,18 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { // simd mappings // +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +#endif +#endif + // we define a common set of C macros which map to specific intrinsics based on the current architecture // we then implement the fundamental computation operations below using only these macros // adding support for new architectures requires to define the corresponding SIMD macros