Skip to content

Commit

Permalink
whisper : fix gpu device selection (#2728)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov authored Jan 13, 2025
1 parent e940fbf commit eb68324
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);

ggml_backend_dev_t dev = nullptr;

int cnt = 0;
if (params.use_gpu) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
if (cnt == 0 || cnt == params.gpu_device) {
dev = dev_cur;
}

if (++cnt > params.gpu_device) {
break;
}
return result;
}
}
}

return nullptr;
if (dev == nullptr) {
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
return nullptr;
}

WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
}

return result;
}

static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
Expand Down Expand Up @@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
}

static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();

if (!params.use_gpu) {
return ggml_backend_cpu_buffer_type();
return result;
}

// if we have a GPU device - use it
int cnt = 0;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
return ggml_backend_dev_buffer_type(dev);
if (cnt == 0 || cnt == params.gpu_device) {
result = ggml_backend_dev_buffer_type(dev);
}

if (++cnt > params.gpu_device) {
break;
}
}
}

return ggml_backend_cpu_buffer_type();
return result;
}

// load the model from a ggml file
Expand Down

0 comments on commit eb68324

Please sign in to comment.