Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : add context param for disable gpu #1293

Merged
merged 17 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
return (*Context)(ctx)
} else {
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;

import java.util.List;

Expand All @@ -23,8 +24,9 @@ public class WhisperContext extends Structure {
public PointerByReference vocab;
public PointerByReference state;

/** populated by whisper_init_from_file() */
/** populated by whisper_init_from_file_with_params() */
String path_model;
WhisperContextParams params;

// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;

Expand All @@ -15,8 +16,9 @@
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
private Pointer paramsPointer = null;
private Pointer greedyParamsPointer = null;
private Pointer beamParamsPointer = null;

public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
Expand All @@ -31,6 +33,18 @@ public File modelDir() {
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
initContextImpl(modelPath, getContextDefaultParams());
}

/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @param params - params to use when initialising the context
*/
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
initContextImpl(modelPath, params);
}

private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
Expand All @@ -43,13 +57,26 @@ public void initContext(String modelPath) throws FileNotFoundException {
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}

ctx = lib.whisper_init_from_file(modelPath);
ctx = lib.whisper_init_from_file_with_params(modelPath, params);

if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}

/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
public WhisperContextParams getContextDefaultParams() {
paramsPointer = lib.whisper_context_default_params_by_ref();
WhisperContextParams params = new WhisperContextParams(paramsPointer);
params.read();
return params;
}

/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
Expand All @@ -63,15 +90,15 @@ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy)

// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (greedyParamsPointer == null) {
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
pointer = greedyParamsPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (beamParamsPointer == null) {
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
pointer = beamParamsPointer;
}

WhisperFullParams params = new WhisperFullParams(pointer);
Expand All @@ -93,13 +120,17 @@ private void freeContext() {
}

private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
if (paramsPointer != null) {
Native.free(Pointer.nativeValue(paramsPointer));
paramsPointer = null;
}
if (greedyParamsPointer != null) {
Native.free(Pointer.nativeValue(greedyParamsPointer));
greedyParamsPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
if (beamParamsPointer != null) {
Native.free(Pointer.nativeValue(beamParamsPointer));
beamParamsPointer = null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;

public interface WhisperCppJnaLibrary extends Library {
Expand All @@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library {
String whisper_print_system_info();

/**
* Allocate (almost) all memory needed for the model by loading from a file.
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);

/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
Pointer whisper_context_default_params_by_ref();

void whisper_free_context_params(Pointer params);

/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);

/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.github.ggerganov.whispercpp.params;

import com.sun.jna.*;

import java.util.Arrays;
import java.util.List;

/**
* Parameters for the whisper_init_from_file_with_params() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_context_default_params()
*/
public class WhisperContextParams extends Structure {

public WhisperContextParams(Pointer p) {
super(p);
}

/** Use GPU for inference Number (default = true) */
public CBool use_gpu;

/** Use GPU for inference Number (default = true) */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}

@Override
protected List<String> getFieldOrder() {
return Arrays.asList("use_gpu");
}
}
2 changes: 1 addition & 1 deletion bindings/javascript/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str());
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_context != nullptr) {
return true;
} else {
Expand Down
2 changes: 1 addition & 1 deletion bindings/ruby/ext/ruby_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
Expand Down
1 change: 1 addition & 0 deletions examples/addon.node/__test__/whisper.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const whisperParamsMock = {
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true,
};

describe("Run whisper.node", () => {
Expand Down
7 changes: 6 additions & 1 deletion examples/addon.node/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct whisper_params {
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;

std::string language = "en";
std::string prompt;
Expand Down Expand Up @@ -153,7 +154,9 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {

// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
Expand Down Expand Up @@ -315,10 +318,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();

params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu;

Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);
Expand Down
1 change: 1 addition & 0 deletions examples/addon.node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const whisperParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: "../../samples/jfk.wav",
use_gpu: true,
};

const arguments = process.argv.slice(2);
Expand Down
2 changes: 1 addition & 1 deletion examples/bench.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();
Expand Down
15 changes: 11 additions & 4 deletions examples/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct whisper_params {
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat

std::string model = "models/ggml-base.en.bin";

bool use_gpu = true;
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand All @@ -23,9 +25,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand All @@ -45,6 +48,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " %-7s 0 - whisper\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
Expand All @@ -54,7 +58,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
int whisper_bench_full(const whisper_params & params) {
// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

{
fprintf(stderr, "\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/command.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
Expand Down
8 changes: 7 additions & 1 deletion examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct whisper_params {
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;

std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -68,6 +69,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
Expand Down Expand Up @@ -101,6 +103,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
Expand Down Expand Up @@ -610,7 +613,10 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

// print some info about the processing
{
Expand Down
Loading
Loading