Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : Metal and ggml-alloc support #1270

Merged
merged 44 commits into from
Sep 15, 2023
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
fbc3f80
metal : init
ggerganov Sep 10, 2023
949ab63
whisper : factor out graph builds
ggerganov Sep 10, 2023
bed5ad6
whisper : allocate encoder and decoder using ggml-alloc
ggerganov Sep 10, 2023
af6f67b
whisper : ggml-alloc is now supported
ggerganov Sep 10, 2023
fa672b4
whisper : CoreML support ggml-alloc
ggerganov Sep 10, 2023
794e8fe
build : fix ggml-alloc
ggerganov Sep 10, 2023
9a78b72
ios : update submodule
ggerganov Sep 10, 2023
06d1d28
extra : update sync-ggml.sh script to also sync ggml-alloc
ggerganov Sep 10, 2023
4d9acc6
ci : see if this is causing the crash
ggerganov Sep 10, 2023
2770d46
whisper : refactor ggml-alloc init
ggerganov Sep 11, 2023
4845b9e
whisper.android : try to fix build
ggerganov Sep 11, 2023
d3b2dd4
whisper : initial Metal version
ggerganov Sep 11, 2023
de94c78
Merge branch 'master' into metal-and-alloc
ggerganov Sep 12, 2023
3b9979a
ci : try to debug vmem issue
ggerganov Sep 12, 2023
fbc9ddc
metal : decoder works on GPU!
ggerganov Sep 12, 2023
79a8805
metal : add multi-decoder support
ggerganov Sep 12, 2023
9fdd415
ggml : fix ggml_nbytes (probably temp solution)
ggerganov Sep 12, 2023
cd47637
metal : run "cross" step on the GPU
ggerganov Sep 12, 2023
ec9a7db
whisper : remove ggml_repeat in the encoder
ggerganov Sep 12, 2023
3074a7f
whisper : offload the Encoder to Metal
ggerganov Sep 12, 2023
905c944
ggml : use simpler ggml_bytes() implementation
ggerganov Sep 13, 2023
b19888c
ggml-alloc : try to make CI happy by reducing vram to 128GB
ggerganov Sep 13, 2023
254b687
whisper : add whisper_allocr to wrap ggml_allocr
ggerganov Sep 13, 2023
b6f0966
whisper : factor out alloc init in a function
ggerganov Sep 13, 2023
77f4bf4
cmake : update to support Metal build
ggerganov Sep 13, 2023
796f84c
whisper : add <functional> header
ggerganov Sep 13, 2023
181bb8c
objc : fix build (no Metal yet)
ggerganov Sep 13, 2023
257d794
ios : add Metal support
ggerganov Sep 13, 2023
16db4da
swiftui : fix build
ggerganov Sep 13, 2023
8e8daa8
metal : speed-up KQ multiplication
ggerganov Sep 13, 2023
ecb23fb
metal : sync latest llama.cpp kernels
ggerganov Sep 13, 2023
23277d2
readme : add Metal info
ggerganov Sep 13, 2023
d37f56e
ios : update submodule
ggerganov Sep 13, 2023
d863f72
coreml : add code to toggle Core ML config (CPU, ANE, GPU)
ggerganov Sep 13, 2023
f408c64
bench : fix timings by running a pre-heat
ggerganov Sep 13, 2023
e81c67a
bench : start benching the decoder
ggerganov Sep 14, 2023
af947cb
whisper : add ggml_mul_mat_pad
ggerganov Sep 14, 2023
c46167f
bench : fix uninitialized vars
ggerganov Sep 14, 2023
f365543
whisper : add comment for disabling mul-mat padding
ggerganov Sep 14, 2023
2b4160a
whisper : add description of ggml_mul_mat_pad
ggerganov Sep 14, 2023
0d5e4cd
whisper : clean-up ggml_mul_mat_pad
ggerganov Sep 14, 2023
bfcb2a2
metal : remove the "concurrent" flag
ggerganov Sep 14, 2023
a166457
bench : variable n_past
ggerganov Sep 14, 2023
3ac0558
ios : update SPM package
ggerganov Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
whisper : ggml-alloc is now supported
  • Loading branch information
ggerganov committed Sep 10, 2023
commit af6f67b251dc78acdcd76d5f47df40ca454ec332
109 changes: 88 additions & 21 deletions whisper.cpp
Original file line number Diff line number Diff line change
@@ -120,6 +120,21 @@ static void byteswap_tensor(ggml_tensor * tensor) {
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16

//
// ggml helpers
//

static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);

if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}

ggml_graph_compute(graph, &plan);
}

// available whisper models
enum e_model {
MODEL_UNKNOWN,
@@ -606,6 +621,9 @@ struct whisper_state {
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;

// reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer;

// ggml-alloc
std::vector<uint8_t> buf_encode;
std::vector<uint8_t> buf_encode_post;
@@ -1407,6 +1425,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_allocr * alloc = wstate.alloc_encode;

struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
ggml_allocr_alloc(alloc, mel);

assert(mel->type == GGML_TYPE_F32);
if (!ggml_allocr_is_measure(alloc)) {
float * dst = (float *) mel->data;
@@ -1796,6 +1816,32 @@ static bool whisper_encode_internal(
const int n_threads) {
const int64_t t_start_us = ggml_time_us();

// encoder
{
auto & alloc = wstate.alloc_encode;

ggml_allocr_reset(alloc);

ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate, mel_offset);

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}

// encoder_post
{
auto & alloc = wstate.alloc_encode_post;

ggml_allocr_reset(alloc);

ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate);

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}

// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

wstate.t_encode_us += ggml_time_us() - t_start_us;
@@ -1841,11 +1887,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_allocr * alloc = wstate.alloc_decode;

struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(alloc, embd);

if (!ggml_allocr_is_measure(alloc)) {
memcpy(embd->data, tokens, N*ggml_element_size(embd));
}

struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(alloc, position);

if (!ggml_allocr_is_measure(alloc)) {
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
@@ -2162,33 +2212,51 @@ static bool whisper_decode_internal(
const int n_tokens,
const int n_past,
const int n_threads) {
//const int64_t t_start_us = ggml_time_us();
const int64_t t_start_us = ggml_time_us();

//auto & logits_out = wstate.logits;
const auto & model = wctx.model;
const auto & hparams = model.hparams;

//const int n_vocab = hparams.n_vocab;
const int n_vocab = hparams.n_vocab;

// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
auto & logits_out = wstate.logits;

//// extract logits for all N tokens
////logits_out.resize(N*n_vocab);
////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
struct ggml_tensor * logits;

//// extract logits only for the last token
//logits_out.resize(n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
// decoder
{
auto & alloc = wstate.alloc_encode;

//if (N > 1) {
// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// // ggml_used_mem(ctx0)/1024.0/1024.0,
// // wstate.get_buf_max_mem(0)/1024.0/1024.0,
// // wstate.get_buf_max_mem(1)/1024.0/1024.0,
// // wstate.get_buf_max_mem(2)/1024.0/1024.0,
// // wstate.get_buf_max_mem(3)/1024.0/1024.0);
//}
ggml_allocr_reset(alloc);

//wstate.t_decode_us += ggml_time_us() - t_start_us;
//wstate.n_decode++;
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);

logits = gf->nodes[gf->n_nodes - 1];
}

// extract logits for all N tokens
//logits_out.resize(N*n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);

// extract logits only for the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);

if (n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
}

wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++;

return true;
}
@@ -2759,7 +2827,6 @@ int whisper_ctx_init_openvino_encoder(
}

struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {

log("%s: loading model from '%s'\n", __func__, path_model);

auto fin = std::ifstream(path_model, std::ios::binary);