Skip to content

Commit

Permalink
whisper : add context param to disable gpu (ggerganov#1293)
Browse files Browse the repository at this point in the history
* whisper : check state->ctx_metal not null

* whisper : add whisper_context_params { use_gpu }

* whisper : new API with params & deprecate old API

* examples : use no-gpu param && whisper_init_from_file_with_params

* whisper.objc : enable metal & disable on simulator

* whisper.swiftui, metal : enable metal & support load default.metallib

* whisper.android : use new API

* bindings : use new API

* addon.node : fix build & test

* bindings : updata java binding

* bindings : add missing whisper_context_default_params_by_ref WHISPER_API for java

* metal : use SWIFTPM_MODULE_BUNDLE for GGML_SWIFT and reuse library load

* metal : move bundle var into block

* metal : use SWIFT_PACKAGE instead of GGML_SWIFT

* style : minor updates

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
jhen0409 and ggerganov authored Nov 6, 2023
1 parent 77f8354 commit 3c2aba1
Show file tree
Hide file tree
Showing 29 changed files with 421 additions and 170 deletions.
2 changes: 1 addition & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
return (*Context)(ctx)
} else {
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;

import java.util.List;

Expand All @@ -23,8 +24,9 @@ public class WhisperContext extends Structure {
public PointerByReference vocab;
public PointerByReference state;

/** populated by whisper_init_from_file() */
/** populated by whisper_init_from_file_with_params() */
String path_model;
WhisperContextParams params;

// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;

Expand All @@ -15,8 +16,9 @@
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
private Pointer paramsPointer = null;
private Pointer greedyParamsPointer = null;
private Pointer beamParamsPointer = null;

public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
Expand All @@ -31,6 +33,18 @@ public File modelDir() {
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
initContextImpl(modelPath, getContextDefaultParams());
}

/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @param params - params to use when initialising the context
*/
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
initContextImpl(modelPath, params);
}

private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
Expand All @@ -43,13 +57,26 @@ public void initContext(String modelPath) throws FileNotFoundException {
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}

ctx = lib.whisper_init_from_file(modelPath);
ctx = lib.whisper_init_from_file_with_params(modelPath, params);

if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}

/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
public WhisperContextParams getContextDefaultParams() {
paramsPointer = lib.whisper_context_default_params_by_ref();
WhisperContextParams params = new WhisperContextParams(paramsPointer);
params.read();
return params;
}

/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
Expand All @@ -63,15 +90,15 @@ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy)

// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (greedyParamsPointer == null) {
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
pointer = greedyParamsPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (beamParamsPointer == null) {
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
pointer = beamParamsPointer;
}

WhisperFullParams params = new WhisperFullParams(pointer);
Expand All @@ -93,13 +120,17 @@ private void freeContext() {
}

private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
if (paramsPointer != null) {
Native.free(Pointer.nativeValue(paramsPointer));
paramsPointer = null;
}
if (greedyParamsPointer != null) {
Native.free(Pointer.nativeValue(greedyParamsPointer));
greedyParamsPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
if (beamParamsPointer != null) {
Native.free(Pointer.nativeValue(beamParamsPointer));
beamParamsPointer = null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;

public interface WhisperCppJnaLibrary extends Library {
Expand All @@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library {
String whisper_print_system_info();

/**
* Allocate (almost) all memory needed for the model by loading from a file.
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);

/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
Pointer whisper_context_default_params_by_ref();

void whisper_free_context_params(Pointer params);

/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);

/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.github.ggerganov.whispercpp.params;

import com.sun.jna.*;

import java.util.Arrays;
import java.util.List;

/**
* Parameters for the whisper_init_from_file_with_params() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_context_default_params()
*/
public class WhisperContextParams extends Structure {

public WhisperContextParams(Pointer p) {
super(p);
}

/** Use GPU for inference Number (default = true) */
public CBool use_gpu;

/** Use GPU for inference Number (default = true) */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}

@Override
protected List<String> getFieldOrder() {
return Arrays.asList("use_gpu");
}
}
2 changes: 1 addition & 1 deletion bindings/javascript/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str());
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_context != nullptr) {
return true;
} else {
Expand Down
2 changes: 1 addition & 1 deletion bindings/ruby/ext/ruby_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
Expand Down
1 change: 1 addition & 0 deletions examples/addon.node/__test__/whisper.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const whisperParamsMock = {
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true,
};

describe("Run whisper.node", () => {
Expand Down
7 changes: 6 additions & 1 deletion examples/addon.node/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct whisper_params {
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;

std::string language = "en";
std::string prompt;
Expand Down Expand Up @@ -153,7 +154,9 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {

// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
Expand Down Expand Up @@ -315,10 +318,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();

params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu;

Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);
Expand Down
1 change: 1 addition & 0 deletions examples/addon.node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const whisperParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: "../../samples/jfk.wav",
use_gpu: true,
};

const arguments = process.argv.slice(2);
Expand Down
2 changes: 1 addition & 1 deletion examples/bench.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();
Expand Down
15 changes: 11 additions & 4 deletions examples/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct whisper_params {
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat

std::string model = "models/ggml-base.en.bin";

bool use_gpu = true;
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand All @@ -23,9 +25,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand All @@ -45,6 +48,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " %-7s 0 - whisper\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
Expand All @@ -54,7 +58,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
int whisper_bench_full(const whisper_params & params) {
// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

{
fprintf(stderr, "\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/command.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
Expand Down
8 changes: 7 additions & 1 deletion examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct whisper_params {
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;

std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -68,6 +69,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
Expand Down Expand Up @@ -101,6 +103,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
Expand Down Expand Up @@ -610,7 +613,10 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

// print some info about the processing
{
Expand Down
Loading

0 comments on commit 3c2aba1

Please sign in to comment.