diff --git a/bindings/ruby/.gitignore b/bindings/ruby/.gitignore index e04a90a9c69..6e3b3be0e24 100644 --- a/bindings/ruby/.gitignore +++ b/bindings/ruby/.gitignore @@ -1,3 +1,5 @@ LICENSE pkg/ -lib/whisper.* +lib/whisper.so +lib/whisper.bundle +lib/whisper.dll diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 05e19eb6e09..e7065bf9d70 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -22,7 +22,7 @@ Usage ```ruby require "whisper" -whisper = Whisper::Context.new("path/to/model.bin") +whisper = Whisper::Context.new(Whisper::Model["base"]) params = Whisper::Params.new params.language = "en" @@ -41,21 +41,60 @@ end ### Preparing model ### -Use script to download model file(s): +Some models are prepared up-front: -```bash -git clone https://github.com/ggerganov/whisper.cpp.git -cd whisper.cpp -sh ./models/download-ggml-model.sh base.en +```ruby +base_en = Whisper::Model["base.en"] +whisper = Whisper::Context.new(base_en) +``` + +At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`: + +```ruby +Whisper::Model["base"].clear_cache ``` -There are some types of models. See [models][] page for details. +You can see the list of prepared model names by `Whisper::Model.preconverted_model_names`: + +```ruby +puts Whisper::Model.preconverted_model_names +# tiny +# tiny.en +# tiny-q5_1 +# tiny.en-q5_1 +# tiny-q8_0 +# base +# base.en +# base-q5_1 +# base.en-q5_1 +# base-q8_0 +# : +# : +``` + +You can also use local model files you prepared: + +```ruby +whisper = Whisper::Context.new("path/to/your/model.bin") +``` + +Or, you can download model files: + +```ruby +model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin") +whisper = Whisper::Context.new(model_uri) +``` + +See [models][] page for details. ### Preparing audio file ### Currently, whisper.cpp accepts only 16-bit WAV files. -### API ### +API +--- + +### Segments ### Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: @@ -107,10 +146,12 @@ whisper.transcribe("path/to/audio.wav", params) ``` +### Models ### + You can see model information: ```ruby -whisper = Whisper::Context.new("path/to/model.bin") +whisper = Whisper::Context.new(Whisper::Model["base"]) model = whisper.model model.n_vocab # => 51864 @@ -128,6 +169,8 @@ model.type # => "base" ``` +### Logging ### + You can set log callback: ```ruby @@ -160,6 +203,8 @@ Whisper.log_set ->(level, buffer, user_data) { Whisper::Context.new(MODEL) ``` +### Low-level API to transcribe ### + You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. ```ruby @@ -169,7 +214,7 @@ require "wavefile" reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) samples = reader.enum_for(:each_buffer).map(&:samples).flatten -whisper = Whisper::Context.new("path/to/model.bin") +whisper = Whisper::Context.new(Whisper::Model["base"]) whisper.full(Whisper::Params.new, samples) whisper.each_segment do |segment| puts segment.text diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 5f6303ba055..f640dce94f2 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -18,19 +18,9 @@ EXTSOURCES.each do |src| end CLEAN.include SOURCES -CLEAN.include FileList[ - "ext/*.o", - "ext/*.metal", - "ext/whisper.{so,bundle,dll}", - "ext/depend" - ] +CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"] -task build: FileList[ - "ext/Makefile", - "ext/ruby_whisper.h", - "ext/ruby_whisper.cpp", - "whispercpp.gemspec", - ] +task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"] directory "pkg" CLOBBER.include "pkg" diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 3804ab7e3e4..e96a8584c94 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -2,7 +2,6 @@ Makefile whisper.so whisper.bundle whisper.dll -depend scripts/get-flags.mk *.o *.c diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 6d76a7cd9ac..59388ffe0bc 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,7 +1,7 @@ require 'mkmf' # need to use c++ compiler flags -$CXXFLAGS << ' -std=c++11' +$CXXFLAGS << ' -std=c++17' $LDFLAGS << ' -lstdc++' @@ -35,10 +35,10 @@ $GGML_METAL_EMBED_LIBRARY = true end -$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iinclude -Isrc -Iexamples' +$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples' $MK_CFLAGS = '-std=c11 -fPIC' -$MK_CXXFLAGS = '-std=c++11 -fPIC' -$MK_NVCCFLAGS = '-std=c++11' +$MK_CXXFLAGS = '-std=c++17 -fPIC' +$MK_NVCCFLAGS = '-std=c++17' $MK_LDFLAGS = '' $OBJ_GGML = [] diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index bb6bae8a859..83fc53fc058 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -45,6 +45,7 @@ static ID id_to_enum; static ID id_length; static ID id_next; static ID id_new; +static ID id_to_path; static bool is_log_callback_finalized = false; @@ -194,7 +195,9 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { /* * call-seq: + * new(Whisper::Model["base.en"]) -> Whisper::Context * new("path/to/model.bin") -> Whisper::Context + * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context */ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; @@ -204,6 +207,9 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_scan_args(argc, argv, "01", &whisper_model_file_path); Data_Get_Struct(self, ruby_whisper, rw); + if (rb_respond_to(whisper_model_file_path, id_to_path)) { + whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); + } if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } @@ -1733,6 +1739,7 @@ void Init_whisper() { id_length = rb_intern("length"); id_next = rb_intern("next"); id_new = rb_intern("new"); + id_to_path = rb_intern("to_path"); mWhisper = rb_define_module("Whisper"); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); diff --git a/bindings/ruby/lib/whisper.rb b/bindings/ruby/lib/whisper.rb new file mode 100644 index 00000000000..4c8e01e2eb9 --- /dev/null +++ b/bindings/ruby/lib/whisper.rb @@ -0,0 +1,2 @@ +require "whisper.so" +require "whisper/model" diff --git a/bindings/ruby/lib/whisper/model.rb b/bindings/ruby/lib/whisper/model.rb new file mode 100644 index 00000000000..be67dff368b --- /dev/null +++ b/bindings/ruby/lib/whisper/model.rb @@ -0,0 +1,159 @@ +require "whisper.so" +require "uri" +require "net/http" +require "pathname" +require "io/console/size" + +class Whisper::Model + class URI + def initialize(uri) + @uri = URI(uri) + end + + def to_path + cache + cache_path.to_path + end + + def clear_cache + path = cache_path + path.delete if path.exist? + end + + private + + def cache_path + base_cache_dir/@uri.host/@uri.path[1..] + end + + def base_cache_dir + base = case RUBY_PLATFORM + when /mswin|mingw/ + ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local" + when /darwin/ + Pathname(Dir.home)/"Library/Caches" + else + ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache" + end + base/"whisper.cpp" + end + + def cache + path = cache_path + headers = {} + headers["if-modified-since"] = path.mtime.httpdate if path.exist? + request @uri, headers + path + end + + def request(uri, headers) + Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http| + request = Net::HTTP::Get.new(uri, headers) + http.request request do |response| + case response + when Net::HTTPNotModified + # noop + when Net::HTTPOK + download response + when Net::HTTPRedirection + request URI(response["location"]) + else + raise response + end + end + end + end + + def download(response) + path = cache_path + path.dirname.mkpath unless path.dirname.exist? + downloading_path = Pathname("#{path}.downloading") + size = response.content_length + downloading_path.open "wb" do |file| + downloaded = 0 + response.read_body do |chunk| + file << chunk + downloaded += chunk.bytesize + show_progress downloaded, size + end + end + downloading_path.rename path + end + + def show_progress(current, size) + return unless size + + unless @prev + @prev = Time.now + $stderr.puts "Downloading #{@uri}" + end + + now = Time.now + return if now - @prev < 1 && current < size + + progress_width = 20 + progress = current.to_f / size + arrow_length = progress * progress_width + arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length) + line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})" + padding = ' ' * ($stderr.winsize[1] - line.size) + $stderr.print "\r#{line}#{padding}" + $stderr.puts if current >= size + @prev = now + end + + def format_bytesize(bytesize) + return "0.0 B" if bytesize.zero? + + units = %w[B KiB MiB GiB TiB] + exp = (Math.log(bytesize) / Math.log(1024)).to_i + format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp]) + end + end + + @names = {} + %w[ + tiny + tiny.en + tiny-q5_1 + tiny.en-q5_1 + tiny-q8_0 + base + base.en + base-q5_1 + base.en-q5_1 + base-q8_0 + small + small.en + small.en-tdrz + small-q5_1 + small.en-q5_1 + small-q8_0 + medium + medium.en + medium-q5_0 + medium.en-q5_0 + medium-q8_0 + large-v1 + large-v2 + large-v2-q5_0 + large-v2-8_0 + large-v3 + large-v3-q5_0 + large-v3-turbo + large-v3-turbo-q5_0 + large-v3-turbo-q8_0 + ].each do |name| + @names[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") + end + + class << self + def [](name) + @names[name] + end + + def preconverted_model_names + @names.keys + end + end +end diff --git a/bindings/ruby/tests/jfk_reader/jfk_reader.c b/bindings/ruby/tests/jfk_reader/jfk_reader.c index a0688374d06..6657176e767 100644 --- a/bindings/ruby/tests/jfk_reader/jfk_reader.c +++ b/bindings/ruby/tests/jfk_reader/jfk_reader.c @@ -60,49 +60,9 @@ static const rb_memory_view_entry_t jfk_reader_view_entry = { jfk_reader_memory_view_available_p }; -static VALUE -read_jfk(int argc, VALUE *argv, VALUE obj) -{ - const char *audio_path_str = StringValueCStr(argv[0]); - const int n_samples = 176000; - - short samples[n_samples]; - FILE *file = fopen(audio_path_str, "rb"); - - fseek(file, 78, SEEK_SET); - fread(samples, sizeof(short), n_samples, file); - fclose(file); - - VALUE rb_samples = rb_ary_new2(n_samples); - for (int i = 0; i < n_samples; i++) { - rb_ary_push(rb_samples, INT2FIX(samples[i])); - } - - VALUE rb_data = rb_ary_new2(n_samples); - for (int i = 0; i < n_samples; i++) { - rb_ary_push(rb_data, DBL2NUM(samples[i]/32768.0)); - } - - float data[n_samples]; - for (int i = 0; i < n_samples; i++) { - data[i] = samples[i]/32768.0; - } - void *c_data = (void *)data; - VALUE rb_void = rb_enc_str_new((const char *)c_data, sizeof(data), rb_ascii8bit_encoding()); - - VALUE rb_result = rb_ary_new3(3, rb_samples, rb_data, rb_void); - return rb_result; -} - void Init_jfk_reader(void) { VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject); rb_memory_view_register(cJFKReader, &jfk_reader_view_entry); rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1); - - - rb_define_global_function("read_jfk", read_jfk, -1); - - - } diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb index 2310522a644..598dbde9f13 100644 --- a/bindings/ruby/tests/test_model.rb +++ b/bindings/ruby/tests/test_model.rb @@ -1,4 +1,5 @@ require_relative "helper" +require "pathname" class TestModel < TestBase def test_model @@ -41,4 +42,23 @@ def test_gc assert_equal 1, model.ftype assert_equal "base", model.type end + + def test_pathname + path = Pathname(MODEL) + whisper = Whisper::Context.new(path) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end end