KitaitiMakoto commited on
Commit
7cb9a0e
·
unverified ·
1 Parent(s): 9cbd99a

ruby : Make context accept initial parameters, API to retrieve a segment and more (#2749)

Browse files

* Fix type signature for Whisper.log_set

* Use cache file for model when offline

* Extract ruby_whisper_transcribe() into a file

* Extract Whisper::Error

* Use FileList for ext/*.{c,cpp,h}

* Extract Whisper::Segment

* Extract Whisper::Model

* Extract Whisper::Params

* Extract Whisper::Context

* Extract log_callback function

* Write base code in C rather than C++

* Use chdir instead of Dir.chdir in Rakefile

* Define alloc func for Whisper::Model

* Define Whisper::Params' calback and user data reader

* Add test for Whisper::Params.new with keyword arguments

* Make Whisper::Params.new accept keyword arguments

* Update type signatures

* Update README

* Update CLEAN targets

* Fix document comment for Whisper::Params#new_segment_callback=

* Use macro to define params

* Fix dependency of build task

* Set Whisper.finalize_log_callback visibility to private

* Make Whisper::Context#full and full_parallel return self

* Add test for Whisper::Context#full_get_segment

* Add Whisper::Context#full_get_segment

* Update signatures

* Update README

* Fix signature

* Resplace #initialize with .new in signature file [skip ci]

* Fix potential overflow

bindings/ruby/README.md CHANGED
@@ -24,14 +24,15 @@ require "whisper"
24
 
25
  whisper = Whisper::Context.new("base")
26
 
27
- params = Whisper::Params.new
28
- params.language = "en"
29
- params.offset = 10_000
30
- params.duration = 60_000
31
- params.max_text_tokens = 300
32
- params.translate = true
33
- params.print_timestamps = false
34
- params.initial_prompt = "Initial prompt here."
 
35
 
36
  whisper.transcribe("path/to/audio.wav", params) do |whole_text|
37
  puts whole_text
@@ -113,18 +114,18 @@ def format_time(time_ms)
113
  "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
114
  end
115
 
116
- whisper.transcribe("path/to/audio.wav", params)
117
-
118
- whisper.each_segment.with_index do |segment, index|
119
- line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
120
- nth: index + 1,
121
- st: format_time(segment.start_time),
122
- ed: format_time(segment.end_time),
123
- text: segment.text
124
- }
125
- line << " (speaker turned)" if segment.speaker_next_turn?
126
- puts line
127
- end
128
 
129
  ```
130
 
@@ -215,10 +216,11 @@ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :
215
  samples = reader.enum_for(:each_buffer).map(&:samples).flatten
216
 
217
  whisper = Whisper::Context.new("base")
218
- whisper.full(Whisper::Params.new, samples)
219
- whisper.each_segment do |segment|
220
- puts segment.text
221
- end
 
222
  ```
223
 
224
  The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
 
24
 
25
  whisper = Whisper::Context.new("base")
26
 
27
+ params = Whisper::Params.new(
28
+ language: "en",
29
+ offset: 10_000,
30
+ duration: 60_000,
31
+ max_text_tokens: 300,
32
+ translate: true,
33
+ print_timestamps: false,
34
+ initial_prompt: "Initial prompt here."
35
+ )
36
 
37
  whisper.transcribe("path/to/audio.wav", params) do |whole_text|
38
  puts whole_text
 
114
  "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
115
  end
116
 
117
+ whisper
118
+ .transcribe("path/to/audio.wav", params)
119
+ .each_segment.with_index do |segment, index|
120
+ line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
121
+ nth: index + 1,
122
+ st: format_time(segment.start_time),
123
+ ed: format_time(segment.end_time),
124
+ text: segment.text
125
+ }
126
+ line << " (speaker turned)" if segment.speaker_next_turn?
127
+ puts line
128
+ end
129
 
130
  ```
131
 
 
216
  samples = reader.enum_for(:each_buffer).map(&:samples).flatten
217
 
218
  whisper = Whisper::Context.new("base")
219
+ whisper
220
+ .full(Whisper::Params.new, samples)
221
+ .each_segment do |segment|
222
+ puts segment.text
223
+ end
224
  ```
225
 
226
  The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
bindings/ruby/Rakefile CHANGED
@@ -18,9 +18,11 @@ EXTSOURCES.each do |src|
18
  end
19
 
20
  CLEAN.include SOURCES
21
- CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
22
 
23
- task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
 
 
24
 
25
  directory "pkg"
26
  CLOBBER.include "pkg"
@@ -29,14 +31,14 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
29
  SO_FILE = File.join("ext", LIB_NAME)
30
  LIB_FILE = File.join("lib", LIB_NAME)
31
 
32
- file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
33
- Dir.chdir "ext" do
34
  ruby "extconf.rb"
35
  end
36
  end
37
 
38
  file SO_FILE => "ext/Makefile" do |t|
39
- Dir.chdir "ext" do
40
  sh "make"
41
  end
42
  end
@@ -54,7 +56,7 @@ end
54
 
55
  TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
56
  file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
57
- Dir.chdir "tests/jfk_reader" do
58
  ruby "extconf.rb"
59
  sh "make"
60
  end
 
18
  end
19
 
20
  CLEAN.include SOURCES
21
+ CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
22
 
23
+ SRC = FileList["ext/*.{c,cpp,h}"]
24
+
25
+ task build: SOURCES
26
 
27
  directory "pkg"
28
  CLOBBER.include "pkg"
 
31
  SO_FILE = File.join("ext", LIB_NAME)
32
  LIB_FILE = File.join("lib", LIB_NAME)
33
 
34
+ file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
35
+ chdir "ext" do
36
  ruby "extconf.rb"
37
  end
38
  end
39
 
40
  file SO_FILE => "ext/Makefile" do |t|
41
+ chdir "ext" do
42
  sh "make"
43
  end
44
  end
 
56
 
57
  TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
58
  file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
59
+ chdir "tests/jfk_reader" do
60
  ruby "extconf.rb"
61
  sh "make"
62
  end
bindings/ruby/ext/.gitignore CHANGED
@@ -4,10 +4,8 @@ whisper.bundle
4
  whisper.dll
5
  scripts/get-flags.mk
6
  *.o
7
- *.c
8
- *.cpp
9
- *.h
10
- *.m
11
- *.metal
12
- !ruby_whisper.cpp
13
- !ruby_whisper.h
 
4
  whisper.dll
5
  scripts/get-flags.mk
6
  *.o
7
+ /*/**/*.c
8
+ /*/**/*.cpp
9
+ /*/**/*.h
10
+ /*/**/*.m
11
+ /*/**/*.metal
 
 
bindings/ruby/ext/extconf.rb CHANGED
@@ -174,7 +174,14 @@ $OBJ_WHISPER <<
174
  'src/whisper.o'
175
 
176
  $objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
177
- $objs << "ruby_whisper.o"
 
 
 
 
 
 
 
178
 
179
  $CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
180
  $CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
 
174
  'src/whisper.o'
175
 
176
  $objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
177
+ $objs <<
178
+ "ruby_whisper.o" <<
179
+ "ruby_whisper_context.o" <<
180
+ "ruby_whisper_transcribe.o" <<
181
+ "ruby_whisper_params.o" <<
182
+ "ruby_whisper_error.o" <<
183
+ "ruby_whisper_segment.o" <<
184
+ "ruby_whisper_model.o"
185
 
186
  $CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
187
  $CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
bindings/ruby/ext/ruby_whisper.c ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
+ #include "ruby_whisper.h"
4
+
5
+ VALUE mWhisper;
6
+ VALUE cContext;
7
+ VALUE cParams;
8
+ VALUE eError;
9
+
10
+ VALUE cSegment;
11
+ VALUE cModel;
12
+
13
+ ID id_to_s;
14
+ ID id_call;
15
+ ID id___method__;
16
+ ID id_to_enum;
17
+ ID id_length;
18
+ ID id_next;
19
+ ID id_new;
20
+ ID id_to_path;
21
+ ID id_URI;
22
+ ID id_pre_converted_models;
23
+
24
+ static bool is_log_callback_finalized = false;
25
+
26
+ // High level API
27
+ extern VALUE ruby_whisper_segment_allocate(VALUE klass);
28
+
29
+ extern void init_ruby_whisper_context(VALUE *mWhisper);
30
+ extern void init_ruby_whisper_params(VALUE *mWhisper);
31
+ extern void init_ruby_whisper_error(VALUE *mWhisper);
32
+ extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
33
+ extern void init_ruby_whisper_model(VALUE *mWhisper);
34
+ extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
35
+
36
+ /*
37
+ * call-seq:
38
+ * lang_max_id -> Integer
39
+ */
40
+ static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
41
+ return INT2NUM(whisper_lang_max_id());
42
+ }
43
+
44
+ /*
45
+ * call-seq:
46
+ * lang_id(lang_name) -> Integer
47
+ */
48
+ static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
49
+ const char * lang_str = StringValueCStr(lang);
50
+ const int id = whisper_lang_id(lang_str);
51
+ if (-1 == id) {
52
+ rb_raise(rb_eArgError, "language not found: %s", lang_str);
53
+ }
54
+ return INT2NUM(id);
55
+ }
56
+
57
+ /*
58
+ * call-seq:
59
+ * lang_str(lang_id) -> String
60
+ */
61
+ static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
62
+ const int lang_id = NUM2INT(id);
63
+ const char * str = whisper_lang_str(lang_id);
64
+ if (NULL == str) {
65
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
66
+ }
67
+ return rb_str_new2(str);
68
+ }
69
+
70
+ /*
71
+ * call-seq:
72
+ * lang_str(lang_id) -> String
73
+ */
74
+ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
75
+ const int lang_id = NUM2INT(id);
76
+ const char * str_full = whisper_lang_str_full(lang_id);
77
+ if (NULL == str_full) {
78
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
79
+ }
80
+ return rb_str_new2(str_full);
81
+ }
82
+
83
+ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
84
+ is_log_callback_finalized = true;
85
+ return Qnil;
86
+ }
87
+
88
+ static void
89
+ ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
90
+ if (is_log_callback_finalized) {
91
+ return;
92
+ }
93
+ VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
94
+ VALUE udata = rb_iv_get(mWhisper, "user_data");
95
+ rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
96
+ }
97
+
98
+ /*
99
+ * call-seq:
100
+ * log_set ->(level, buffer, user_data) { ... }, user_data -> nil
101
+ */
102
+ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
103
+ VALUE old_callback = rb_iv_get(self, "log_callback");
104
+ if (!NIL_P(old_callback)) {
105
+ rb_undefine_finalizer(old_callback);
106
+ }
107
+
108
+ rb_iv_set(self, "log_callback", log_callback);
109
+ rb_iv_set(self, "user_data", user_data);
110
+
111
+ VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
112
+ rb_define_finalizer(log_callback, finalize_log_callback);
113
+
114
+ whisper_log_set(ruby_whisper_log_callback, NULL);
115
+
116
+ return Qnil;
117
+ }
118
+
119
+ static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
120
+ rb_gc_mark(rwm->context);
121
+ }
122
+
123
+ static VALUE ruby_whisper_model_allocate(VALUE klass) {
124
+ ruby_whisper_model *rwm;
125
+ rwm = ALLOC(ruby_whisper_model);
126
+ return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
127
+ }
128
+
129
+ void Init_whisper() {
130
+ id_to_s = rb_intern("to_s");
131
+ id_call = rb_intern("call");
132
+ id___method__ = rb_intern("__method__");
133
+ id_to_enum = rb_intern("to_enum");
134
+ id_length = rb_intern("length");
135
+ id_next = rb_intern("next");
136
+ id_new = rb_intern("new");
137
+ id_to_path = rb_intern("to_path");
138
+ id_URI = rb_intern("URI");
139
+ id_pre_converted_models = rb_intern("pre_converted_models");
140
+
141
+ mWhisper = rb_define_module("Whisper");
142
+
143
+ rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
144
+ rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
145
+ rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
146
+ rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
147
+ rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
148
+ rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
149
+
150
+ rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
151
+ rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
152
+ rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
153
+ rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
154
+ rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
155
+ rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
156
+
157
+ init_ruby_whisper_context(&mWhisper);
158
+ init_ruby_whisper_params(&mWhisper);
159
+ init_ruby_whisper_error(&mWhisper);
160
+ init_ruby_whisper_segment(&mWhisper, &cContext);
161
+ init_ruby_whisper_model(&mWhisper);
162
+
163
+ rb_require("whisper/model/uri");
164
+ }
bindings/ruby/ext/ruby_whisper.cpp DELETED
@@ -1,1962 +0,0 @@
1
- #include <ruby.h>
2
- #include <ruby/memory_view.h>
3
- #include "ruby_whisper.h"
4
- #define DR_WAV_IMPLEMENTATION
5
- #include "dr_wav.h"
6
- #include <cmath>
7
- #include <fstream>
8
- #include <cstdio>
9
- #include <string>
10
- #include <thread>
11
- #include <vector>
12
-
13
- #ifdef __cplusplus
14
- extern "C" {
15
- #endif
16
-
17
- #define BOOL_PARAMS_SETTER(self, prop, value) \
18
- ruby_whisper_params *rwp; \
19
- Data_Get_Struct(self, ruby_whisper_params, rwp); \
20
- if (value == Qfalse || value == Qnil) { \
21
- rwp->params.prop = false; \
22
- } else { \
23
- rwp->params.prop = true; \
24
- } \
25
- return value; \
26
-
27
- #define BOOL_PARAMS_GETTER(self, prop) \
28
- ruby_whisper_params *rwp; \
29
- Data_Get_Struct(self, ruby_whisper_params, rwp); \
30
- if (rwp->params.prop) { \
31
- return Qtrue; \
32
- } else { \
33
- return Qfalse; \
34
- }
35
-
36
- VALUE mWhisper;
37
- VALUE cContext;
38
- VALUE cParams;
39
- VALUE eError;
40
-
41
- VALUE cSegment;
42
- VALUE cModel;
43
-
44
- static ID id_to_s;
45
- static ID id_call;
46
- static ID id___method__;
47
- static ID id_to_enum;
48
- static ID id_length;
49
- static ID id_next;
50
- static ID id_new;
51
- static ID id_to_path;
52
- static ID id_URI;
53
- static ID id_pre_converted_models;
54
-
55
- static bool is_log_callback_finalized = false;
56
-
57
- // High level API
58
- static VALUE rb_whisper_segment_initialize(VALUE context, int index);
59
-
60
- /*
61
- * call-seq:
62
- * lang_max_id -> Integer
63
- */
64
- static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
65
- return INT2NUM(whisper_lang_max_id());
66
- }
67
-
68
- /*
69
- * call-seq:
70
- * lang_id(lang_name) -> Integer
71
- */
72
- static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
73
- const char * lang_str = StringValueCStr(lang);
74
- const int id = whisper_lang_id(lang_str);
75
- if (-1 == id) {
76
- rb_raise(rb_eArgError, "language not found: %s", lang_str);
77
- }
78
- return INT2NUM(id);
79
- }
80
-
81
- /*
82
- * call-seq:
83
- * lang_str(lang_id) -> String
84
- */
85
- static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
86
- const int lang_id = NUM2INT(id);
87
- const char * str = whisper_lang_str(lang_id);
88
- if (nullptr == str) {
89
- rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
90
- }
91
- return rb_str_new2(str);
92
- }
93
-
94
- /*
95
- * call-seq:
96
- * lang_str(lang_id) -> String
97
- */
98
- static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
99
- const int lang_id = NUM2INT(id);
100
- const char * str_full = whisper_lang_str_full(lang_id);
101
- if (nullptr == str_full) {
102
- rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
103
- }
104
- return rb_str_new2(str_full);
105
- }
106
-
107
- static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
108
- is_log_callback_finalized = true;
109
- return Qnil;
110
- }
111
-
112
- /*
113
- * call-seq:
114
- * log_set ->(level, buffer, user_data) { ... }, user_data -> nil
115
- */
116
- static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
117
- VALUE old_callback = rb_iv_get(self, "log_callback");
118
- if (!NIL_P(old_callback)) {
119
- rb_undefine_finalizer(old_callback);
120
- }
121
-
122
- rb_iv_set(self, "log_callback", log_callback);
123
- rb_iv_set(self, "user_data", user_data);
124
-
125
- VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
126
- rb_define_finalizer(log_callback, finalize_log_callback);
127
-
128
- whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) {
129
- if (is_log_callback_finalized) {
130
- return;
131
- }
132
- VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
133
- VALUE udata = rb_iv_get(mWhisper, "user_data");
134
- rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
135
- }, nullptr);
136
-
137
- return Qnil;
138
- }
139
-
140
- static void ruby_whisper_free(ruby_whisper *rw) {
141
- if (rw->context) {
142
- whisper_free(rw->context);
143
- rw->context = NULL;
144
- }
145
- }
146
-
147
- static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
148
- }
149
-
150
- void rb_whisper_mark(ruby_whisper *rw) {
151
- // call rb_gc_mark on any ruby references in rw
152
- }
153
-
154
- void rb_whisper_free(ruby_whisper *rw) {
155
- ruby_whisper_free(rw);
156
- free(rw);
157
- }
158
-
159
- void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
160
- rb_gc_mark(rwc->user_data);
161
- rb_gc_mark(rwc->callback);
162
- rb_gc_mark(rwc->callbacks);
163
- }
164
-
165
- void rb_whisper_params_mark(ruby_whisper_params *rwp) {
166
- rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
167
- rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
168
- rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
169
- }
170
-
171
- void rb_whisper_params_free(ruby_whisper_params *rwp) {
172
- // How to free user_data and callback only when not referred to by others?
173
- ruby_whisper_params_free(rwp);
174
- free(rwp);
175
- }
176
-
177
- static VALUE ruby_whisper_allocate(VALUE klass) {
178
- ruby_whisper *rw;
179
- rw = ALLOC(ruby_whisper);
180
- rw->context = NULL;
181
- return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
182
- }
183
-
184
- static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() {
185
- ruby_whisper_callback_container *container;
186
- container = ALLOC(ruby_whisper_callback_container);
187
- container->context = nullptr;
188
- container->user_data = Qnil;
189
- container->callback = Qnil;
190
- container->callbacks = rb_ary_new();
191
- return container;
192
- }
193
-
194
- static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
195
- const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
196
-
197
- // Currently, doesn't support state because
198
- // those require to resolve GC-related problems.
199
- if (!NIL_P(container->callback)) {
200
- rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
201
- }
202
- const long callbacks_len = RARRAY_LEN(container->callbacks);
203
- if (0 == callbacks_len) {
204
- return;
205
- }
206
- const int n_segments = whisper_full_n_segments_from_state(state);
207
- for (int i = n_new; i > 0; i--) {
208
- int i_segment = n_segments - i;
209
- VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
210
- for (int j = 0; j < callbacks_len; j++) {
211
- VALUE cb = rb_ary_entry(container->callbacks, j);
212
- rb_funcall(cb, id_call, 1, segment);
213
- }
214
- }
215
- }
216
-
217
- static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
218
- const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
219
- const VALUE progress = INT2NUM(progress_cur);
220
- // Currently, doesn't support state because
221
- // those require to resolve GC-related problems.
222
- if (!NIL_P(container->callback)) {
223
- rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
224
- }
225
- const long callbacks_len = RARRAY_LEN(container->callbacks);
226
- if (0 == callbacks_len) {
227
- return;
228
- }
229
- for (int j = 0; j < callbacks_len; j++) {
230
- VALUE cb = rb_ary_entry(container->callbacks, j);
231
- rb_funcall(cb, id_call, 1, progress);
232
- }
233
- }
234
-
235
- static bool abort_callback(void * user_data) {
236
- const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
237
- if (!NIL_P(container->callback)) {
238
- VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
239
- if (!NIL_P(result) && Qfalse != result) {
240
- return true;
241
- }
242
- }
243
- const long callbacks_len = RARRAY_LEN(container->callbacks);
244
- if (0 == callbacks_len) {
245
- return false;
246
- }
247
- for (int j = 0; j < callbacks_len; j++) {
248
- VALUE cb = rb_ary_entry(container->callbacks, j);
249
- VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
250
- if (!NIL_P(result) && Qfalse != result) {
251
- return true;
252
- }
253
- }
254
- return false;
255
- }
256
-
257
- static VALUE ruby_whisper_params_allocate(VALUE klass) {
258
- ruby_whisper_params *rwp;
259
- rwp = ALLOC(ruby_whisper_params);
260
- rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
261
- rwp->diarize = false;
262
- rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
263
- rwp->progress_callback_container = rb_whisper_callback_container_allocate();
264
- rwp->abort_callback_container = rb_whisper_callback_container_allocate();
265
- return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
266
- }
267
-
268
- /*
269
- * call-seq:
270
- * new("base.en") -> Whisper::Context
271
- * new("path/to/model.bin") -> Whisper::Context
272
- * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
273
- */
274
- static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
275
- ruby_whisper *rw;
276
- VALUE whisper_model_file_path;
277
-
278
- // TODO: we can support init from buffer here too maybe another ruby object to expose
279
- rb_scan_args(argc, argv, "01", &whisper_model_file_path);
280
- Data_Get_Struct(self, ruby_whisper, rw);
281
-
282
- VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
283
- VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
284
- if (!NIL_P(pre_converted_model)) {
285
- whisper_model_file_path = pre_converted_model;
286
- }
287
- if (TYPE(whisper_model_file_path) == T_STRING) {
288
- const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
289
- if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
290
- VALUE uri_class = rb_const_get(cModel, id_URI);
291
- whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
292
- }
293
- }
294
- if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
295
- VALUE uri_class = rb_const_get(cModel, id_URI);
296
- whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
297
- }
298
- if (rb_respond_to(whisper_model_file_path, id_to_path)) {
299
- whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
300
- }
301
- if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
302
- rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
303
- }
304
- rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
305
- if (rw->context == nullptr) {
306
- rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
307
- }
308
- return self;
309
- }
310
-
311
- static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
312
- if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
313
- rwp->new_segment_callback_container->context = self;
314
- rwp->params.new_segment_callback = new_segment_callback;
315
- rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
316
- }
317
-
318
- if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
319
- rwp->progress_callback_container->context = self;
320
- rwp->params.progress_callback = progress_callback;
321
- rwp->params.progress_callback_user_data = rwp->progress_callback_container;
322
- }
323
-
324
- if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
325
- rwp->abort_callback_container->context = self;
326
- rwp->params.abort_callback = abort_callback;
327
- rwp->params.abort_callback_user_data = rwp->abort_callback_container;
328
- }
329
- }
330
-
331
- /*
332
- * transcribe a single file
333
- * can emit to a block results
334
- *
335
- * params = Whisper::Params.new
336
- * params.duration = 60_000
337
- * whisper.transcribe "path/to/audio.wav", params do |text|
338
- * puts text
339
- * end
340
- *
341
- * call-seq:
342
- * transcribe(path_to_audio, params) {|text| ...}
343
- **/
344
- static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
345
- ruby_whisper *rw;
346
- ruby_whisper_params *rwp;
347
- VALUE wave_file_path, blk, params;
348
-
349
- rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
350
- Data_Get_Struct(self, ruby_whisper, rw);
351
- Data_Get_Struct(params, ruby_whisper_params, rwp);
352
-
353
- if (!rb_respond_to(wave_file_path, id_to_s)) {
354
- rb_raise(rb_eRuntimeError, "Expected file path to wave file");
355
- }
356
-
357
- std::string fname_inp = StringValueCStr(wave_file_path);
358
-
359
- std::vector<float> pcmf32; // mono-channel F32 PCM
360
- std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
361
-
362
- // WAV input - this is directly from main.cpp example
363
- {
364
- drwav wav;
365
- std::vector<uint8_t> wav_data; // used for pipe input from stdin
366
-
367
- if (fname_inp == "-") {
368
- {
369
- uint8_t buf[1024];
370
- while (true) {
371
- const size_t n = fread(buf, 1, sizeof(buf), stdin);
372
- if (n == 0) {
373
- break;
374
- }
375
- wav_data.insert(wav_data.end(), buf, buf + n);
376
- }
377
- }
378
-
379
- if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
380
- fprintf(stderr, "error: failed to open WAV file from stdin\n");
381
- return self;
382
- }
383
-
384
- fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
385
- } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
386
- fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
387
- return self;
388
- }
389
-
390
- if (wav.channels != 1 && wav.channels != 2) {
391
- fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
392
- return self;
393
- }
394
-
395
- if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
396
- fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
397
- return self;
398
- }
399
-
400
- if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
401
- fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
402
- return self;
403
- }
404
-
405
- if (wav.bitsPerSample != 16) {
406
- fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
407
- return self;
408
- }
409
-
410
- const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
411
-
412
- std::vector<int16_t> pcm16;
413
- pcm16.resize(n*wav.channels);
414
- drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
415
- drwav_uninit(&wav);
416
-
417
- // convert to mono, float
418
- pcmf32.resize(n);
419
- if (wav.channels == 1) {
420
- for (uint64_t i = 0; i < n; i++) {
421
- pcmf32[i] = float(pcm16[i])/32768.0f;
422
- }
423
- } else {
424
- for (uint64_t i = 0; i < n; i++) {
425
- pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
426
- }
427
- }
428
-
429
- if (rwp->diarize) {
430
- // convert to stereo, float
431
- pcmf32s.resize(2);
432
-
433
- pcmf32s[0].resize(n);
434
- pcmf32s[1].resize(n);
435
- for (uint64_t i = 0; i < n; i++) {
436
- pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
437
- pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
438
- }
439
- }
440
- }
441
- {
442
- static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
443
-
444
- rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
445
- bool is_aborted = *(bool*)user_data;
446
- return !is_aborted;
447
- };
448
- rwp->params.encoder_begin_callback_user_data = &is_aborted;
449
- }
450
-
451
- register_callbacks(rwp, &self);
452
-
453
- if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
454
- fprintf(stderr, "failed to process audio\n");
455
- return self;
456
- }
457
- const int n_segments = whisper_full_n_segments(rw->context);
458
- VALUE output = rb_str_new2("");
459
- for (int i = 0; i < n_segments; ++i) {
460
- const char * text = whisper_full_get_segment_text(rw->context, i);
461
- output = rb_str_concat(output, rb_str_new2(text));
462
- }
463
- VALUE idCall = id_call;
464
- if (blk != Qnil) {
465
- rb_funcall(blk, idCall, 1, output);
466
- }
467
- return self;
468
- }
469
-
470
- /*
471
- * call-seq:
472
- * model_n_vocab -> Integer
473
- */
474
- VALUE ruby_whisper_model_n_vocab(VALUE self) {
475
- ruby_whisper *rw;
476
- Data_Get_Struct(self, ruby_whisper, rw);
477
- return INT2NUM(whisper_model_n_vocab(rw->context));
478
- }
479
-
480
- /*
481
- * call-seq:
482
- * model_n_audio_ctx -> Integer
483
- */
484
- VALUE ruby_whisper_model_n_audio_ctx(VALUE self) {
485
- ruby_whisper *rw;
486
- Data_Get_Struct(self, ruby_whisper, rw);
487
- return INT2NUM(whisper_model_n_audio_ctx(rw->context));
488
- }
489
-
490
- /*
491
- * call-seq:
492
- * model_n_audio_state -> Integer
493
- */
494
- VALUE ruby_whisper_model_n_audio_state(VALUE self) {
495
- ruby_whisper *rw;
496
- Data_Get_Struct(self, ruby_whisper, rw);
497
- return INT2NUM(whisper_model_n_audio_state(rw->context));
498
- }
499
-
500
- /*
501
- * call-seq:
502
- * model_n_audio_head -> Integer
503
- */
504
- VALUE ruby_whisper_model_n_audio_head(VALUE self) {
505
- ruby_whisper *rw;
506
- Data_Get_Struct(self, ruby_whisper, rw);
507
- return INT2NUM(whisper_model_n_audio_head(rw->context));
508
- }
509
-
510
- /*
511
- * call-seq:
512
- * model_n_audio_layer -> Integer
513
- */
514
- VALUE ruby_whisper_model_n_audio_layer(VALUE self) {
515
- ruby_whisper *rw;
516
- Data_Get_Struct(self, ruby_whisper, rw);
517
- return INT2NUM(whisper_model_n_audio_layer(rw->context));
518
- }
519
-
520
- /*
521
- * call-seq:
522
- * model_n_text_ctx -> Integer
523
- */
524
- VALUE ruby_whisper_model_n_text_ctx(VALUE self) {
525
- ruby_whisper *rw;
526
- Data_Get_Struct(self, ruby_whisper, rw);
527
- return INT2NUM(whisper_model_n_text_ctx(rw->context));
528
- }
529
-
530
- /*
531
- * call-seq:
532
- * model_n_text_state -> Integer
533
- */
534
- VALUE ruby_whisper_model_n_text_state(VALUE self) {
535
- ruby_whisper *rw;
536
- Data_Get_Struct(self, ruby_whisper, rw);
537
- return INT2NUM(whisper_model_n_text_state(rw->context));
538
- }
539
-
540
- /*
541
- * call-seq:
542
- * model_n_text_head -> Integer
543
- */
544
- VALUE ruby_whisper_model_n_text_head(VALUE self) {
545
- ruby_whisper *rw;
546
- Data_Get_Struct(self, ruby_whisper, rw);
547
- return INT2NUM(whisper_model_n_text_head(rw->context));
548
- }
549
-
550
- /*
551
- * call-seq:
552
- * model_n_text_layer -> Integer
553
- */
554
- VALUE ruby_whisper_model_n_text_layer(VALUE self) {
555
- ruby_whisper *rw;
556
- Data_Get_Struct(self, ruby_whisper, rw);
557
- return INT2NUM(whisper_model_n_text_layer(rw->context));
558
- }
559
-
560
- /*
561
- * call-seq:
562
- * model_n_mels -> Integer
563
- */
564
- VALUE ruby_whisper_model_n_mels(VALUE self) {
565
- ruby_whisper *rw;
566
- Data_Get_Struct(self, ruby_whisper, rw);
567
- return INT2NUM(whisper_model_n_mels(rw->context));
568
- }
569
-
570
- /*
571
- * call-seq:
572
- * model_ftype -> Integer
573
- */
574
- VALUE ruby_whisper_model_ftype(VALUE self) {
575
- ruby_whisper *rw;
576
- Data_Get_Struct(self, ruby_whisper, rw);
577
- return INT2NUM(whisper_model_ftype(rw->context));
578
- }
579
-
580
- /*
581
- * call-seq:
582
- * model_type -> String
583
- */
584
- VALUE ruby_whisper_model_type(VALUE self) {
585
- ruby_whisper *rw;
586
- Data_Get_Struct(self, ruby_whisper, rw);
587
- return rb_str_new2(whisper_model_type_readable(rw->context));
588
- }
589
-
590
- /*
591
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
592
- * Not thread safe for same context
593
- * Uses the specified decoding strategy to obtain the text.
594
- *
595
- * call-seq:
596
- * full(params, samples, n_samples) -> nil
597
- * full(params, samples) -> nil
598
- *
599
- * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
600
- */
601
- VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
602
- if (argc < 2 || argc > 3) {
603
- rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
604
- }
605
-
606
- ruby_whisper *rw;
607
- ruby_whisper_params *rwp;
608
- Data_Get_Struct(self, ruby_whisper, rw);
609
- VALUE params = argv[0];
610
- Data_Get_Struct(params, ruby_whisper_params, rwp);
611
- VALUE samples = argv[1];
612
- int n_samples;
613
- rb_memory_view_t view;
614
- const bool memory_view_available_p = rb_memory_view_available_p(samples);
615
- if (argc == 3) {
616
- n_samples = NUM2INT(argv[2]);
617
- if (TYPE(samples) == T_ARRAY) {
618
- if (RARRAY_LEN(samples) < n_samples) {
619
- rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
620
- }
621
- }
622
- // Should check when samples.respond_to?(:length)?
623
- } else {
624
- if (TYPE(samples) == T_ARRAY) {
625
- n_samples = RARRAY_LEN(samples);
626
- } else if (memory_view_available_p) {
627
- if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
628
- view.obj = Qnil;
629
- rb_raise(rb_eArgError, "unable to get a memory view");
630
- }
631
- n_samples = view.byte_size / view.item_size;
632
- } else if (rb_respond_to(samples, id_length)) {
633
- n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
634
- } else {
635
- rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
636
- }
637
- }
638
- float * c_samples = (float *)malloc(n_samples * sizeof(float));
639
- if (memory_view_available_p) {
640
- c_samples = (float *)view.data;
641
- } else {
642
- if (TYPE(samples) == T_ARRAY) {
643
- for (int i = 0; i < n_samples; i++) {
644
- c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
645
- }
646
- } else {
647
- // TODO: use rb_block_call
648
- VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
649
- for (int i = 0; i < n_samples; i++) {
650
- // TODO: check if iter is exhausted and raise ArgumentError appropriately
651
- VALUE sample = rb_funcall(iter, id_next, 0);
652
- c_samples[i] = RFLOAT_VALUE(sample);
653
- }
654
- }
655
- }
656
- register_callbacks(rwp, &self);
657
- const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
658
- if (0 == result) {
659
- return Qnil;
660
- } else {
661
- rb_exc_raise(rb_funcall(eError, id_new, 1, result));
662
- }
663
- }
664
-
665
- /*
666
- * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
667
- * Result is stored in the default state of the context
668
- * Not thread safe if executed in parallel on the same context.
669
- * It seems this approach can offer some speedup in some cases.
670
- * However, the transcription accuracy can be worse at the beginning and end of each chunk.
671
- *
672
- * call-seq:
673
- * full_parallel(params, samples) -> nil
674
- * full_parallel(params, samples, n_samples) -> nil
675
- * full_parallel(params, samples, n_samples, n_processors) -> nil
676
- * full_parallel(params, samples, nil, n_processors) -> nil
677
- */
678
- static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
679
- if (argc < 2 || argc > 4) {
680
- rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
681
- }
682
-
683
- ruby_whisper *rw;
684
- ruby_whisper_params *rwp;
685
- Data_Get_Struct(self, ruby_whisper, rw);
686
- VALUE params = argv[0];
687
- Data_Get_Struct(params, ruby_whisper_params, rwp);
688
- VALUE samples = argv[1];
689
- int n_samples;
690
- int n_processors;
691
- rb_memory_view_t view;
692
- const bool memory_view_available_p = rb_memory_view_available_p(samples);
693
- switch (argc) {
694
- case 2:
695
- n_processors = 1;
696
- break;
697
- case 3:
698
- n_processors = 1;
699
- break;
700
- case 4:
701
- n_processors = NUM2INT(argv[3]);
702
- break;
703
- }
704
- if (argc >= 3 && !NIL_P(argv[2])) {
705
- n_samples = NUM2INT(argv[2]);
706
- if (TYPE(samples) == T_ARRAY) {
707
- if (RARRAY_LEN(samples) < n_samples) {
708
- rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
709
- }
710
- }
711
- // Should check when samples.respond_to?(:length)?
712
- } else if (memory_view_available_p) {
713
- if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
714
- view.obj = Qnil;
715
- rb_raise(rb_eArgError, "unable to get a memory view");
716
- }
717
- n_samples = view.byte_size / view.item_size;
718
- } else {
719
- if (TYPE(samples) == T_ARRAY) {
720
- n_samples = RARRAY_LEN(samples);
721
- } else if (rb_respond_to(samples, id_length)) {
722
- n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
723
- } else {
724
- rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
725
- }
726
- }
727
- float * c_samples = (float *)malloc(n_samples * sizeof(float));
728
- if (memory_view_available_p) {
729
- c_samples = (float *)view.data;
730
- } else {
731
- if (TYPE(samples) == T_ARRAY) {
732
- for (int i = 0; i < n_samples; i++) {
733
- c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
734
- }
735
- } else {
736
- // FIXME: use rb_block_call
737
- VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
738
- for (int i = 0; i < n_samples; i++) {
739
- // TODO: check if iter is exhausted and raise ArgumentError
740
- VALUE sample = rb_funcall(iter, id_next, 0);
741
- c_samples[i] = RFLOAT_VALUE(sample);
742
- }
743
- }
744
- }
745
- register_callbacks(rwp, &self);
746
- const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
747
- if (0 == result) {
748
- return Qnil;
749
- } else {
750
- rb_exc_raise(rb_funcall(eError, id_new, 1, result));
751
- }
752
- }
753
-
754
- /*
755
- * Number of segments.
756
- *
757
- * call-seq:
758
- * full_n_segments -> Integer
759
- */
760
- static VALUE ruby_whisper_full_n_segments(VALUE self) {
761
- ruby_whisper *rw;
762
- Data_Get_Struct(self, ruby_whisper, rw);
763
- return INT2NUM(whisper_full_n_segments(rw->context));
764
- }
765
-
766
- /*
767
- * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
768
- *
769
- * call-seq:
770
- * full_lang_id -> Integer
771
- */
772
- static VALUE ruby_whisper_full_lang_id(VALUE self) {
773
- ruby_whisper *rw;
774
- Data_Get_Struct(self, ruby_whisper, rw);
775
- return INT2NUM(whisper_full_lang_id(rw->context));
776
- }
777
-
778
- static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
779
- const int c_i_segment = NUM2INT(i_segment);
780
- if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
781
- rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
782
- }
783
- return c_i_segment;
784
- }
785
-
786
- /*
787
- * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
788
- *
789
- * full_get_segment_t0(3) # => 1668 (16680 ms)
790
- *
791
- * call-seq:
792
- * full_get_segment_t0(segment_index) -> Integer
793
- */
794
- static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
795
- ruby_whisper *rw;
796
- Data_Get_Struct(self, ruby_whisper, rw);
797
- const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
798
- const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
799
- return INT2NUM(t0);
800
- }
801
-
802
- /*
803
- * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
804
- *
805
- * full_get_segment_t1(3) # => 1668 (16680 ms)
806
- *
807
- * call-seq:
808
- * full_get_segment_t1(segment_index) -> Integer
809
- */
810
- static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
811
- ruby_whisper *rw;
812
- Data_Get_Struct(self, ruby_whisper, rw);
813
- const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
814
- const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
815
- return INT2NUM(t1);
816
- }
817
-
818
- /*
819
- * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
820
- *
821
- * full_get_segment_speacker_turn_next(3) # => true
822
- *
823
- * call-seq:
824
- * full_get_segment_speacker_turn_next(segment_index) -> bool
825
- */
826
- static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
827
- ruby_whisper *rw;
828
- Data_Get_Struct(self, ruby_whisper, rw);
829
- const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
830
- const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
831
- return speaker_turn_next ? Qtrue : Qfalse;
832
- }
833
-
834
- /*
835
- * Text of a segment indexed by +segment_index+.
836
- *
837
- * full_get_segment_text(3) # => "ask not what your country can do for you, ..."
838
- *
839
- * call-seq:
840
- * full_get_segment_text(segment_index) -> String
841
- */
842
- static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
843
- ruby_whisper *rw;
844
- Data_Get_Struct(self, ruby_whisper, rw);
845
- const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
846
- const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
847
- return rb_str_new2(text);
848
- }
849
-
850
- /*
851
- * call-seq:
852
- * full_get_segment_no_speech_prob(segment_index) -> Float
853
- */
854
- static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
855
- ruby_whisper *rw;
856
- Data_Get_Struct(self, ruby_whisper, rw);
857
- const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
858
- const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
859
- return DBL2NUM(no_speech_prob);
860
- }
861
-
862
- /*
863
- * params.language = "auto" | "en", etc...
864
- *
865
- * call-seq:
866
- * language = lang_name -> lang_name
867
- */
868
- static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
869
- ruby_whisper_params *rwp;
870
- Data_Get_Struct(self, ruby_whisper_params, rwp);
871
- if (value == Qfalse || value == Qnil) {
872
- rwp->params.language = "auto";
873
- } else {
874
- rwp->params.language = StringValueCStr(value);
875
- }
876
- return value;
877
- }
878
- /*
879
- * call-seq:
880
- * language -> String
881
- */
882
- static VALUE ruby_whisper_params_get_language(VALUE self) {
883
- ruby_whisper_params *rwp;
884
- Data_Get_Struct(self, ruby_whisper_params, rwp);
885
- if (rwp->params.language) {
886
- return rb_str_new2(rwp->params.language);
887
- } else {
888
- return rb_str_new2("auto");
889
- }
890
- }
891
- /*
892
- * call-seq:
893
- * translate = do_translate -> do_translate
894
- */
895
- static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
896
- BOOL_PARAMS_SETTER(self, translate, value)
897
- }
898
- /*
899
- * call-seq:
900
- * translate -> bool
901
- */
902
- static VALUE ruby_whisper_params_get_translate(VALUE self) {
903
- BOOL_PARAMS_GETTER(self, translate)
904
- }
905
- /*
906
- * call-seq:
907
- * no_context = dont_use_context -> dont_use_context
908
- */
909
- static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
910
- BOOL_PARAMS_SETTER(self, no_context, value)
911
- }
912
- /*
913
- * If true, does not use past transcription (if any) as initial prompt for the decoder.
914
- *
915
- * call-seq:
916
- * no_context -> bool
917
- */
918
- static VALUE ruby_whisper_params_get_no_context(VALUE self) {
919
- BOOL_PARAMS_GETTER(self, no_context)
920
- }
921
- /*
922
- * call-seq:
923
- * single_segment = force_single -> force_single
924
- */
925
- static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
926
- BOOL_PARAMS_SETTER(self, single_segment, value)
927
- }
928
- /*
929
- * If true, forces single segment output (useful for streaming).
930
- *
931
- * call-seq:
932
- * single_segment -> bool
933
- */
934
- static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
935
- BOOL_PARAMS_GETTER(self, single_segment)
936
- }
937
- /*
938
- * call-seq:
939
- * print_special = force_print -> force_print
940
- */
941
- static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
942
- BOOL_PARAMS_SETTER(self, print_special, value)
943
- }
944
- /*
945
- * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
946
- *
947
- * call-seq:
948
- * print_special -> bool
949
- */
950
- static VALUE ruby_whisper_params_get_print_special(VALUE self) {
951
- BOOL_PARAMS_GETTER(self, print_special)
952
- }
953
- /*
954
- * call-seq:
955
- * print_progress = force_print -> force_print
956
- */
957
- static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
958
- BOOL_PARAMS_SETTER(self, print_progress, value)
959
- }
960
- /*
961
- * If true, prints progress information.
962
- *
963
- * call-seq:
964
- * print_progress -> bool
965
- */
966
- static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
967
- BOOL_PARAMS_GETTER(self, print_progress)
968
- }
969
- /*
970
- * call-seq:
971
- * print_realtime = force_print -> force_print
972
- */
973
- static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
974
- BOOL_PARAMS_SETTER(self, print_realtime, value)
975
- }
976
- /*
977
- * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
978
- * call-seq:
979
- * print_realtime -> bool
980
- */
981
- static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
982
- BOOL_PARAMS_GETTER(self, print_realtime)
983
- }
984
- /*
985
- * call-seq:
986
- * print_timestamps = force_print -> force_print
987
- */
988
- static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
989
- BOOL_PARAMS_SETTER(self, print_timestamps, value)
990
- }
991
- /*
992
- * If true, prints timestamps for each text segment when printing realtime.
993
- *
994
- * call-seq:
995
- * print_timestamps -> bool
996
- */
997
- static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
998
- BOOL_PARAMS_GETTER(self, print_timestamps)
999
- }
1000
- /*
1001
- * call-seq:
1002
- * suppress_blank = force_suppress -> force_suppress
1003
- */
1004
- static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
1005
- BOOL_PARAMS_SETTER(self, suppress_blank, value)
1006
- }
1007
- /*
1008
- * If true, suppresses blank outputs.
1009
- *
1010
- * call-seq:
1011
- * suppress_blank -> bool
1012
- */
1013
- static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
1014
- BOOL_PARAMS_GETTER(self, suppress_blank)
1015
- }
1016
- /*
1017
- * call-seq:
1018
- * suppress_nst = force_suppress -> force_suppress
1019
- */
1020
- static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
1021
- BOOL_PARAMS_SETTER(self, suppress_nst, value)
1022
- }
1023
- /*
1024
- * If true, suppresses non-speech-tokens.
1025
- *
1026
- * call-seq:
1027
- * suppress_nst -> bool
1028
- */
1029
- static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
1030
- BOOL_PARAMS_GETTER(self, suppress_nst)
1031
- }
1032
- /*
1033
- * If true, enables token-level timestamps.
1034
- *
1035
- * call-seq:
1036
- * token_timestamps -> bool
1037
- */
1038
- static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
1039
- BOOL_PARAMS_GETTER(self, token_timestamps)
1040
- }
1041
- /*
1042
- * call-seq:
1043
- * token_timestamps = force_timestamps -> force_timestamps
1044
- */
1045
- static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
1046
- BOOL_PARAMS_SETTER(self, token_timestamps, value)
1047
- }
1048
- /*
1049
- * If true, split on word rather than on token (when used with max_len).
1050
- *
1051
- * call-seq:
1052
- * translate -> bool
1053
- */
1054
- static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
1055
- BOOL_PARAMS_GETTER(self, split_on_word)
1056
- }
1057
- /*
1058
- * call-seq:
1059
- * split_on_word = force_split -> force_split
1060
- */
1061
- static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
1062
- BOOL_PARAMS_SETTER(self, split_on_word, value)
1063
- }
1064
- /*
1065
- * Tokens to provide to the whisper decoder as initial prompt
1066
- * these are prepended to any existing text context from a previous call
1067
- * use whisper_tokenize() to convert text to tokens.
1068
- * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
1069
- *
1070
- * call-seq:
1071
- * initial_prompt -> String
1072
- */
1073
- static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) {
1074
- ruby_whisper_params *rwp;
1075
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1076
- return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt);
1077
- }
1078
- /*
1079
- * call-seq:
1080
- * initial_prompt = prompt -> prompt
1081
- */
1082
- static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) {
1083
- ruby_whisper_params *rwp;
1084
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1085
- rwp->params.initial_prompt = StringValueCStr(value);
1086
- return value;
1087
- }
1088
- /*
1089
- * If true, enables diarization.
1090
- *
1091
- * call-seq:
1092
- * diarize -> bool
1093
- */
1094
- static VALUE ruby_whisper_params_get_diarize(VALUE self) {
1095
- ruby_whisper_params *rwp;
1096
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1097
- if (rwp->diarize) {
1098
- return Qtrue;
1099
- } else {
1100
- return Qfalse;
1101
- }
1102
- }
1103
- /*
1104
- * call-seq:
1105
- * diarize = force_diarize -> force_diarize
1106
- */
1107
- static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
1108
- ruby_whisper_params *rwp;
1109
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1110
- if (value == Qfalse || value == Qnil) {
1111
- rwp->diarize = false;
1112
- } else {
1113
- rwp->diarize = true;
1114
- } \
1115
- return value;
1116
- }
1117
-
1118
- /*
1119
- * Start offset in ms.
1120
- *
1121
- * call-seq:
1122
- * offset -> Integer
1123
- */
1124
- static VALUE ruby_whisper_params_get_offset(VALUE self) {
1125
- ruby_whisper_params *rwp;
1126
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1127
- return INT2NUM(rwp->params.offset_ms);
1128
- }
1129
- /*
1130
- * call-seq:
1131
- * offset = offset_ms -> offset_ms
1132
- */
1133
- static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
1134
- ruby_whisper_params *rwp;
1135
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1136
- rwp->params.offset_ms = NUM2INT(value);
1137
- return value;
1138
- }
1139
- /*
1140
- * Audio duration to process in ms.
1141
- *
1142
- * call-seq:
1143
- * duration -> Integer
1144
- */
1145
- static VALUE ruby_whisper_params_get_duration(VALUE self) {
1146
- ruby_whisper_params *rwp;
1147
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1148
- return INT2NUM(rwp->params.duration_ms);
1149
- }
1150
- /*
1151
- * call-seq:
1152
- * duration = duration_ms -> duration_ms
1153
- */
1154
- static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
1155
- ruby_whisper_params *rwp;
1156
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1157
- rwp->params.duration_ms = NUM2INT(value);
1158
- return value;
1159
- }
1160
-
1161
- /*
1162
- * Max tokens to use from past text as prompt for the decoder.
1163
- *
1164
- * call-seq:
1165
- * max_text_tokens -> Integer
1166
- */
1167
- static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
1168
- ruby_whisper_params *rwp;
1169
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1170
- return INT2NUM(rwp->params.n_max_text_ctx);
1171
- }
1172
- /*
1173
- * call-seq:
1174
- * max_text_tokens = n_tokens -> n_tokens
1175
- */
1176
- static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
1177
- ruby_whisper_params *rwp;
1178
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1179
- rwp->params.n_max_text_ctx = NUM2INT(value);
1180
- return value;
1181
- }
1182
- /*
1183
- * call-seq:
1184
- * temperature -> Float
1185
- */
1186
- static VALUE ruby_whisper_params_get_temperature(VALUE self) {
1187
- ruby_whisper_params *rwp;
1188
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1189
- return DBL2NUM(rwp->params.temperature);
1190
- }
1191
- /*
1192
- * call-seq:
1193
- * temperature = temp -> temp
1194
- */
1195
- static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) {
1196
- ruby_whisper_params *rwp;
1197
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1198
- rwp->params.temperature = RFLOAT_VALUE(value);
1199
- return value;
1200
- }
1201
- /*
1202
- * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
1203
- *
1204
- * call-seq:
1205
- * max_initial_ts -> Flaot
1206
- */
1207
- static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) {
1208
- ruby_whisper_params *rwp;
1209
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1210
- return DBL2NUM(rwp->params.max_initial_ts);
1211
- }
1212
- /*
1213
- * call-seq:
1214
- * max_initial_ts = timestamp -> timestamp
1215
- */
1216
- static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) {
1217
- ruby_whisper_params *rwp;
1218
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1219
- rwp->params.max_initial_ts = RFLOAT_VALUE(value);
1220
- return value;
1221
- }
1222
- /*
1223
- * call-seq:
1224
- * length_penalty -> Float
1225
- */
1226
- static VALUE ruby_whisper_params_get_length_penalty(VALUE self) {
1227
- ruby_whisper_params *rwp;
1228
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1229
- return DBL2NUM(rwp->params.length_penalty);
1230
- }
1231
- /*
1232
- * call-seq:
1233
- * length_penalty = penalty -> penalty
1234
- */
1235
- static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) {
1236
- ruby_whisper_params *rwp;
1237
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1238
- rwp->params.length_penalty = RFLOAT_VALUE(value);
1239
- return value;
1240
- }
1241
- /*
1242
- * call-seq:
1243
- * temperature_inc -> Float
1244
- */
1245
- static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) {
1246
- ruby_whisper_params *rwp;
1247
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1248
- return DBL2NUM(rwp->params.temperature_inc);
1249
- }
1250
- /*
1251
- * call-seq:
1252
- * temperature_inc = inc -> inc
1253
- */
1254
- static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) {
1255
- ruby_whisper_params *rwp;
1256
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1257
- rwp->params.temperature_inc = RFLOAT_VALUE(value);
1258
- return value;
1259
- }
1260
- /*
1261
- * Similar to OpenAI's "compression_ratio_threshold"
1262
- *
1263
- * call-seq:
1264
- * entropy_thold -> Float
1265
- */
1266
- static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) {
1267
- ruby_whisper_params *rwp;
1268
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1269
- return DBL2NUM(rwp->params.entropy_thold);
1270
- }
1271
- /*
1272
- * call-seq:
1273
- * entropy_thold = threshold -> threshold
1274
- */
1275
- static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) {
1276
- ruby_whisper_params *rwp;
1277
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1278
- rwp->params.entropy_thold = RFLOAT_VALUE(value);
1279
- return value;
1280
- }
1281
- /*
1282
- * call-seq:
1283
- * logprob_thold -> Float
1284
- */
1285
- static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) {
1286
- ruby_whisper_params *rwp;
1287
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1288
- return DBL2NUM(rwp->params.logprob_thold);
1289
- }
1290
- /*
1291
- * call-seq:
1292
- * logprob_thold = threshold -> threshold
1293
- */
1294
- static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
1295
- ruby_whisper_params *rwp;
1296
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1297
- rwp->params.logprob_thold = RFLOAT_VALUE(value);
1298
- return value;
1299
- }
1300
- /*
1301
- * call-seq:
1302
- * no_speech_thold -> Float
1303
- */
1304
- static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
1305
- ruby_whisper_params *rwp;
1306
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1307
- return DBL2NUM(rwp->params.no_speech_thold);
1308
- }
1309
- /*
1310
- * call-seq:
1311
- * no_speech_thold = threshold -> threshold
1312
- */
1313
- static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
1314
- ruby_whisper_params *rwp;
1315
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1316
- rwp->params.no_speech_thold = RFLOAT_VALUE(value);
1317
- return value;
1318
- }
1319
- /*
1320
- * Sets new segment callback, called for every newly generated text segment.
1321
- *
1322
- * params.new_segment_callback = ->(context, _, n_new, user_data) {
1323
- * # ...
1324
- * }
1325
- *
1326
- * call-seq:
1327
- * new_segment_callback = callback -> callback
1328
- */
1329
- static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
1330
- ruby_whisper_params *rwp;
1331
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1332
- rwp->new_segment_callback_container->callback = value;
1333
- return value;
1334
- }
1335
- /*
1336
- * Sets user data passed to the last argument of new segment callback.
1337
- *
1338
- * call-seq:
1339
- * new_segment_callback_user_data = user_data -> use_data
1340
- */
1341
- static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
1342
- ruby_whisper_params *rwp;
1343
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1344
- rwp->new_segment_callback_container->user_data = value;
1345
- return value;
1346
- }
1347
- /*
1348
- * Sets progress callback, called on each progress update.
1349
- *
1350
- * params.new_segment_callback = ->(context, _, n_new, user_data) {
1351
- * # ...
1352
- * }
1353
- *
1354
- * call-seq:
1355
- * progress_callback = callback -> callback
1356
- */
1357
- static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
1358
- ruby_whisper_params *rwp;
1359
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1360
- rwp->progress_callback_container->callback = value;
1361
- return value;
1362
- }
1363
- /*
1364
- * Sets user data passed to the last argument of progress callback.
1365
- *
1366
- * call-seq:
1367
- * progress_callback_user_data = user_data -> use_data
1368
- */
1369
- static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
1370
- ruby_whisper_params *rwp;
1371
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1372
- rwp->progress_callback_container->user_data = value;
1373
- return value;
1374
- }
1375
- /*
1376
- * Sets abort callback, called to check if the process should be aborted.
1377
- *
1378
- * params.abort_callback = ->(user_data) {
1379
- * # ...
1380
- * }
1381
- *
1382
- * call-seq:
1383
- * abort_callback = callback -> callback
1384
- */
1385
- static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
1386
- ruby_whisper_params *rwp;
1387
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1388
- rwp->abort_callback_container->callback = value;
1389
- return value;
1390
- }
1391
- /*
1392
- * Sets user data passed to the last argument of abort callback.
1393
- *
1394
- * call-seq:
1395
- * abort_callback_user_data = user_data -> use_data
1396
- */
1397
- static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
1398
- ruby_whisper_params *rwp;
1399
- Data_Get_Struct(self, ruby_whisper_params, rwp);
1400
- rwp->abort_callback_container->user_data = value;
1401
- return value;
1402
- }
1403
-
1404
- // High level API
1405
-
1406
- typedef struct {
1407
- VALUE context;
1408
- int index;
1409
- } ruby_whisper_segment;
1410
-
1411
- typedef struct {
1412
- VALUE context;
1413
- } ruby_whisper_model;
1414
-
1415
- static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
1416
- rb_gc_mark(rws->context);
1417
- }
1418
-
1419
- static VALUE ruby_whisper_segment_allocate(VALUE klass) {
1420
- ruby_whisper_segment *rws;
1421
- rws = ALLOC(ruby_whisper_segment);
1422
- return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
1423
- }
1424
-
1425
- static VALUE rb_whisper_segment_initialize(VALUE context, int index) {
1426
- ruby_whisper_segment *rws;
1427
- const VALUE segment = ruby_whisper_segment_allocate(cSegment);
1428
- Data_Get_Struct(segment, ruby_whisper_segment, rws);
1429
- rws->context = context;
1430
- rws->index = index;
1431
- return segment;
1432
- };
1433
-
1434
- /*
1435
- * Yields each Whisper::Segment:
1436
- *
1437
- * whisper.transcribe("path/to/audio.wav", params)
1438
- * whisper.each_segment do |segment|
1439
- * puts segment.text
1440
- * end
1441
- *
1442
- * Returns an Enumerator if no block given:
1443
- *
1444
- * whisper.transcribe("path/to/audio.wav", params)
1445
- * enum = whisper.each_segment
1446
- * enum.to_a # => [#<Whisper::Segment>, ...]
1447
- *
1448
- * call-seq:
1449
- * each_segment {|segment| ... }
1450
- * each_segment -> Enumerator
1451
- */
1452
- static VALUE ruby_whisper_each_segment(VALUE self) {
1453
- if (!rb_block_given_p()) {
1454
- const VALUE method_name = rb_funcall(self, id___method__, 0);
1455
- return rb_funcall(self, id_to_enum, 1, method_name);
1456
- }
1457
-
1458
- ruby_whisper *rw;
1459
- Data_Get_Struct(self, ruby_whisper, rw);
1460
-
1461
- const int n_segments = whisper_full_n_segments(rw->context);
1462
- for (int i = 0; i < n_segments; ++i) {
1463
- rb_yield(rb_whisper_segment_initialize(self, i));
1464
- }
1465
-
1466
- return self;
1467
- }
1468
-
1469
- /*
1470
- * Hook called on new segment. Yields each Whisper::Segment.
1471
- *
1472
- * whisper.on_new_segment do |segment|
1473
- * # ...
1474
- * end
1475
- *
1476
- * call-seq:
1477
- * on_new_segment {|segment| ... }
1478
- */
1479
- static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
1480
- ruby_whisper_params *rws;
1481
- Data_Get_Struct(self, ruby_whisper_params, rws);
1482
- const VALUE blk = rb_block_proc();
1483
- rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
1484
- return Qnil;
1485
- }
1486
-
1487
- /*
1488
- * Hook called on progress update. Yields each progress Integer between 0 and 100.
1489
- *
1490
- * whisper.on_progress do |progress|
1491
- * # ...
1492
- * end
1493
- *
1494
- * call-seq:
1495
- * on_progress {|progress| ... }
1496
- */
1497
- static VALUE ruby_whisper_params_on_progress(VALUE self) {
1498
- ruby_whisper_params *rws;
1499
- Data_Get_Struct(self, ruby_whisper_params, rws);
1500
- const VALUE blk = rb_block_proc();
1501
- rb_ary_push(rws->progress_callback_container->callbacks, blk);
1502
- return Qnil;
1503
- }
1504
-
1505
- /*
1506
- * Call block to determine whether abort or not. Return +true+ when you want to abort.
1507
- *
1508
- * params.abort_on do
1509
- * if some_condition
1510
- * true # abort
1511
- * else
1512
- * false # continue
1513
- * end
1514
- * end
1515
- *
1516
- * call-seq:
1517
- * abort_on { ... }
1518
- */
1519
- static VALUE ruby_whisper_params_abort_on(VALUE self) {
1520
- ruby_whisper_params *rws;
1521
- Data_Get_Struct(self, ruby_whisper_params, rws);
1522
- const VALUE blk = rb_block_proc();
1523
- rb_ary_push(rws->abort_callback_container->callbacks, blk);
1524
- return Qnil;
1525
- }
1526
-
1527
- /*
1528
- * Start time in milliseconds.
1529
- *
1530
- * call-seq:
1531
- * start_time -> Integer
1532
- */
1533
- static VALUE ruby_whisper_segment_get_start_time(VALUE self) {
1534
- ruby_whisper_segment *rws;
1535
- Data_Get_Struct(self, ruby_whisper_segment, rws);
1536
- ruby_whisper *rw;
1537
- Data_Get_Struct(rws->context, ruby_whisper, rw);
1538
- const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
1539
- // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
1540
- return INT2NUM(t0 * 10);
1541
- }
1542
-
1543
- /*
1544
- * End time in milliseconds.
1545
- *
1546
- * call-seq:
1547
- * end_time -> Integer
1548
- */
1549
- static VALUE ruby_whisper_segment_get_end_time(VALUE self) {
1550
- ruby_whisper_segment *rws;
1551
- Data_Get_Struct(self, ruby_whisper_segment, rws);
1552
- ruby_whisper *rw;
1553
- Data_Get_Struct(rws->context, ruby_whisper, rw);
1554
- const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
1555
- // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
1556
- return INT2NUM(t1 * 10);
1557
- }
1558
-
1559
- /*
1560
- * Whether the next segment is predicted as a speaker turn.
1561
- *
1562
- * call-seq:
1563
- * speaker_turn_next? -> bool
1564
- */
1565
- static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) {
1566
- ruby_whisper_segment *rws;
1567
- Data_Get_Struct(self, ruby_whisper_segment, rws);
1568
- ruby_whisper *rw;
1569
- Data_Get_Struct(rws->context, ruby_whisper, rw);
1570
- return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
1571
- }
1572
-
1573
- /*
1574
- * call-seq:
1575
- * text -> String
1576
- */
1577
- static VALUE ruby_whisper_segment_get_text(VALUE self) {
1578
- ruby_whisper_segment *rws;
1579
- Data_Get_Struct(self, ruby_whisper_segment, rws);
1580
- ruby_whisper *rw;
1581
- Data_Get_Struct(rws->context, ruby_whisper, rw);
1582
- const char * text = whisper_full_get_segment_text(rw->context, rws->index);
1583
- return rb_str_new2(text);
1584
- }
1585
-
1586
- /*
1587
- * call-seq:
1588
- * no_speech_prob -> Float
1589
- */
1590
- static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
1591
- ruby_whisper_segment *rws;
1592
- Data_Get_Struct(self, ruby_whisper_segment, rws);
1593
- ruby_whisper *rw;
1594
- Data_Get_Struct(rws->context, ruby_whisper, rw);
1595
- return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
1596
- }
1597
-
1598
- static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
1599
- rb_gc_mark(rwm->context);
1600
- }
1601
-
1602
- static VALUE ruby_whisper_model_allocate(VALUE klass) {
1603
- ruby_whisper_model *rwm;
1604
- rwm = ALLOC(ruby_whisper_model);
1605
- return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
1606
- }
1607
-
1608
- static VALUE rb_whisper_model_initialize(VALUE context) {
1609
- ruby_whisper_model *rwm;
1610
- const VALUE model = ruby_whisper_model_allocate(cModel);
1611
- Data_Get_Struct(model, ruby_whisper_model, rwm);
1612
- rwm->context = context;
1613
- return model;
1614
- };
1615
-
1616
- /*
1617
- * call-seq:
1618
- * model -> Whisper::Model
1619
- */
1620
- static VALUE ruby_whisper_get_model(VALUE self) {
1621
- return rb_whisper_model_initialize(self);
1622
- }
1623
-
1624
- /*
1625
- * call-seq:
1626
- * n_vocab -> Integer
1627
- */
1628
- static VALUE ruby_whisper_c_model_n_vocab(VALUE self) {
1629
- ruby_whisper_model *rwm;
1630
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1631
- ruby_whisper *rw;
1632
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1633
- return INT2NUM(whisper_model_n_vocab(rw->context));
1634
- }
1635
-
1636
- /*
1637
- * call-seq:
1638
- * n_audio_ctx -> Integer
1639
- */
1640
- static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) {
1641
- ruby_whisper_model *rwm;
1642
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1643
- ruby_whisper *rw;
1644
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1645
- return INT2NUM(whisper_model_n_audio_ctx(rw->context));
1646
- }
1647
-
1648
- /*
1649
- * call-seq:
1650
- * n_audio_state -> Integer
1651
- */
1652
- static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) {
1653
- ruby_whisper_model *rwm;
1654
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1655
- ruby_whisper *rw;
1656
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1657
- return INT2NUM(whisper_model_n_audio_state(rw->context));
1658
- }
1659
-
1660
- /*
1661
- * call-seq:
1662
- * n_audio_head -> Integer
1663
- */
1664
- static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) {
1665
- ruby_whisper_model *rwm;
1666
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1667
- ruby_whisper *rw;
1668
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1669
- return INT2NUM(whisper_model_n_audio_head(rw->context));
1670
- }
1671
-
1672
- /*
1673
- * call-seq:
1674
- * n_audio_layer -> Integer
1675
- */
1676
- static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) {
1677
- ruby_whisper_model *rwm;
1678
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1679
- ruby_whisper *rw;
1680
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1681
- return INT2NUM(whisper_model_n_audio_layer(rw->context));
1682
- }
1683
-
1684
- /*
1685
- * call-seq:
1686
- * n_text_ctx -> Integer
1687
- */
1688
- static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) {
1689
- ruby_whisper_model *rwm;
1690
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1691
- ruby_whisper *rw;
1692
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1693
- return INT2NUM(whisper_model_n_text_ctx(rw->context));
1694
- }
1695
-
1696
- /*
1697
- * call-seq:
1698
- * n_text_state -> Integer
1699
- */
1700
- static VALUE ruby_whisper_c_model_n_text_state(VALUE self) {
1701
- ruby_whisper_model *rwm;
1702
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1703
- ruby_whisper *rw;
1704
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1705
- return INT2NUM(whisper_model_n_text_state(rw->context));
1706
- }
1707
-
1708
- /*
1709
- * call-seq:
1710
- * n_text_head -> Integer
1711
- */
1712
- static VALUE ruby_whisper_c_model_n_text_head(VALUE self) {
1713
- ruby_whisper_model *rwm;
1714
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1715
- ruby_whisper *rw;
1716
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1717
- return INT2NUM(whisper_model_n_text_head(rw->context));
1718
- }
1719
-
1720
- /*
1721
- * call-seq:
1722
- * n_text_layer -> Integer
1723
- */
1724
- static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) {
1725
- ruby_whisper_model *rwm;
1726
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1727
- ruby_whisper *rw;
1728
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1729
- return INT2NUM(whisper_model_n_text_layer(rw->context));
1730
- }
1731
-
1732
- /*
1733
- * call-seq:
1734
- * n_mels -> Integer
1735
- */
1736
- static VALUE ruby_whisper_c_model_n_mels(VALUE self) {
1737
- ruby_whisper_model *rwm;
1738
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1739
- ruby_whisper *rw;
1740
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1741
- return INT2NUM(whisper_model_n_mels(rw->context));
1742
- }
1743
-
1744
- /*
1745
- * call-seq:
1746
- * ftype -> Integer
1747
- */
1748
- static VALUE ruby_whisper_c_model_ftype(VALUE self) {
1749
- ruby_whisper_model *rwm;
1750
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1751
- ruby_whisper *rw;
1752
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1753
- return INT2NUM(whisper_model_ftype(rw->context));
1754
- }
1755
-
1756
- /*
1757
- * call-seq:
1758
- * type -> String
1759
- */
1760
- static VALUE ruby_whisper_c_model_type(VALUE self) {
1761
- ruby_whisper_model *rwm;
1762
- Data_Get_Struct(self, ruby_whisper_model, rwm);
1763
- ruby_whisper *rw;
1764
- Data_Get_Struct(rwm->context, ruby_whisper, rw);
1765
- return rb_str_new2(whisper_model_type_readable(rw->context));
1766
- }
1767
-
1768
- static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
1769
- const int c_code = NUM2INT(code);
1770
- const char *raw_message;
1771
- switch (c_code) {
1772
- case -2:
1773
- raw_message = "failed to compute log mel spectrogram";
1774
- break;
1775
- case -3:
1776
- raw_message = "failed to auto-detect language";
1777
- break;
1778
- case -4:
1779
- raw_message = "too many decoders requested";
1780
- break;
1781
- case -5:
1782
- raw_message = "audio_ctx is larger than the maximum allowed";
1783
- break;
1784
- case -6:
1785
- raw_message = "failed to encode";
1786
- break;
1787
- case -7:
1788
- raw_message = "whisper_kv_cache_init() failed for self-attention cache";
1789
- break;
1790
- case -8:
1791
- raw_message = "failed to decode";
1792
- break;
1793
- case -9:
1794
- raw_message = "failed to decode";
1795
- break;
1796
- default:
1797
- raw_message = "unknown error";
1798
- break;
1799
- }
1800
- const VALUE message = rb_str_new2(raw_message);
1801
- rb_call_super(1, &message);
1802
- rb_iv_set(self, "@code", code);
1803
-
1804
- return self;
1805
- }
1806
-
1807
-
1808
- void Init_whisper() {
1809
- id_to_s = rb_intern("to_s");
1810
- id_call = rb_intern("call");
1811
- id___method__ = rb_intern("__method__");
1812
- id_to_enum = rb_intern("to_enum");
1813
- id_length = rb_intern("length");
1814
- id_next = rb_intern("next");
1815
- id_new = rb_intern("new");
1816
- id_to_path = rb_intern("to_path");
1817
- id_URI = rb_intern("URI");
1818
- id_pre_converted_models = rb_intern("pre_converted_models");
1819
-
1820
- mWhisper = rb_define_module("Whisper");
1821
- cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
1822
- cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
1823
- eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
1824
-
1825
- rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
1826
- rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
1827
- rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
1828
- rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
1829
- rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
1830
- rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
1831
-
1832
- rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
1833
- rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
1834
- rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
1835
- rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
1836
- rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
1837
- rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
1838
-
1839
- rb_define_alloc_func(cContext, ruby_whisper_allocate);
1840
- rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
1841
-
1842
- rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
1843
- rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
1844
- rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
1845
- rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
1846
- rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
1847
- rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
1848
- rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
1849
- rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
1850
- rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
1851
- rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
1852
- rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
1853
- rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
1854
- rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
1855
- rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
1856
- rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
1857
- rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
1858
- rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
1859
- rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
1860
- rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
1861
- rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
1862
- rb_define_method(cContext, "full", ruby_whisper_full, -1);
1863
- rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
1864
-
1865
- rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
1866
-
1867
- rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
1868
- rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
1869
- rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
1870
- rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
1871
- rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
1872
- rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
1873
- rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
1874
- rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
1875
- rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
1876
- rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
1877
- rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
1878
- rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
1879
- rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
1880
- rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
1881
- rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
1882
- rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
1883
- rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
1884
- rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
1885
- rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
1886
- rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
1887
- rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
1888
- rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
1889
- rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
1890
- rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
1891
- rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0);
1892
- rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1);
1893
- rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
1894
- rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
1895
-
1896
- rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
1897
- rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
1898
- rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
1899
- rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
1900
-
1901
- rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
1902
- rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
1903
- rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0);
1904
- rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1);
1905
- rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0);
1906
- rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1);
1907
- rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0);
1908
- rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1);
1909
- rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0);
1910
- rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1);
1911
- rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0);
1912
- rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
1913
- rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
1914
- rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
1915
- rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
1916
- rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
1917
-
1918
- rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
1919
- rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
1920
- rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
1921
- rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
1922
- rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1923
- rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
1924
-
1925
- rb_define_attr(eError, "code", true, false);
1926
- rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
1927
-
1928
- // High leve
1929
- cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
1930
-
1931
- rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
1932
- rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
1933
- rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1934
- rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1935
- rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1936
- rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
1937
- rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
1938
- rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
1939
- rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
1940
- rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
1941
-
1942
- cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
1943
- rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
1944
- rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
1945
- rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0);
1946
- rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0);
1947
- rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0);
1948
- rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0);
1949
- rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0);
1950
- rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0);
1951
- rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0);
1952
- rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0);
1953
- rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0);
1954
- rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
1955
- rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
1956
- rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
1957
-
1958
- rb_require("whisper/model/uri");
1959
- }
1960
- #ifdef __cplusplus
1961
- }
1962
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bindings/ruby/ext/ruby_whisper.h CHANGED
@@ -22,4 +22,13 @@ typedef struct {
22
  ruby_whisper_callback_container *abort_callback_container;
23
  } ruby_whisper_params;
24
 
 
 
 
 
 
 
 
 
 
25
  #endif
 
22
  ruby_whisper_callback_container *abort_callback_container;
23
  } ruby_whisper_params;
24
 
25
+ typedef struct {
26
+ VALUE context;
27
+ int index;
28
+ } ruby_whisper_segment;
29
+
30
+ typedef struct {
31
+ VALUE context;
32
+ } ruby_whisper_model;
33
+
34
  #endif
bindings/ruby/ext/ruby_whisper_context.c ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
+ #include "ruby_whisper.h"
4
+
5
+ extern ID id_to_s;
6
+ extern ID id___method__;
7
+ extern ID id_to_enum;
8
+ extern ID id_length;
9
+ extern ID id_next;
10
+ extern ID id_new;
11
+ extern ID id_to_path;
12
+ extern ID id_URI;
13
+ extern ID id_pre_converted_models;
14
+
15
+ extern VALUE cContext;
16
+ extern VALUE eError;
17
+ extern VALUE cModel;
18
+
19
+ extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
20
+ extern VALUE rb_whisper_model_initialize(VALUE context);
21
+ extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
22
+ extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
23
+
24
+ static void
25
+ ruby_whisper_free(ruby_whisper *rw)
26
+ {
27
+ if (rw->context) {
28
+ whisper_free(rw->context);
29
+ rw->context = NULL;
30
+ }
31
+ }
32
+
33
+ void
34
+ rb_whisper_mark(ruby_whisper *rw)
35
+ {
36
+ // call rb_gc_mark on any ruby references in rw
37
+ }
38
+
39
+ void
40
+ rb_whisper_free(ruby_whisper *rw)
41
+ {
42
+ ruby_whisper_free(rw);
43
+ free(rw);
44
+ }
45
+
46
+ static VALUE
47
+ ruby_whisper_allocate(VALUE klass)
48
+ {
49
+ ruby_whisper *rw;
50
+ rw = ALLOC(ruby_whisper);
51
+ rw->context = NULL;
52
+ return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
53
+ }
54
+
55
+ /*
56
+ * call-seq:
57
+ * new("base.en") -> Whisper::Context
58
+ * new("path/to/model.bin") -> Whisper::Context
59
+ * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
60
+ */
61
+ static VALUE
62
+ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
63
+ {
64
+ ruby_whisper *rw;
65
+ VALUE whisper_model_file_path;
66
+
67
+ // TODO: we can support init from buffer here too maybe another ruby object to expose
68
+ rb_scan_args(argc, argv, "01", &whisper_model_file_path);
69
+ Data_Get_Struct(self, ruby_whisper, rw);
70
+
71
+ VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
72
+ VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
73
+ if (!NIL_P(pre_converted_model)) {
74
+ whisper_model_file_path = pre_converted_model;
75
+ }
76
+ if (TYPE(whisper_model_file_path) == T_STRING) {
77
+ const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
78
+ if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
79
+ VALUE uri_class = rb_const_get(cModel, id_URI);
80
+ whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
81
+ }
82
+ }
83
+ if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
84
+ VALUE uri_class = rb_const_get(cModel, id_URI);
85
+ whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
86
+ }
87
+ if (rb_respond_to(whisper_model_file_path, id_to_path)) {
88
+ whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
89
+ }
90
+ if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
91
+ rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
92
+ }
93
+ rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
94
+ if (rw->context == NULL) {
95
+ rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
96
+ }
97
+ return self;
98
+ }
99
+
100
+ /*
101
+ * call-seq:
102
+ * model_n_vocab -> Integer
103
+ */
104
+ VALUE ruby_whisper_model_n_vocab(VALUE self)
105
+ {
106
+ ruby_whisper *rw;
107
+ Data_Get_Struct(self, ruby_whisper, rw);
108
+ return INT2NUM(whisper_model_n_vocab(rw->context));
109
+ }
110
+
111
+ /*
112
+ * call-seq:
113
+ * model_n_audio_ctx -> Integer
114
+ */
115
+ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
116
+ {
117
+ ruby_whisper *rw;
118
+ Data_Get_Struct(self, ruby_whisper, rw);
119
+ return INT2NUM(whisper_model_n_audio_ctx(rw->context));
120
+ }
121
+
122
+ /*
123
+ * call-seq:
124
+ * model_n_audio_state -> Integer
125
+ */
126
+ VALUE ruby_whisper_model_n_audio_state(VALUE self)
127
+ {
128
+ ruby_whisper *rw;
129
+ Data_Get_Struct(self, ruby_whisper, rw);
130
+ return INT2NUM(whisper_model_n_audio_state(rw->context));
131
+ }
132
+
133
+ /*
134
+ * call-seq:
135
+ * model_n_audio_head -> Integer
136
+ */
137
+ VALUE ruby_whisper_model_n_audio_head(VALUE self)
138
+ {
139
+ ruby_whisper *rw;
140
+ Data_Get_Struct(self, ruby_whisper, rw);
141
+ return INT2NUM(whisper_model_n_audio_head(rw->context));
142
+ }
143
+
144
+ /*
145
+ * call-seq:
146
+ * model_n_audio_layer -> Integer
147
+ */
148
+ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
149
+ {
150
+ ruby_whisper *rw;
151
+ Data_Get_Struct(self, ruby_whisper, rw);
152
+ return INT2NUM(whisper_model_n_audio_layer(rw->context));
153
+ }
154
+
155
+ /*
156
+ * call-seq:
157
+ * model_n_text_ctx -> Integer
158
+ */
159
+ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
160
+ {
161
+ ruby_whisper *rw;
162
+ Data_Get_Struct(self, ruby_whisper, rw);
163
+ return INT2NUM(whisper_model_n_text_ctx(rw->context));
164
+ }
165
+
166
+ /*
167
+ * call-seq:
168
+ * model_n_text_state -> Integer
169
+ */
170
+ VALUE ruby_whisper_model_n_text_state(VALUE self)
171
+ {
172
+ ruby_whisper *rw;
173
+ Data_Get_Struct(self, ruby_whisper, rw);
174
+ return INT2NUM(whisper_model_n_text_state(rw->context));
175
+ }
176
+
177
+ /*
178
+ * call-seq:
179
+ * model_n_text_head -> Integer
180
+ */
181
+ VALUE ruby_whisper_model_n_text_head(VALUE self)
182
+ {
183
+ ruby_whisper *rw;
184
+ Data_Get_Struct(self, ruby_whisper, rw);
185
+ return INT2NUM(whisper_model_n_text_head(rw->context));
186
+ }
187
+
188
+ /*
189
+ * call-seq:
190
+ * model_n_text_layer -> Integer
191
+ */
192
+ VALUE ruby_whisper_model_n_text_layer(VALUE self)
193
+ {
194
+ ruby_whisper *rw;
195
+ Data_Get_Struct(self, ruby_whisper, rw);
196
+ return INT2NUM(whisper_model_n_text_layer(rw->context));
197
+ }
198
+
199
+ /*
200
+ * call-seq:
201
+ * model_n_mels -> Integer
202
+ */
203
+ VALUE ruby_whisper_model_n_mels(VALUE self)
204
+ {
205
+ ruby_whisper *rw;
206
+ Data_Get_Struct(self, ruby_whisper, rw);
207
+ return INT2NUM(whisper_model_n_mels(rw->context));
208
+ }
209
+
210
+ /*
211
+ * call-seq:
212
+ * model_ftype -> Integer
213
+ */
214
+ VALUE ruby_whisper_model_ftype(VALUE self)
215
+ {
216
+ ruby_whisper *rw;
217
+ Data_Get_Struct(self, ruby_whisper, rw);
218
+ return INT2NUM(whisper_model_ftype(rw->context));
219
+ }
220
+
221
+ /*
222
+ * call-seq:
223
+ * model_type -> String
224
+ */
225
+ VALUE ruby_whisper_model_type(VALUE self)
226
+ {
227
+ ruby_whisper *rw;
228
+ Data_Get_Struct(self, ruby_whisper, rw);
229
+ return rb_str_new2(whisper_model_type_readable(rw->context));
230
+ }
231
+
232
+ /*
233
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
234
+ * Not thread safe for same context
235
+ * Uses the specified decoding strategy to obtain the text.
236
+ *
237
+ * call-seq:
238
+ * full(params, samples, n_samples) -> nil
239
+ * full(params, samples) -> nil
240
+ *
241
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
242
+ */
243
+ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
244
+ {
245
+ if (argc < 2 || argc > 3) {
246
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
247
+ }
248
+
249
+ ruby_whisper *rw;
250
+ ruby_whisper_params *rwp;
251
+ Data_Get_Struct(self, ruby_whisper, rw);
252
+ VALUE params = argv[0];
253
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
254
+ VALUE samples = argv[1];
255
+ int n_samples;
256
+ rb_memory_view_t view;
257
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
258
+ if (argc == 3) {
259
+ n_samples = NUM2INT(argv[2]);
260
+ if (TYPE(samples) == T_ARRAY) {
261
+ if (RARRAY_LEN(samples) < n_samples) {
262
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
263
+ }
264
+ }
265
+ // Should check when samples.respond_to?(:length)?
266
+ } else {
267
+ if (TYPE(samples) == T_ARRAY) {
268
+ n_samples = RARRAY_LEN(samples);
269
+ } else if (memory_view_available_p) {
270
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
271
+ view.obj = Qnil;
272
+ rb_raise(rb_eArgError, "unable to get a memory view");
273
+ }
274
+ n_samples = view.byte_size / view.item_size;
275
+ } else if (rb_respond_to(samples, id_length)) {
276
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
277
+ } else {
278
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
279
+ }
280
+ }
281
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
282
+ if (memory_view_available_p) {
283
+ c_samples = (float *)view.data;
284
+ } else {
285
+ if (TYPE(samples) == T_ARRAY) {
286
+ for (int i = 0; i < n_samples; i++) {
287
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
288
+ }
289
+ } else {
290
+ // TODO: use rb_block_call
291
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
292
+ for (int i = 0; i < n_samples; i++) {
293
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
294
+ VALUE sample = rb_funcall(iter, id_next, 0);
295
+ c_samples[i] = RFLOAT_VALUE(sample);
296
+ }
297
+ }
298
+ }
299
+ register_callbacks(rwp, &self);
300
+ const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
301
+ if (0 == result) {
302
+ return self;
303
+ } else {
304
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
305
+ }
306
+ }
307
+
308
+ /*
309
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
310
+ * Result is stored in the default state of the context
311
+ * Not thread safe if executed in parallel on the same context.
312
+ * It seems this approach can offer some speedup in some cases.
313
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
314
+ *
315
+ * call-seq:
316
+ * full_parallel(params, samples) -> nil
317
+ * full_parallel(params, samples, n_samples) -> nil
318
+ * full_parallel(params, samples, n_samples, n_processors) -> nil
319
+ * full_parallel(params, samples, nil, n_processors) -> nil
320
+ */
321
+ static VALUE
322
+ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
323
+ {
324
+ if (argc < 2 || argc > 4) {
325
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
326
+ }
327
+
328
+ ruby_whisper *rw;
329
+ ruby_whisper_params *rwp;
330
+ Data_Get_Struct(self, ruby_whisper, rw);
331
+ VALUE params = argv[0];
332
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
333
+ VALUE samples = argv[1];
334
+ int n_samples;
335
+ int n_processors;
336
+ rb_memory_view_t view;
337
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
338
+ switch (argc) {
339
+ case 2:
340
+ n_processors = 1;
341
+ break;
342
+ case 3:
343
+ n_processors = 1;
344
+ break;
345
+ case 4:
346
+ n_processors = NUM2INT(argv[3]);
347
+ break;
348
+ }
349
+ if (argc >= 3 && !NIL_P(argv[2])) {
350
+ n_samples = NUM2INT(argv[2]);
351
+ if (TYPE(samples) == T_ARRAY) {
352
+ if (RARRAY_LEN(samples) < n_samples) {
353
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
354
+ }
355
+ }
356
+ // Should check when samples.respond_to?(:length)?
357
+ } else if (memory_view_available_p) {
358
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
359
+ view.obj = Qnil;
360
+ rb_raise(rb_eArgError, "unable to get a memory view");
361
+ }
362
+ n_samples = view.byte_size / view.item_size;
363
+ } else {
364
+ if (TYPE(samples) == T_ARRAY) {
365
+ n_samples = RARRAY_LEN(samples);
366
+ } else if (rb_respond_to(samples, id_length)) {
367
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
368
+ } else {
369
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
370
+ }
371
+ }
372
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
373
+ if (memory_view_available_p) {
374
+ c_samples = (float *)view.data;
375
+ } else {
376
+ if (TYPE(samples) == T_ARRAY) {
377
+ for (int i = 0; i < n_samples; i++) {
378
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
379
+ }
380
+ } else {
381
+ // FIXME: use rb_block_call
382
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
383
+ for (int i = 0; i < n_samples; i++) {
384
+ // TODO: check if iter is exhausted and raise ArgumentError
385
+ VALUE sample = rb_funcall(iter, id_next, 0);
386
+ c_samples[i] = RFLOAT_VALUE(sample);
387
+ }
388
+ }
389
+ }
390
+ register_callbacks(rwp, &self);
391
+ const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
392
+ if (0 == result) {
393
+ return self;
394
+ } else {
395
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
396
+ }
397
+ }
398
+
399
+ /*
400
+ * Number of segments.
401
+ *
402
+ * call-seq:
403
+ * full_n_segments -> Integer
404
+ */
405
+ static VALUE
406
+ ruby_whisper_full_n_segments(VALUE self)
407
+ {
408
+ ruby_whisper *rw;
409
+ Data_Get_Struct(self, ruby_whisper, rw);
410
+ return INT2NUM(whisper_full_n_segments(rw->context));
411
+ }
412
+
413
+ /*
414
+ * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
415
+ *
416
+ * call-seq:
417
+ * full_lang_id -> Integer
418
+ */
419
+ static VALUE
420
+ ruby_whisper_full_lang_id(VALUE self)
421
+ {
422
+ ruby_whisper *rw;
423
+ Data_Get_Struct(self, ruby_whisper, rw);
424
+ return INT2NUM(whisper_full_lang_id(rw->context));
425
+ }
426
+
427
+ static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment)
428
+ {
429
+ const int c_i_segment = NUM2INT(i_segment);
430
+ if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
431
+ rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
432
+ }
433
+ return c_i_segment;
434
+ }
435
+
436
+ /*
437
+ * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
438
+ *
439
+ * full_get_segment_t0(3) # => 1668 (16680 ms)
440
+ *
441
+ * call-seq:
442
+ * full_get_segment_t0(segment_index) -> Integer
443
+ */
444
+ static VALUE
445
+ ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
446
+ {
447
+ ruby_whisper *rw;
448
+ Data_Get_Struct(self, ruby_whisper, rw);
449
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
450
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
451
+ return INT2NUM(t0);
452
+ }
453
+
454
+ /*
455
+ * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
456
+ *
457
+ * full_get_segment_t1(3) # => 1668 (16680 ms)
458
+ *
459
+ * call-seq:
460
+ * full_get_segment_t1(segment_index) -> Integer
461
+ */
462
+ static VALUE
463
+ ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
464
+ {
465
+ ruby_whisper *rw;
466
+ Data_Get_Struct(self, ruby_whisper, rw);
467
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
468
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
469
+ return INT2NUM(t1);
470
+ }
471
+
472
+ /*
473
+ * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
474
+ *
475
+ * full_get_segment_speacker_turn_next(3) # => true
476
+ *
477
+ * call-seq:
478
+ * full_get_segment_speacker_turn_next(segment_index) -> bool
479
+ */
480
+ static VALUE
481
+ ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
482
+ {
483
+ ruby_whisper *rw;
484
+ Data_Get_Struct(self, ruby_whisper, rw);
485
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
486
+ const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
487
+ return speaker_turn_next ? Qtrue : Qfalse;
488
+ }
489
+
490
+ /*
491
+ * Text of a segment indexed by +segment_index+.
492
+ *
493
+ * full_get_segment_text(3) # => "ask not what your country can do for you, ..."
494
+ *
495
+ * call-seq:
496
+ * full_get_segment_text(segment_index) -> String
497
+ */
498
+ static VALUE
499
+ ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
500
+ {
501
+ ruby_whisper *rw;
502
+ Data_Get_Struct(self, ruby_whisper, rw);
503
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
504
+ const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
505
+ return rb_str_new2(text);
506
+ }
507
+
508
+ /*
509
+ * call-seq:
510
+ * full_get_segment_no_speech_prob(segment_index) -> Float
511
+ */
512
+ static VALUE
513
+ ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
514
+ {
515
+ ruby_whisper *rw;
516
+ Data_Get_Struct(self, ruby_whisper, rw);
517
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
518
+ const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
519
+ return DBL2NUM(no_speech_prob);
520
+ }
521
+
522
+ // High level API
523
+
524
+ static VALUE
525
+ ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
526
+ {
527
+ return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
528
+ }
529
+
530
+ /*
531
+ * Yields each Whisper::Segment:
532
+ *
533
+ * whisper.transcribe("path/to/audio.wav", params)
534
+ * whisper.each_segment do |segment|
535
+ * puts segment.text
536
+ * end
537
+ *
538
+ * Returns an Enumerator if no block given:
539
+ *
540
+ * whisper.transcribe("path/to/audio.wav", params)
541
+ * enum = whisper.each_segment
542
+ * enum.to_a # => [#<Whisper::Segment>, ...]
543
+ *
544
+ * call-seq:
545
+ * each_segment {|segment| ... }
546
+ * each_segment -> Enumerator
547
+ */
548
+ static VALUE
549
+ ruby_whisper_each_segment(VALUE self)
550
+ {
551
+ if (!rb_block_given_p()) {
552
+ const VALUE method_name = rb_funcall(self, id___method__, 0);
553
+ return rb_funcall(self, id_to_enum, 1, method_name);
554
+ }
555
+
556
+ ruby_whisper *rw;
557
+ Data_Get_Struct(self, ruby_whisper, rw);
558
+
559
+ const int n_segments = whisper_full_n_segments(rw->context);
560
+ for (int i = 0; i < n_segments; ++i) {
561
+ rb_yield(rb_whisper_segment_initialize(self, i));
562
+ }
563
+
564
+ return self;
565
+ }
566
+
567
+ /*
568
+ * call-seq:
569
+ * model -> Whisper::Model
570
+ */
571
+ static VALUE
572
+ ruby_whisper_get_model(VALUE self)
573
+ {
574
+ return rb_whisper_model_initialize(self);
575
+ }
576
+
577
+ void
578
+ init_ruby_whisper_context(VALUE *mWhisper)
579
+ {
580
+ cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
581
+
582
+ rb_define_alloc_func(cContext, ruby_whisper_allocate);
583
+ rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
584
+
585
+ rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
586
+ rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
587
+ rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
588
+ rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
589
+ rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
590
+ rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
591
+ rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
592
+ rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
593
+ rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
594
+ rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
595
+ rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
596
+ rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
597
+ rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
598
+ rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
599
+ rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
600
+ rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
601
+ rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
602
+ rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
603
+ rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
604
+ rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
605
+ rb_define_method(cContext, "full", ruby_whisper_full, -1);
606
+ rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
607
+
608
+ // High leve
609
+ rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
610
+ rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
611
+
612
+ rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
613
+ }
bindings/ruby/ext/ruby_whisper_error.c ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+
3
+ extern VALUE eError;
4
+
5
+ VALUE ruby_whisper_error_initialize(VALUE self, VALUE code)
6
+ {
7
+ const int c_code = NUM2INT(code);
8
+ const char *raw_message;
9
+ switch (c_code) {
10
+ case -2:
11
+ raw_message = "failed to compute log mel spectrogram";
12
+ break;
13
+ case -3:
14
+ raw_message = "failed to auto-detect language";
15
+ break;
16
+ case -4:
17
+ raw_message = "too many decoders requested";
18
+ break;
19
+ case -5:
20
+ raw_message = "audio_ctx is larger than the maximum allowed";
21
+ break;
22
+ case -6:
23
+ raw_message = "failed to encode";
24
+ break;
25
+ case -7:
26
+ raw_message = "whisper_kv_cache_init() failed for self-attention cache";
27
+ break;
28
+ case -8:
29
+ raw_message = "failed to decode";
30
+ break;
31
+ case -9:
32
+ raw_message = "failed to decode";
33
+ break;
34
+ default:
35
+ raw_message = "unknown error";
36
+ break;
37
+ }
38
+ const VALUE message = rb_str_new2(raw_message);
39
+ rb_call_super(1, &message);
40
+ rb_iv_set(self, "@code", code);
41
+
42
+ return self;
43
+ }
44
+
45
+ void
46
+ init_ruby_whisper_error(VALUE *mWhisper)
47
+ {
48
+ eError = rb_define_class_under(*mWhisper, "Error", rb_eStandardError);
49
+
50
+ rb_define_attr(eError, "code", true, false);
51
+ rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
52
+ }
bindings/ruby/ext/ruby_whisper_model.c ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include "ruby_whisper.h"
3
+
4
+ extern VALUE cModel;
5
+
6
+ static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
7
+ rb_gc_mark(rwm->context);
8
+ }
9
+
10
+ static VALUE ruby_whisper_model_allocate(VALUE klass) {
11
+ ruby_whisper_model *rwm;
12
+ rwm = ALLOC(ruby_whisper_model);
13
+ return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
14
+ }
15
+
16
+ VALUE rb_whisper_model_initialize(VALUE context) {
17
+ ruby_whisper_model *rwm;
18
+ const VALUE model = ruby_whisper_model_allocate(cModel);
19
+ Data_Get_Struct(model, ruby_whisper_model, rwm);
20
+ rwm->context = context;
21
+ return model;
22
+ };
23
+
24
+ /*
25
+ * call-seq:
26
+ * n_vocab -> Integer
27
+ */
28
+ static VALUE
29
+ ruby_whisper_model_n_vocab(VALUE self)
30
+ {
31
+ ruby_whisper_model *rwm;
32
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
33
+ ruby_whisper *rw;
34
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
35
+ return INT2NUM(whisper_model_n_vocab(rw->context));
36
+ }
37
+
38
+ /*
39
+ * call-seq:
40
+ * n_audio_ctx -> Integer
41
+ */
42
+ static VALUE
43
+ ruby_whisper_model_n_audio_ctx(VALUE self)
44
+ {
45
+ ruby_whisper_model *rwm;
46
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
47
+ ruby_whisper *rw;
48
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
49
+ return INT2NUM(whisper_model_n_audio_ctx(rw->context));
50
+ }
51
+
52
+ /*
53
+ * call-seq:
54
+ * n_audio_state -> Integer
55
+ */
56
+ static VALUE
57
+ ruby_whisper_model_n_audio_state(VALUE self)
58
+ {
59
+ ruby_whisper_model *rwm;
60
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
61
+ ruby_whisper *rw;
62
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
63
+ return INT2NUM(whisper_model_n_audio_state(rw->context));
64
+ }
65
+
66
+ /*
67
+ * call-seq:
68
+ * n_audio_head -> Integer
69
+ */
70
+ static VALUE
71
+ ruby_whisper_model_n_audio_head(VALUE self)
72
+ {
73
+ ruby_whisper_model *rwm;
74
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
75
+ ruby_whisper *rw;
76
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
77
+ return INT2NUM(whisper_model_n_audio_head(rw->context));
78
+ }
79
+
80
+ /*
81
+ * call-seq:
82
+ * n_audio_layer -> Integer
83
+ */
84
+ static VALUE
85
+ ruby_whisper_model_n_audio_layer(VALUE self)
86
+ {
87
+ ruby_whisper_model *rwm;
88
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
89
+ ruby_whisper *rw;
90
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
91
+ return INT2NUM(whisper_model_n_audio_layer(rw->context));
92
+ }
93
+
94
+ /*
95
+ * call-seq:
96
+ * n_text_ctx -> Integer
97
+ */
98
+ static VALUE
99
+ ruby_whisper_model_n_text_ctx(VALUE self)
100
+ {
101
+ ruby_whisper_model *rwm;
102
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
103
+ ruby_whisper *rw;
104
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
105
+ return INT2NUM(whisper_model_n_text_ctx(rw->context));
106
+ }
107
+
108
+ /*
109
+ * call-seq:
110
+ * n_text_state -> Integer
111
+ */
112
+ static VALUE
113
+ ruby_whisper_model_n_text_state(VALUE self)
114
+ {
115
+ ruby_whisper_model *rwm;
116
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
117
+ ruby_whisper *rw;
118
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
119
+ return INT2NUM(whisper_model_n_text_state(rw->context));
120
+ }
121
+
122
+ /*
123
+ * call-seq:
124
+ * n_text_head -> Integer
125
+ */
126
+ static VALUE
127
+ ruby_whisper_model_n_text_head(VALUE self)
128
+ {
129
+ ruby_whisper_model *rwm;
130
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
131
+ ruby_whisper *rw;
132
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
133
+ return INT2NUM(whisper_model_n_text_head(rw->context));
134
+ }
135
+
136
+ /*
137
+ * call-seq:
138
+ * n_text_layer -> Integer
139
+ */
140
+ static VALUE
141
+ ruby_whisper_model_n_text_layer(VALUE self)
142
+ {
143
+ ruby_whisper_model *rwm;
144
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
145
+ ruby_whisper *rw;
146
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
147
+ return INT2NUM(whisper_model_n_text_layer(rw->context));
148
+ }
149
+
150
+ /*
151
+ * call-seq:
152
+ * n_mels -> Integer
153
+ */
154
+ static VALUE
155
+ ruby_whisper_model_n_mels(VALUE self)
156
+ {
157
+ ruby_whisper_model *rwm;
158
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
159
+ ruby_whisper *rw;
160
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
161
+ return INT2NUM(whisper_model_n_mels(rw->context));
162
+ }
163
+
164
+ /*
165
+ * call-seq:
166
+ * ftype -> Integer
167
+ */
168
+ static VALUE
169
+ ruby_whisper_model_ftype(VALUE self)
170
+ {
171
+ ruby_whisper_model *rwm;
172
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
173
+ ruby_whisper *rw;
174
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
175
+ return INT2NUM(whisper_model_ftype(rw->context));
176
+ }
177
+
178
+ /*
179
+ * call-seq:
180
+ * type -> String
181
+ */
182
+ static VALUE
183
+ ruby_whisper_model_type(VALUE self)
184
+ {
185
+ ruby_whisper_model *rwm;
186
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
187
+ ruby_whisper *rw;
188
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
189
+ return rb_str_new2(whisper_model_type_readable(rw->context));
190
+ }
191
+
192
+ void
193
+ init_ruby_whisper_model(VALUE *mWhisper)
194
+ {
195
+ cModel = rb_define_class_under(*mWhisper, "Model", rb_cObject);
196
+
197
+ rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
198
+ rb_define_method(cModel, "n_vocab", ruby_whisper_model_n_vocab, 0);
199
+ rb_define_method(cModel, "n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
200
+ rb_define_method(cModel, "n_audio_state", ruby_whisper_model_n_audio_state, 0);
201
+ rb_define_method(cModel, "n_audio_head", ruby_whisper_model_n_audio_head, 0);
202
+ rb_define_method(cModel, "n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
203
+ rb_define_method(cModel, "n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
204
+ rb_define_method(cModel, "n_text_state", ruby_whisper_model_n_text_state, 0);
205
+ rb_define_method(cModel, "n_text_head", ruby_whisper_model_n_text_head, 0);
206
+ rb_define_method(cModel, "n_text_layer", ruby_whisper_model_n_text_layer, 0);
207
+ rb_define_method(cModel, "n_mels", ruby_whisper_model_n_mels, 0);
208
+ rb_define_method(cModel, "ftype", ruby_whisper_model_ftype, 0);
209
+ rb_define_method(cModel, "type", ruby_whisper_model_type, 0);
210
+ }
bindings/ruby/ext/ruby_whisper_params.c ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include "ruby_whisper.h"
3
+
4
+ #define BOOL_PARAMS_SETTER(self, prop, value) \
5
+ ruby_whisper_params *rwp; \
6
+ Data_Get_Struct(self, ruby_whisper_params, rwp); \
7
+ if (value == Qfalse || value == Qnil) { \
8
+ rwp->params.prop = false; \
9
+ } else { \
10
+ rwp->params.prop = true; \
11
+ } \
12
+ return value; \
13
+
14
+ #define BOOL_PARAMS_GETTER(self, prop) \
15
+ ruby_whisper_params *rwp; \
16
+ Data_Get_Struct(self, ruby_whisper_params, rwp); \
17
+ if (rwp->params.prop) { \
18
+ return Qtrue; \
19
+ } else { \
20
+ return Qfalse; \
21
+ }
22
+
23
+ #define DEFINE_PARAM(param_name, nth) \
24
+ id_ ## param_name = rb_intern(#param_name); \
25
+ param_names[nth] = id_ ## param_name; \
26
+ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
27
+ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
28
+
29
+ #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
30
+
31
+ extern VALUE cParams;
32
+
33
+ extern ID id_call;
34
+
35
+ extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
36
+
37
+ static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
38
+ static ID id_language;
39
+ static ID id_translate;
40
+ static ID id_no_context;
41
+ static ID id_single_segment;
42
+ static ID id_print_special;
43
+ static ID id_print_progress;
44
+ static ID id_print_realtime;
45
+ static ID id_print_timestamps;
46
+ static ID id_suppress_blank;
47
+ static ID id_suppress_nst;
48
+ static ID id_token_timestamps;
49
+ static ID id_split_on_word;
50
+ static ID id_initial_prompt;
51
+ static ID id_diarize;
52
+ static ID id_offset;
53
+ static ID id_duration;
54
+ static ID id_max_text_tokens;
55
+ static ID id_temperature;
56
+ static ID id_max_initial_ts;
57
+ static ID id_length_penalty;
58
+ static ID id_temperature_inc;
59
+ static ID id_entropy_thold;
60
+ static ID id_logprob_thold;
61
+ static ID id_no_speech_thold;
62
+ static ID id_new_segment_callback;
63
+ static ID id_new_segment_callback_user_data;
64
+ static ID id_progress_callback;
65
+ static ID id_progress_callback_user_data;
66
+ static ID id_abort_callback;
67
+ static ID id_abort_callback_user_data;
68
+
69
+ static void
70
+ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
71
+ {
72
+ rb_gc_mark(rwc->user_data);
73
+ rb_gc_mark(rwc->callback);
74
+ rb_gc_mark(rwc->callbacks);
75
+ }
76
+
77
+ static ruby_whisper_callback_container*
78
+ rb_whisper_callback_container_allocate() {
79
+ ruby_whisper_callback_container *container;
80
+ container = ALLOC(ruby_whisper_callback_container);
81
+ container->context = NULL;
82
+ container->user_data = Qnil;
83
+ container->callback = Qnil;
84
+ container->callbacks = rb_ary_new();
85
+ return container;
86
+ }
87
+
88
+ static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
89
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
90
+
91
+ // Currently, doesn't support state because
92
+ // those require to resolve GC-related problems.
93
+ if (!NIL_P(container->callback)) {
94
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
95
+ }
96
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
97
+ if (0 == callbacks_len) {
98
+ return;
99
+ }
100
+ const int n_segments = whisper_full_n_segments_from_state(state);
101
+ for (int i = n_new; i > 0; i--) {
102
+ int i_segment = n_segments - i;
103
+ VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
104
+ for (int j = 0; j < callbacks_len; j++) {
105
+ VALUE cb = rb_ary_entry(container->callbacks, j);
106
+ rb_funcall(cb, id_call, 1, segment);
107
+ }
108
+ }
109
+ }
110
+
111
+ static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
112
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
113
+ const VALUE progress = INT2NUM(progress_cur);
114
+ // Currently, doesn't support state because
115
+ // those require to resolve GC-related problems.
116
+ if (!NIL_P(container->callback)) {
117
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
118
+ }
119
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
120
+ if (0 == callbacks_len) {
121
+ return;
122
+ }
123
+ for (int j = 0; j < callbacks_len; j++) {
124
+ VALUE cb = rb_ary_entry(container->callbacks, j);
125
+ rb_funcall(cb, id_call, 1, progress);
126
+ }
127
+ }
128
+
129
+ static bool abort_callback(void * user_data) {
130
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
131
+ if (!NIL_P(container->callback)) {
132
+ VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
133
+ if (!NIL_P(result) && Qfalse != result) {
134
+ return true;
135
+ }
136
+ }
137
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
138
+ if (0 == callbacks_len) {
139
+ return false;
140
+ }
141
+ for (int j = 0; j < callbacks_len; j++) {
142
+ VALUE cb = rb_ary_entry(container->callbacks, j);
143
+ VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
144
+ if (!NIL_P(result) && Qfalse != result) {
145
+ return true;
146
+ }
147
+ }
148
+ return false;
149
+ }
150
+
151
+ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
152
+ if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
153
+ rwp->new_segment_callback_container->context = context;
154
+ rwp->params.new_segment_callback = new_segment_callback;
155
+ rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
156
+ }
157
+
158
+ if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
159
+ rwp->progress_callback_container->context = context;
160
+ rwp->params.progress_callback = progress_callback;
161
+ rwp->params.progress_callback_user_data = rwp->progress_callback_container;
162
+ }
163
+
164
+ if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
165
+ rwp->abort_callback_container->context = context;
166
+ rwp->params.abort_callback = abort_callback;
167
+ rwp->params.abort_callback_user_data = rwp->abort_callback_container;
168
+ }
169
+ }
170
+
171
+ void
172
+ rb_whisper_params_mark(ruby_whisper_params *rwp)
173
+ {
174
+ rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
175
+ rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
176
+ rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
177
+ }
178
+
179
+ void
180
+ ruby_whisper_params_free(ruby_whisper_params *rwp)
181
+ {
182
+ }
183
+
184
+ void
185
+ rb_whisper_params_free(ruby_whisper_params *rwp)
186
+ {
187
+ // How to free user_data and callback only when not referred to by others?
188
+ ruby_whisper_params_free(rwp);
189
+ free(rwp);
190
+ }
191
+
192
+ static VALUE
193
+ ruby_whisper_params_allocate(VALUE klass)
194
+ {
195
+ ruby_whisper_params *rwp;
196
+ rwp = ALLOC(ruby_whisper_params);
197
+ rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
198
+ rwp->diarize = false;
199
+ rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
200
+ rwp->progress_callback_container = rb_whisper_callback_container_allocate();
201
+ rwp->abort_callback_container = rb_whisper_callback_container_allocate();
202
+ return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
203
+ }
204
+
205
+ /*
206
+ * params.language = "auto" | "en", etc...
207
+ *
208
+ * call-seq:
209
+ * language = lang_name -> lang_name
210
+ */
211
+ static VALUE
212
+ ruby_whisper_params_set_language(VALUE self, VALUE value)
213
+ {
214
+ ruby_whisper_params *rwp;
215
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
216
+ if (value == Qfalse || value == Qnil) {
217
+ rwp->params.language = "auto";
218
+ } else {
219
+ rwp->params.language = StringValueCStr(value);
220
+ }
221
+ return value;
222
+ }
223
+ /*
224
+ * call-seq:
225
+ * language -> String
226
+ */
227
+ static VALUE
228
+ ruby_whisper_params_get_language(VALUE self)
229
+ {
230
+ ruby_whisper_params *rwp;
231
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
232
+ if (rwp->params.language) {
233
+ return rb_str_new2(rwp->params.language);
234
+ } else {
235
+ return rb_str_new2("auto");
236
+ }
237
+ }
238
+ /*
239
+ * call-seq:
240
+ * translate = do_translate -> do_translate
241
+ */
242
+ static VALUE
243
+ ruby_whisper_params_set_translate(VALUE self, VALUE value)
244
+ {
245
+ BOOL_PARAMS_SETTER(self, translate, value)
246
+ }
247
+ /*
248
+ * call-seq:
249
+ * translate -> bool
250
+ */
251
+ static VALUE
252
+ ruby_whisper_params_get_translate(VALUE self)
253
+ {
254
+ BOOL_PARAMS_GETTER(self, translate)
255
+ }
256
+ /*
257
+ * call-seq:
258
+ * no_context = dont_use_context -> dont_use_context
259
+ */
260
+ static VALUE
261
+ ruby_whisper_params_set_no_context(VALUE self, VALUE value)
262
+ {
263
+ BOOL_PARAMS_SETTER(self, no_context, value)
264
+ }
265
+ /*
266
+ * If true, does not use past transcription (if any) as initial prompt for the decoder.
267
+ *
268
+ * call-seq:
269
+ * no_context -> bool
270
+ */
271
+ static VALUE
272
+ ruby_whisper_params_get_no_context(VALUE self)
273
+ {
274
+ BOOL_PARAMS_GETTER(self, no_context)
275
+ }
276
+ /*
277
+ * call-seq:
278
+ * single_segment = force_single -> force_single
279
+ */
280
+ static VALUE
281
+ ruby_whisper_params_set_single_segment(VALUE self, VALUE value)
282
+ {
283
+ BOOL_PARAMS_SETTER(self, single_segment, value)
284
+ }
285
+ /*
286
+ * If true, forces single segment output (useful for streaming).
287
+ *
288
+ * call-seq:
289
+ * single_segment -> bool
290
+ */
291
+ static VALUE
292
+ ruby_whisper_params_get_single_segment(VALUE self)
293
+ {
294
+ BOOL_PARAMS_GETTER(self, single_segment)
295
+ }
296
+ /*
297
+ * call-seq:
298
+ * print_special = force_print -> force_print
299
+ */
300
+ static VALUE
301
+ ruby_whisper_params_set_print_special(VALUE self, VALUE value)
302
+ {
303
+ BOOL_PARAMS_SETTER(self, print_special, value)
304
+ }
305
+ /*
306
+ * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
307
+ *
308
+ * call-seq:
309
+ * print_special -> bool
310
+ */
311
+ static VALUE
312
+ ruby_whisper_params_get_print_special(VALUE self)
313
+ {
314
+ BOOL_PARAMS_GETTER(self, print_special)
315
+ }
316
+ /*
317
+ * call-seq:
318
+ * print_progress = force_print -> force_print
319
+ */
320
+ static VALUE
321
+ ruby_whisper_params_set_print_progress(VALUE self, VALUE value)
322
+ {
323
+ BOOL_PARAMS_SETTER(self, print_progress, value)
324
+ }
325
+ /*
326
+ * If true, prints progress information.
327
+ *
328
+ * call-seq:
329
+ * print_progress -> bool
330
+ */
331
+ static VALUE
332
+ ruby_whisper_params_get_print_progress(VALUE self)
333
+ {
334
+ BOOL_PARAMS_GETTER(self, print_progress)
335
+ }
336
+ /*
337
+ * call-seq:
338
+ * print_realtime = force_print -> force_print
339
+ */
340
+ static VALUE
341
+ ruby_whisper_params_set_print_realtime(VALUE self, VALUE value)
342
+ {
343
+ BOOL_PARAMS_SETTER(self, print_realtime, value)
344
+ }
345
+ /*
346
+ * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
347
+ * call-seq:
348
+ * print_realtime -> bool
349
+ */
350
+ static VALUE
351
+ ruby_whisper_params_get_print_realtime(VALUE self)
352
+ {
353
+ BOOL_PARAMS_GETTER(self, print_realtime)
354
+ }
355
+ /*
356
+ * call-seq:
357
+ * print_timestamps = force_print -> force_print
358
+ */
359
+ static VALUE
360
+ ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value)
361
+ {
362
+ BOOL_PARAMS_SETTER(self, print_timestamps, value)
363
+ }
364
+ /*
365
+ * If true, prints timestamps for each text segment when printing realtime.
366
+ *
367
+ * call-seq:
368
+ * print_timestamps -> bool
369
+ */
370
+ static VALUE
371
+ ruby_whisper_params_get_print_timestamps(VALUE self)
372
+ {
373
+ BOOL_PARAMS_GETTER(self, print_timestamps)
374
+ }
375
+ /*
376
+ * call-seq:
377
+ * suppress_blank = force_suppress -> force_suppress
378
+ */
379
+ static VALUE
380
+ ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value)
381
+ {
382
+ BOOL_PARAMS_SETTER(self, suppress_blank, value)
383
+ }
384
+ /*
385
+ * If true, suppresses blank outputs.
386
+ *
387
+ * call-seq:
388
+ * suppress_blank -> bool
389
+ */
390
+ static VALUE
391
+ ruby_whisper_params_get_suppress_blank(VALUE self)
392
+ {
393
+ BOOL_PARAMS_GETTER(self, suppress_blank)
394
+ }
395
+ /*
396
+ * call-seq:
397
+ * suppress_nst = force_suppress -> force_suppress
398
+ */
399
+ static VALUE
400
+ ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value)
401
+ {
402
+ BOOL_PARAMS_SETTER(self, suppress_nst, value)
403
+ }
404
+ /*
405
+ * If true, suppresses non-speech-tokens.
406
+ *
407
+ * call-seq:
408
+ * suppress_nst -> bool
409
+ */
410
+ static VALUE
411
+ ruby_whisper_params_get_suppress_nst(VALUE self)
412
+ {
413
+ BOOL_PARAMS_GETTER(self, suppress_nst)
414
+ }
415
+ /*
416
+ * If true, enables token-level timestamps.
417
+ *
418
+ * call-seq:
419
+ * token_timestamps -> bool
420
+ */
421
+ static VALUE
422
+ ruby_whisper_params_get_token_timestamps(VALUE self)
423
+ {
424
+ BOOL_PARAMS_GETTER(self, token_timestamps)
425
+ }
426
+ /*
427
+ * call-seq:
428
+ * token_timestamps = force_timestamps -> force_timestamps
429
+ */
430
+ static VALUE
431
+ ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value)
432
+ {
433
+ BOOL_PARAMS_SETTER(self, token_timestamps, value)
434
+ }
435
+ /*
436
+ * If true, split on word rather than on token (when used with max_len).
437
+ *
438
+ * call-seq:
439
+ * translate -> bool
440
+ */
441
+ static VALUE
442
+ ruby_whisper_params_get_split_on_word(VALUE self)
443
+ {
444
+ BOOL_PARAMS_GETTER(self, split_on_word)
445
+ }
446
+ /*
447
+ * call-seq:
448
+ * split_on_word = force_split -> force_split
449
+ */
450
+ static VALUE
451
+ ruby_whisper_params_set_split_on_word(VALUE self, VALUE value)
452
+ {
453
+ BOOL_PARAMS_SETTER(self, split_on_word, value)
454
+ }
455
+ /*
456
+ * Tokens to provide to the whisper decoder as initial prompt
457
+ * these are prepended to any existing text context from a previous call
458
+ * use whisper_tokenize() to convert text to tokens.
459
+ * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
460
+ *
461
+ * call-seq:
462
+ * initial_prompt -> String
463
+ */
464
+ static VALUE
465
+ ruby_whisper_params_get_initial_prompt(VALUE self)
466
+ {
467
+ ruby_whisper_params *rwp;
468
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
469
+ return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
470
+ }
471
+ /*
472
+ * call-seq:
473
+ * initial_prompt = prompt -> prompt
474
+ */
475
+ static VALUE
476
+ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
477
+ {
478
+ ruby_whisper_params *rwp;
479
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
480
+ rwp->params.initial_prompt = StringValueCStr(value);
481
+ return value;
482
+ }
483
+ /*
484
+ * If true, enables diarization.
485
+ *
486
+ * call-seq:
487
+ * diarize -> bool
488
+ */
489
+ static VALUE
490
+ ruby_whisper_params_get_diarize(VALUE self)
491
+ {
492
+ ruby_whisper_params *rwp;
493
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
494
+ if (rwp->diarize) {
495
+ return Qtrue;
496
+ } else {
497
+ return Qfalse;
498
+ }
499
+ }
500
+ /*
501
+ * call-seq:
502
+ * diarize = force_diarize -> force_diarize
503
+ */
504
+ static VALUE
505
+ ruby_whisper_params_set_diarize(VALUE self, VALUE value)
506
+ {
507
+ ruby_whisper_params *rwp;
508
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
509
+ if (value == Qfalse || value == Qnil) {
510
+ rwp->diarize = false;
511
+ } else {
512
+ rwp->diarize = true;
513
+ } \
514
+ return value;
515
+ }
516
+
517
+ /*
518
+ * Start offset in ms.
519
+ *
520
+ * call-seq:
521
+ * offset -> Integer
522
+ */
523
+ static VALUE
524
+ ruby_whisper_params_get_offset(VALUE self)
525
+ {
526
+ ruby_whisper_params *rwp;
527
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
528
+ return INT2NUM(rwp->params.offset_ms);
529
+ }
530
+ /*
531
+ * call-seq:
532
+ * offset = offset_ms -> offset_ms
533
+ */
534
+ static VALUE
535
+ ruby_whisper_params_set_offset(VALUE self, VALUE value)
536
+ {
537
+ ruby_whisper_params *rwp;
538
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
539
+ rwp->params.offset_ms = NUM2INT(value);
540
+ return value;
541
+ }
542
+ /*
543
+ * Audio duration to process in ms.
544
+ *
545
+ * call-seq:
546
+ * duration -> Integer
547
+ */
548
+ static VALUE
549
+ ruby_whisper_params_get_duration(VALUE self)
550
+ {
551
+ ruby_whisper_params *rwp;
552
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
553
+ return INT2NUM(rwp->params.duration_ms);
554
+ }
555
+ /*
556
+ * call-seq:
557
+ * duration = duration_ms -> duration_ms
558
+ */
559
+ static VALUE
560
+ ruby_whisper_params_set_duration(VALUE self, VALUE value)
561
+ {
562
+ ruby_whisper_params *rwp;
563
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
564
+ rwp->params.duration_ms = NUM2INT(value);
565
+ return value;
566
+ }
567
+
568
+ /*
569
+ * Max tokens to use from past text as prompt for the decoder.
570
+ *
571
+ * call-seq:
572
+ * max_text_tokens -> Integer
573
+ */
574
+ static VALUE
575
+ ruby_whisper_params_get_max_text_tokens(VALUE self)
576
+ {
577
+ ruby_whisper_params *rwp;
578
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
579
+ return INT2NUM(rwp->params.n_max_text_ctx);
580
+ }
581
+ /*
582
+ * call-seq:
583
+ * max_text_tokens = n_tokens -> n_tokens
584
+ */
585
+ static VALUE
586
+ ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
587
+ {
588
+ ruby_whisper_params *rwp;
589
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
590
+ rwp->params.n_max_text_ctx = NUM2INT(value);
591
+ return value;
592
+ }
593
+ /*
594
+ * call-seq:
595
+ * temperature -> Float
596
+ */
597
+ static VALUE
598
+ ruby_whisper_params_get_temperature(VALUE self)
599
+ {
600
+ ruby_whisper_params *rwp;
601
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
602
+ return DBL2NUM(rwp->params.temperature);
603
+ }
604
+ /*
605
+ * call-seq:
606
+ * temperature = temp -> temp
607
+ */
608
+ static VALUE
609
+ ruby_whisper_params_set_temperature(VALUE self, VALUE value)
610
+ {
611
+ ruby_whisper_params *rwp;
612
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
613
+ rwp->params.temperature = RFLOAT_VALUE(value);
614
+ return value;
615
+ }
616
+ /*
617
+ * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
618
+ *
619
+ * call-seq:
620
+ * max_initial_ts -> Flaot
621
+ */
622
+ static VALUE
623
+ ruby_whisper_params_get_max_initial_ts(VALUE self)
624
+ {
625
+ ruby_whisper_params *rwp;
626
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
627
+ return DBL2NUM(rwp->params.max_initial_ts);
628
+ }
629
+ /*
630
+ * call-seq:
631
+ * max_initial_ts = timestamp -> timestamp
632
+ */
633
+ static VALUE
634
+ ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
635
+ {
636
+ ruby_whisper_params *rwp;
637
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
638
+ rwp->params.max_initial_ts = RFLOAT_VALUE(value);
639
+ return value;
640
+ }
641
+ /*
642
+ * call-seq:
643
+ * length_penalty -> Float
644
+ */
645
+ static VALUE
646
+ ruby_whisper_params_get_length_penalty(VALUE self)
647
+ {
648
+ ruby_whisper_params *rwp;
649
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
650
+ return DBL2NUM(rwp->params.length_penalty);
651
+ }
652
+ /*
653
+ * call-seq:
654
+ * length_penalty = penalty -> penalty
655
+ */
656
+ static VALUE
657
+ ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
658
+ {
659
+ ruby_whisper_params *rwp;
660
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
661
+ rwp->params.length_penalty = RFLOAT_VALUE(value);
662
+ return value;
663
+ }
664
+ /*
665
+ * call-seq:
666
+ * temperature_inc -> Float
667
+ */
668
+ static VALUE
669
+ ruby_whisper_params_get_temperature_inc(VALUE self)
670
+ {
671
+ ruby_whisper_params *rwp;
672
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
673
+ return DBL2NUM(rwp->params.temperature_inc);
674
+ }
675
+ /*
676
+ * call-seq:
677
+ * temperature_inc = inc -> inc
678
+ */
679
+ static VALUE
680
+ ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
681
+ {
682
+ ruby_whisper_params *rwp;
683
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
684
+ rwp->params.temperature_inc = RFLOAT_VALUE(value);
685
+ return value;
686
+ }
687
+ /*
688
+ * Similar to OpenAI's "compression_ratio_threshold"
689
+ *
690
+ * call-seq:
691
+ * entropy_thold -> Float
692
+ */
693
+ static VALUE
694
+ ruby_whisper_params_get_entropy_thold(VALUE self)
695
+ {
696
+ ruby_whisper_params *rwp;
697
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
698
+ return DBL2NUM(rwp->params.entropy_thold);
699
+ }
700
+ /*
701
+ * call-seq:
702
+ * entropy_thold = threshold -> threshold
703
+ */
704
+ static VALUE
705
+ ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
706
+ {
707
+ ruby_whisper_params *rwp;
708
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
709
+ rwp->params.entropy_thold = RFLOAT_VALUE(value);
710
+ return value;
711
+ }
712
+ /*
713
+ * call-seq:
714
+ * logprob_thold -> Float
715
+ */
716
+ static VALUE
717
+ ruby_whisper_params_get_logprob_thold(VALUE self)
718
+ {
719
+ ruby_whisper_params *rwp;
720
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
721
+ return DBL2NUM(rwp->params.logprob_thold);
722
+ }
723
+ /*
724
+ * call-seq:
725
+ * logprob_thold = threshold -> threshold
726
+ */
727
+ static VALUE
728
+ ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
729
+ {
730
+ ruby_whisper_params *rwp;
731
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
732
+ rwp->params.logprob_thold = RFLOAT_VALUE(value);
733
+ return value;
734
+ }
735
+ /*
736
+ * call-seq:
737
+ * no_speech_thold -> Float
738
+ */
739
+ static VALUE
740
+ ruby_whisper_params_get_no_speech_thold(VALUE self)
741
+ {
742
+ ruby_whisper_params *rwp;
743
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
744
+ return DBL2NUM(rwp->params.no_speech_thold);
745
+ }
746
+ /*
747
+ * call-seq:
748
+ * no_speech_thold = threshold -> threshold
749
+ */
750
+ static VALUE
751
+ ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
752
+ {
753
+ ruby_whisper_params *rwp;
754
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
755
+ rwp->params.no_speech_thold = RFLOAT_VALUE(value);
756
+ return value;
757
+ }
758
+ static VALUE
759
+ ruby_whisper_params_get_new_segment_callback(VALUE self)
760
+ {
761
+ ruby_whisper_params *rwp;
762
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
763
+ return rwp->new_segment_callback_container->callback;
764
+ }
765
+ /*
766
+ * Sets new segment callback, called for every newly generated text segment.
767
+ *
768
+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
769
+ * # ...
770
+ * }
771
+ *
772
+ * call-seq:
773
+ * new_segment_callback = callback -> callback
774
+ */
775
+ static VALUE
776
+ ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
777
+ {
778
+ ruby_whisper_params *rwp;
779
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
780
+ rwp->new_segment_callback_container->callback = value;
781
+ return value;
782
+ }
783
+ static VALUE
784
+ ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
785
+ {
786
+ ruby_whisper_params *rwp;
787
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
788
+ return rwp->new_segment_callback_container->user_data;
789
+ }
790
+ /*
791
+ * Sets user data passed to the last argument of new segment callback.
792
+ *
793
+ * call-seq:
794
+ * new_segment_callback_user_data = user_data -> use_data
795
+ */
796
+ static VALUE
797
+ ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
798
+ {
799
+ ruby_whisper_params *rwp;
800
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
801
+ rwp->new_segment_callback_container->user_data = value;
802
+ return value;
803
+ }
804
+ static VALUE
805
+ ruby_whisper_params_get_progress_callback(VALUE self)
806
+ {
807
+ ruby_whisper_params *rwp;
808
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
809
+ return rwp->progress_callback_container->callback;
810
+ }
811
+ /*
812
+ * Sets progress callback, called on each progress update.
813
+ *
814
+ * params.new_segment_callback = ->(context, _, progress, user_data) {
815
+ * # ...
816
+ * }
817
+ *
818
+ * +progress+ is an Integer between 0 and 100.
819
+ *
820
+ * call-seq:
821
+ * progress_callback = callback -> callback
822
+ */
823
+ static VALUE
824
+ ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
825
+ {
826
+ ruby_whisper_params *rwp;
827
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
828
+ rwp->progress_callback_container->callback = value;
829
+ return value;
830
+ }
831
+ static VALUE
832
+ ruby_whisper_params_get_progress_callback_user_data(VALUE self)
833
+ {
834
+ ruby_whisper_params *rwp;
835
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
836
+ return rwp->progress_callback_container->user_data;
837
+ }
838
+ /*
839
+ * Sets user data passed to the last argument of progress callback.
840
+ *
841
+ * call-seq:
842
+ * progress_callback_user_data = user_data -> use_data
843
+ */
844
+ static VALUE
845
+ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
846
+ {
847
+ ruby_whisper_params *rwp;
848
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
849
+ rwp->progress_callback_container->user_data = value;
850
+ return value;
851
+ }
852
+ static VALUE
853
+ ruby_whisper_params_get_abort_callback(VALUE self)
854
+ {
855
+ ruby_whisper_params *rwp;
856
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
857
+ return rwp->abort_callback_container->callback;
858
+ }
859
+ /*
860
+ * Sets abort callback, called to check if the process should be aborted.
861
+ *
862
+ * params.abort_callback = ->(user_data) {
863
+ * # ...
864
+ * }
865
+ *
866
+ * call-seq:
867
+ * abort_callback = callback -> callback
868
+ */
869
+ static VALUE
870
+ ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
871
+ {
872
+ ruby_whisper_params *rwp;
873
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
874
+ rwp->abort_callback_container->callback = value;
875
+ return value;
876
+ }
877
+ static VALUE
878
+ ruby_whisper_params_get_abort_callback_user_data(VALUE self)
879
+ {
880
+ ruby_whisper_params *rwp;
881
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
882
+ return rwp->abort_callback_container->user_data;
883
+ }
884
+ /*
885
+ * Sets user data passed to the last argument of abort callback.
886
+ *
887
+ * call-seq:
888
+ * abort_callback_user_data = user_data -> use_data
889
+ */
890
+ static VALUE
891
+ ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
892
+ {
893
+ ruby_whisper_params *rwp;
894
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
895
+ rwp->abort_callback_container->user_data = value;
896
+ return value;
897
+ }
898
+
899
+ #define SET_PARAM_IF_SAME(param_name) \
900
+ if (id == id_ ## param_name) { \
901
+ ruby_whisper_params_set_ ## param_name(self, value); \
902
+ continue; \
903
+ }
904
+
905
+ static VALUE
906
+ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
907
+ {
908
+
909
+ VALUE kw_hash;
910
+ VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
911
+ VALUE value;
912
+ ruby_whisper_params *rwp;
913
+ ID id;
914
+ int i;
915
+
916
+ rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
917
+ if (NIL_P(kw_hash)) {
918
+ return self;
919
+ }
920
+
921
+ rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
922
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
923
+
924
+ for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
925
+ id = param_names[i];
926
+ value = values[i];
927
+ if (value == Qundef) {
928
+ continue;
929
+ }
930
+ if (id == id_diarize) {
931
+ rwp->diarize = value;
932
+ continue;
933
+ } else {
934
+ SET_PARAM_IF_SAME(language)
935
+ SET_PARAM_IF_SAME(translate)
936
+ SET_PARAM_IF_SAME(no_context)
937
+ SET_PARAM_IF_SAME(single_segment)
938
+ SET_PARAM_IF_SAME(print_special)
939
+ SET_PARAM_IF_SAME(print_progress)
940
+ SET_PARAM_IF_SAME(print_realtime)
941
+ SET_PARAM_IF_SAME(print_timestamps)
942
+ SET_PARAM_IF_SAME(suppress_blank)
943
+ SET_PARAM_IF_SAME(suppress_nst)
944
+ SET_PARAM_IF_SAME(token_timestamps)
945
+ SET_PARAM_IF_SAME(split_on_word)
946
+ SET_PARAM_IF_SAME(initial_prompt)
947
+ SET_PARAM_IF_SAME(offset)
948
+ SET_PARAM_IF_SAME(duration)
949
+ SET_PARAM_IF_SAME(max_text_tokens)
950
+ SET_PARAM_IF_SAME(temperature)
951
+ SET_PARAM_IF_SAME(max_initial_ts)
952
+ SET_PARAM_IF_SAME(length_penalty)
953
+ SET_PARAM_IF_SAME(temperature_inc)
954
+ SET_PARAM_IF_SAME(entropy_thold)
955
+ SET_PARAM_IF_SAME(logprob_thold)
956
+ SET_PARAM_IF_SAME(no_speech_thold)
957
+ SET_PARAM_IF_SAME(new_segment_callback)
958
+ SET_PARAM_IF_SAME(new_segment_callback_user_data)
959
+ SET_PARAM_IF_SAME(progress_callback)
960
+ SET_PARAM_IF_SAME(progress_callback_user_data)
961
+ SET_PARAM_IF_SAME(abort_callback)
962
+ SET_PARAM_IF_SAME(abort_callback_user_data)
963
+ }
964
+ }
965
+
966
+ return self;
967
+ }
968
+
969
+ #undef SET_PARAM_IF_SAME
970
+
971
+ /*
972
+ * Hook called on new segment. Yields each Whisper::Segment.
973
+ *
974
+ * whisper.on_new_segment do |segment|
975
+ * # ...
976
+ * end
977
+ *
978
+ * call-seq:
979
+ * on_new_segment {|segment| ... }
980
+ */
981
+ static VALUE
982
+ ruby_whisper_params_on_new_segment(VALUE self)
983
+ {
984
+ ruby_whisper_params *rws;
985
+ Data_Get_Struct(self, ruby_whisper_params, rws);
986
+ const VALUE blk = rb_block_proc();
987
+ rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
988
+ return Qnil;
989
+ }
990
+
991
+ /*
992
+ * Hook called on progress update. Yields each progress Integer between 0 and 100.
993
+ *
994
+ * whisper.on_progress do |progress|
995
+ * # ...
996
+ * end
997
+ *
998
+ * call-seq:
999
+ * on_progress {|progress| ... }
1000
+ */
1001
+ static VALUE
1002
+ ruby_whisper_params_on_progress(VALUE self)
1003
+ {
1004
+ ruby_whisper_params *rws;
1005
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1006
+ const VALUE blk = rb_block_proc();
1007
+ rb_ary_push(rws->progress_callback_container->callbacks, blk);
1008
+ return Qnil;
1009
+ }
1010
+
1011
+ /*
1012
+ * Call block to determine whether abort or not. Return +true+ when you want to abort.
1013
+ *
1014
+ * params.abort_on do
1015
+ * if some_condition
1016
+ * true # abort
1017
+ * else
1018
+ * false # continue
1019
+ * end
1020
+ * end
1021
+ *
1022
+ * call-seq:
1023
+ * abort_on { ... }
1024
+ */
1025
+ static VALUE
1026
+ ruby_whisper_params_abort_on(VALUE self)
1027
+ {
1028
+ ruby_whisper_params *rws;
1029
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1030
+ const VALUE blk = rb_block_proc();
1031
+ rb_ary_push(rws->abort_callback_container->callbacks, blk);
1032
+ return Qnil;
1033
+ }
1034
+
1035
+ void
1036
+ init_ruby_whisper_params(VALUE *mWhisper)
1037
+ {
1038
+ cParams = rb_define_class_under(*mWhisper, "Params", rb_cObject);
1039
+
1040
+ rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
1041
+ rb_define_method(cParams, "initialize", ruby_whisper_params_initialize, -1);
1042
+
1043
+ DEFINE_PARAM(language, 0)
1044
+ DEFINE_PARAM(translate, 1)
1045
+ DEFINE_PARAM(no_context, 2)
1046
+ DEFINE_PARAM(single_segment, 3)
1047
+ DEFINE_PARAM(print_special, 4)
1048
+ DEFINE_PARAM(print_progress, 5)
1049
+ DEFINE_PARAM(print_realtime, 6)
1050
+ DEFINE_PARAM(print_timestamps, 7)
1051
+ DEFINE_PARAM(suppress_blank, 8)
1052
+ DEFINE_PARAM(suppress_nst, 9)
1053
+ DEFINE_PARAM(token_timestamps, 10)
1054
+ DEFINE_PARAM(split_on_word, 11)
1055
+ DEFINE_PARAM(initial_prompt, 12)
1056
+ DEFINE_PARAM(diarize, 13)
1057
+ DEFINE_PARAM(offset, 14)
1058
+ DEFINE_PARAM(duration, 15)
1059
+ DEFINE_PARAM(max_text_tokens, 16)
1060
+ DEFINE_PARAM(temperature, 17)
1061
+ DEFINE_PARAM(max_initial_ts, 18)
1062
+ DEFINE_PARAM(length_penalty, 19)
1063
+ DEFINE_PARAM(temperature_inc, 20)
1064
+ DEFINE_PARAM(entropy_thold, 21)
1065
+ DEFINE_PARAM(logprob_thold, 22)
1066
+ DEFINE_PARAM(no_speech_thold, 23)
1067
+ DEFINE_PARAM(new_segment_callback, 24)
1068
+ DEFINE_PARAM(new_segment_callback_user_data, 25)
1069
+ DEFINE_PARAM(progress_callback, 26)
1070
+ DEFINE_PARAM(progress_callback_user_data, 27)
1071
+ DEFINE_PARAM(abort_callback, 28)
1072
+ DEFINE_PARAM(abort_callback_user_data, 29)
1073
+
1074
+ rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1075
+ rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1076
+ rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1077
+ }
bindings/ruby/ext/ruby_whisper_segment.c ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include "ruby_whisper.h"
3
+
4
+ extern VALUE cSegment;
5
+
6
+ static void
7
+ rb_whisper_segment_mark(ruby_whisper_segment *rws)
8
+ {
9
+ rb_gc_mark(rws->context);
10
+ }
11
+
12
+ VALUE
13
+ ruby_whisper_segment_allocate(VALUE klass)
14
+ {
15
+ ruby_whisper_segment *rws;
16
+ rws = ALLOC(ruby_whisper_segment);
17
+ return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
18
+ }
19
+
20
+ VALUE
21
+ rb_whisper_segment_initialize(VALUE context, int index)
22
+ {
23
+ ruby_whisper_segment *rws;
24
+ const VALUE segment = ruby_whisper_segment_allocate(cSegment);
25
+ Data_Get_Struct(segment, ruby_whisper_segment, rws);
26
+ rws->context = context;
27
+ rws->index = index;
28
+ return segment;
29
+ };
30
+
31
+ /*
32
+ * Start time in milliseconds.
33
+ *
34
+ * call-seq:
35
+ * start_time -> Integer
36
+ */
37
+ static VALUE
38
+ ruby_whisper_segment_get_start_time(VALUE self)
39
+ {
40
+ ruby_whisper_segment *rws;
41
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
42
+ ruby_whisper *rw;
43
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
44
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
45
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
46
+ return INT2NUM(t0 * 10);
47
+ }
48
+
49
+ /*
50
+ * End time in milliseconds.
51
+ *
52
+ * call-seq:
53
+ * end_time -> Integer
54
+ */
55
+ static VALUE
56
+ ruby_whisper_segment_get_end_time(VALUE self)
57
+ {
58
+ ruby_whisper_segment *rws;
59
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
60
+ ruby_whisper *rw;
61
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
62
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
63
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
64
+ return INT2NUM(t1 * 10);
65
+ }
66
+
67
+ /*
68
+ * Whether the next segment is predicted as a speaker turn.
69
+ *
70
+ * call-seq:
71
+ * speaker_turn_next? -> bool
72
+ */
73
+ static VALUE
74
+ ruby_whisper_segment_get_speaker_turn_next(VALUE self)
75
+ {
76
+ ruby_whisper_segment *rws;
77
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
78
+ ruby_whisper *rw;
79
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
80
+ return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
81
+ }
82
+
83
+ /*
84
+ * call-seq:
85
+ * text -> String
86
+ */
87
+ static VALUE
88
+ ruby_whisper_segment_get_text(VALUE self)
89
+ {
90
+ ruby_whisper_segment *rws;
91
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
92
+ ruby_whisper *rw;
93
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
94
+ const char * text = whisper_full_get_segment_text(rw->context, rws->index);
95
+ return rb_str_new2(text);
96
+ }
97
+
98
+ /*
99
+ * call-seq:
100
+ * no_speech_prob -> Float
101
+ */
102
+ static VALUE
103
+ ruby_whisper_segment_get_no_speech_prob(VALUE self)
104
+ {
105
+ ruby_whisper_segment *rws;
106
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
107
+ ruby_whisper *rw;
108
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
109
+ return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
110
+ }
111
+
112
+ void
113
+ init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
114
+ {
115
+ cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
116
+
117
+ rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
118
+ rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
119
+ rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
120
+ rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
121
+ rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
122
+ rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
123
+ }
bindings/ruby/ext/ruby_whisper_transcribe.cpp ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include "ruby_whisper.h"
3
+ #define DR_WAV_IMPLEMENTATION
4
+ #include "dr_wav.h"
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ #ifdef __cplusplus
9
+ extern "C" {
10
+ #endif
11
+
12
+ extern ID id_to_s;
13
+ extern ID id_call;
14
+
15
+ extern void
16
+ register_callbacks(ruby_whisper_params * rwp, VALUE * self);
17
+
18
+ /*
19
+ * transcribe a single file
20
+ * can emit to a block results
21
+ *
22
+ * params = Whisper::Params.new
23
+ * params.duration = 60_000
24
+ * whisper.transcribe "path/to/audio.wav", params do |text|
25
+ * puts text
26
+ * end
27
+ *
28
+ * call-seq:
29
+ * transcribe(path_to_audio, params) {|text| ...}
30
+ **/
31
+ VALUE
32
+ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
33
+ ruby_whisper *rw;
34
+ ruby_whisper_params *rwp;
35
+ VALUE wave_file_path, blk, params;
36
+
37
+ rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
38
+ Data_Get_Struct(self, ruby_whisper, rw);
39
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
40
+
41
+ if (!rb_respond_to(wave_file_path, id_to_s)) {
42
+ rb_raise(rb_eRuntimeError, "Expected file path to wave file");
43
+ }
44
+
45
+ std::string fname_inp = StringValueCStr(wave_file_path);
46
+
47
+ std::vector<float> pcmf32; // mono-channel F32 PCM
48
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
49
+
50
+ // WAV input - this is directly from main.cpp example
51
+ {
52
+ drwav wav;
53
+ std::vector<uint8_t> wav_data; // used for pipe input from stdin
54
+
55
+ if (fname_inp == "-") {
56
+ {
57
+ uint8_t buf[1024];
58
+ while (true) {
59
+ const size_t n = fread(buf, 1, sizeof(buf), stdin);
60
+ if (n == 0) {
61
+ break;
62
+ }
63
+ wav_data.insert(wav_data.end(), buf, buf + n);
64
+ }
65
+ }
66
+
67
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
68
+ fprintf(stderr, "error: failed to open WAV file from stdin\n");
69
+ return self;
70
+ }
71
+
72
+ fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
73
+ } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
74
+ fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
75
+ return self;
76
+ }
77
+
78
+ if (wav.channels != 1 && wav.channels != 2) {
79
+ fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
80
+ return self;
81
+ }
82
+
83
+ if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
84
+ fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
85
+ return self;
86
+ }
87
+
88
+ if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
89
+ fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
90
+ return self;
91
+ }
92
+
93
+ if (wav.bitsPerSample != 16) {
94
+ fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
95
+ return self;
96
+ }
97
+
98
+ const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
99
+
100
+ std::vector<int16_t> pcm16;
101
+ pcm16.resize(n*wav.channels);
102
+ drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
103
+ drwav_uninit(&wav);
104
+
105
+ // convert to mono, float
106
+ pcmf32.resize(n);
107
+ if (wav.channels == 1) {
108
+ for (uint64_t i = 0; i < n; i++) {
109
+ pcmf32[i] = float(pcm16[i])/32768.0f;
110
+ }
111
+ } else {
112
+ for (uint64_t i = 0; i < n; i++) {
113
+ pcmf32[i] = float((int32_t)pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
114
+ }
115
+ }
116
+
117
+ if (rwp->diarize) {
118
+ // convert to stereo, float
119
+ pcmf32s.resize(2);
120
+
121
+ pcmf32s[0].resize(n);
122
+ pcmf32s[1].resize(n);
123
+ for (uint64_t i = 0; i < n; i++) {
124
+ pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
125
+ pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
126
+ }
127
+ }
128
+ }
129
+ {
130
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
131
+
132
+ rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
133
+ bool is_aborted = *(bool*)user_data;
134
+ return !is_aborted;
135
+ };
136
+ rwp->params.encoder_begin_callback_user_data = &is_aborted;
137
+ }
138
+
139
+ register_callbacks(rwp, &self);
140
+
141
+ if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
142
+ fprintf(stderr, "failed to process audio\n");
143
+ return self;
144
+ }
145
+ const int n_segments = whisper_full_n_segments(rw->context);
146
+ VALUE output = rb_str_new2("");
147
+ for (int i = 0; i < n_segments; ++i) {
148
+ const char * text = whisper_full_get_segment_text(rw->context, i);
149
+ output = rb_str_concat(output, rb_str_new2(text));
150
+ }
151
+ VALUE idCall = id_call;
152
+ if (blk != Qnil) {
153
+ rb_funcall(blk, idCall, 1, output);
154
+ }
155
+ return self;
156
+ }
157
+ #ifdef __cplusplus
158
+ }
159
+ #endif
bindings/ruby/lib/whisper/model/uri.rb CHANGED
@@ -65,6 +65,13 @@ module Whisper
65
  end
66
  end
67
  end
 
 
 
 
 
 
 
68
  end
69
 
70
  def download(response)
 
65
  end
66
  end
67
  end
68
+ rescue => err
69
+ if cache_path.exist?
70
+ warn err
71
+ # Use cache file
72
+ else
73
+ raise
74
+ end
75
  end
76
 
77
  def download(response)
bindings/ruby/sig/whisper.rbs CHANGED
@@ -20,13 +20,12 @@ module Whisper
20
  def self.lang_id: (string name) -> Integer
21
  def self.lang_str: (Integer id) -> String
22
  def self.lang_str_full: (Integer id) -> String
23
- def self.log_set=: (log_callback) -> log_callback
24
- def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
25
 
26
  class Context
27
- def initialize: (string | _ToPath | ::URI::HTTP ) -> void
28
- def transcribe: (string, Params) -> void
29
- | (string, Params) { (String) -> void } -> void
30
  def model_n_vocab: () -> Integer
31
  def model_n_audio_ctx: () -> Integer
32
  def model_n_audio_state: () -> Integer
@@ -35,6 +34,10 @@ module Whisper
35
  def model_n_mels: () -> Integer
36
  def model_ftype: () -> Integer
37
  def model_type: () -> String
 
 
 
 
38
  def full_n_segments: () -> Integer
39
  def full_lang_id: () -> Integer
40
  def full_get_segment_t0: (Integer) -> Integer
@@ -42,18 +45,46 @@ module Whisper
42
  def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
43
  def full_get_segment_text: (Integer) -> String
44
  def full_get_segment_no_speech_prob: (Integer) -> Float
45
- def full: (Params, Array[Float], ?Integer) -> void
46
- | (Params, _Samples, ?Integer) -> void
47
- def full_parallel: (Params, Array[Float], ?Integer) -> void
48
- | (Params, _Samples, ?Integer) -> void
49
- | (Params, _Samples, ?Integer?, Integer) -> void
50
- def each_segment: { (Segment) -> void } -> void
51
- | () -> Enumerator[Segment]
52
- def model: () -> Model
53
  end
54
 
55
  class Params
56
- def initialize: () -> void
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def language=: (String) -> String # TODO: Enumerate lang names
58
  def language: () -> String
59
  def translate=: (boolish) -> boolish
@@ -79,7 +110,7 @@ module Whisper
79
  def split_on_word=: (boolish) -> boolish
80
  def split_on_word: () -> (true | false)
81
  def initial_prompt=: (_ToS) -> _ToS
82
- def initial_prompt: () -> String
83
  def diarize=: (boolish) -> boolish
84
  def diarize: () -> (true | false)
85
  def offset=: (Integer) -> Integer
@@ -103,19 +134,25 @@ module Whisper
103
  def no_speech_thold=: (Float) -> Float
104
  def no_speech_thold: () -> Float
105
  def new_segment_callback=: (new_segment_callback) -> new_segment_callback
 
106
  def new_segment_callback_user_data=: (Object) -> Object
 
107
  def progress_callback=: (progress_callback) -> progress_callback
 
108
  def progress_callback_user_data=: (Object) -> Object
 
109
  def abort_callback=: (abort_callback) -> abort_callback
 
110
  def abort_callback_user_data=: (Object) -> Object
 
111
  def on_new_segment: { (Segment) -> void } -> void
112
- def on_progress: { (Integer) -> void } -> void
113
- def abort_on: { (Object) -> boolish } -> void
114
  end
115
 
116
  class Model
117
  def self.pre_converted_models: () -> Hash[String, Model::URI]
118
- def initialize: () -> void
119
  def n_vocab: () -> Integer
120
  def n_audio_ctx: () -> Integer
121
  def n_audio_state: () -> Integer
@@ -130,14 +167,13 @@ module Whisper
130
  def type: () -> String
131
 
132
  class URI
133
- def initialize: (string | ::URI::HTTP) -> void
134
  def to_path: -> String
135
  def clear_cache: -> void
136
  end
137
  end
138
 
139
  class Segment
140
- def initialize: () -> void
141
  def start_time: () -> Integer
142
  def end_time: () -> Integer
143
  def speaker_next_turn?: () -> (true | false)
@@ -148,6 +184,6 @@ module Whisper
148
  class Error < StandardError
149
  attr_reader code: Integer
150
 
151
- def initialize: (Integer) -> void
152
  end
153
  end
 
20
  def self.lang_id: (string name) -> Integer
21
  def self.lang_str: (Integer id) -> String
22
  def self.lang_str_full: (Integer id) -> String
23
+ def self.log_set: (log_callback, Object? user_data) -> log_callback
 
24
 
25
  class Context
26
+ def self.new: (string | _ToPath | ::URI::HTTP) -> instance
27
+ def transcribe: (string, Params) -> self
28
+ | (string, Params) { (String) -> void } -> self
29
  def model_n_vocab: () -> Integer
30
  def model_n_audio_ctx: () -> Integer
31
  def model_n_audio_state: () -> Integer
 
34
  def model_n_mels: () -> Integer
35
  def model_ftype: () -> Integer
36
  def model_type: () -> String
37
+ def each_segment: { (Segment) -> void } -> void
38
+ | () -> Enumerator[Segment]
39
+ def model: () -> Model
40
+ def full_get_segment: (Integer nth) -> Segment
41
  def full_n_segments: () -> Integer
42
  def full_lang_id: () -> Integer
43
  def full_get_segment_t0: (Integer) -> Integer
 
45
  def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
46
  def full_get_segment_text: (Integer) -> String
47
  def full_get_segment_no_speech_prob: (Integer) -> Float
48
+ def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
49
+ | (Params, _Samples, ?Integer n_samples) -> self
50
+ def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
51
+ | (Params, _Samples, ?Integer n_samples) -> self
52
+ | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
 
 
 
53
  end
54
 
55
  class Params
56
+ def self.new: (
57
+ ?language: string,
58
+ ?translate: boolish,
59
+ ?no_context: boolish,
60
+ ?single_segment: boolish,
61
+ ?print_special: boolish,
62
+ ?print_progress: boolish,
63
+ ?print_realtime: boolish,
64
+ ?print_timestamps: boolish,
65
+ ?suppress_blank: boolish,
66
+ ?suppress_nst: boolish,
67
+ ?token_timestamps: boolish,
68
+ ?split_on_word: boolish,
69
+ ?initial_prompt: string | nil,
70
+ ?diarize: boolish,
71
+ ?offset: Integer,
72
+ ?duration: Integer,
73
+ ?max_text_tokens: Integer,
74
+ ?temperature: Float,
75
+ ?max_initial_ts: Float,
76
+ ?length_penalty: Float,
77
+ ?temperature_inc: Float,
78
+ ?entropy_thold: Float,
79
+ ?logprob_thold: Float,
80
+ ?no_speech_thold: Float,
81
+ ?new_segment_callback: new_segment_callback,
82
+ ?new_segment_callback_user_data: Object,
83
+ ?progress_callback: progress_callback,
84
+ ?progress_callback_user_data: Object,
85
+ ?abort_callback: abort_callback,
86
+ ?abort_callback_user_data: Object
87
+ ) -> instance
88
  def language=: (String) -> String # TODO: Enumerate lang names
89
  def language: () -> String
90
  def translate=: (boolish) -> boolish
 
110
  def split_on_word=: (boolish) -> boolish
111
  def split_on_word: () -> (true | false)
112
  def initial_prompt=: (_ToS) -> _ToS
113
+ def initial_prompt: () -> (String | nil)
114
  def diarize=: (boolish) -> boolish
115
  def diarize: () -> (true | false)
116
  def offset=: (Integer) -> Integer
 
134
  def no_speech_thold=: (Float) -> Float
135
  def no_speech_thold: () -> Float
136
  def new_segment_callback=: (new_segment_callback) -> new_segment_callback
137
+ def new_segment_callback: () -> (new_segment_callback | nil)
138
  def new_segment_callback_user_data=: (Object) -> Object
139
+ def new_segment_callback_user_data: () -> Object
140
  def progress_callback=: (progress_callback) -> progress_callback
141
+ def progress_callback: () -> (progress_callback | nil)
142
  def progress_callback_user_data=: (Object) -> Object
143
+ def progress_callback_user_data: () -> Object
144
  def abort_callback=: (abort_callback) -> abort_callback
145
+ def abort_callback: () -> (abort_callback | nil)
146
  def abort_callback_user_data=: (Object) -> Object
147
+ def abort_callback_user_data: () -> Object
148
  def on_new_segment: { (Segment) -> void } -> void
149
+ def on_progress: { (Integer progress) -> void } -> void
150
+ def abort_on: { (Object user_data) -> boolish } -> void
151
  end
152
 
153
  class Model
154
  def self.pre_converted_models: () -> Hash[String, Model::URI]
155
+ def self.new: () -> instance
156
  def n_vocab: () -> Integer
157
  def n_audio_ctx: () -> Integer
158
  def n_audio_state: () -> Integer
 
167
  def type: () -> String
168
 
169
  class URI
170
+ def self.new: (string | ::URI::HTTP) -> self
171
  def to_path: -> String
172
  def clear_cache: -> void
173
  end
174
  end
175
 
176
  class Segment
 
177
  def start_time: () -> Integer
178
  def end_time: () -> Integer
179
  def speaker_next_turn?: () -> (true | false)
 
184
  class Error < StandardError
185
  attr_reader code: Integer
186
 
187
+ def self.new: (Integer code) -> instance
188
  end
189
  end
bindings/ruby/tests/test_params.rb CHANGED
@@ -1,6 +1,39 @@
1
  require_relative "helper"
2
 
3
  class TestParams < TestBase
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def setup
5
  @params = Whisper::Params.new
6
  end
@@ -157,4 +190,57 @@ class TestParams < TestBase
157
  @params.no_speech_thold = 0.2
158
  assert_in_delta 0.2, @params.no_speech_thold
159
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  end
 
1
  require_relative "helper"
2
 
3
  class TestParams < TestBase
4
+ PARAM_NAMES = [
5
+ :language,
6
+ :translate,
7
+ :no_context,
8
+ :single_segment,
9
+ :print_special,
10
+ :print_progress,
11
+ :print_realtime,
12
+ :print_timestamps,
13
+ :suppress_blank,
14
+ :suppress_nst,
15
+ :token_timestamps,
16
+ :split_on_word,
17
+ :initial_prompt,
18
+ :diarize,
19
+ :offset,
20
+ :duration,
21
+ :max_text_tokens,
22
+ :temperature,
23
+ :max_initial_ts,
24
+ :length_penalty,
25
+ :temperature_inc,
26
+ :entropy_thold,
27
+ :logprob_thold,
28
+ :no_speech_thold,
29
+ :new_segment_callback,
30
+ :new_segment_callback_user_data,
31
+ :progress_callback,
32
+ :progress_callback_user_data,
33
+ :abort_callback,
34
+ :abort_callback_user_data,
35
+ ]
36
+
37
  def setup
38
  @params = Whisper::Params.new
39
  end
 
190
  @params.no_speech_thold = 0.2
191
  assert_in_delta 0.2, @params.no_speech_thold
192
  end
193
+
194
+ def test_new_with_kw_args
195
+ params = Whisper::Params.new(language: "es")
196
+ assert_equal "es", params.language
197
+ assert_equal 1.0, params.max_initial_ts
198
+ end
199
+
200
+ def test_new_with_kw_args_non_existent
201
+ assert_raise ArgumentError do
202
+ Whisper::Params.new(non_existent: "value")
203
+ end
204
+ end
205
+
206
+ def test_new_with_kw_args_wrong_type
207
+ assert_raise TypeError do
208
+ Whisper::Params.new(language: 3)
209
+ end
210
+ end
211
+
212
+ data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
213
+ def test_new_with_kw_args_default_values(param)
214
+ default_value = @params.send(param)
215
+ value = case [param, default_value]
216
+ in [*, true | false]
217
+ !default_value
218
+ in [*, Integer | Float]
219
+ default_value + 1
220
+ in [:language, *]
221
+ "es"
222
+ in [:initial_prompt, *]
223
+ "Initial prompt"
224
+ in [/_callback\Z/, *]
225
+ proc {}
226
+ in [/_user_data\Z/, *]
227
+ Object.new
228
+ end
229
+ params = Whisper::Params.new(param => value)
230
+ if Float === value
231
+ assert_in_delta value, params.send(param)
232
+ else
233
+ assert_equal value, params.send(param)
234
+ end
235
+
236
+ PARAM_NAMES.reject {|name| name == param}.each do |name|
237
+ expected = @params.send(name)
238
+ actual = params.send(name)
239
+ if Float === expected
240
+ assert_in_delta expected, actual
241
+ else
242
+ assert_equal expected, actual
243
+ end
244
+ end
245
+ end
246
  end
bindings/ruby/tests/test_whisper.rb CHANGED
@@ -29,6 +29,12 @@ class TestWhisper < TestBase
29
  assert_equal 0, whisper.full_lang_id
30
  end
31
 
 
 
 
 
 
 
32
  def test_full_get_segment_t0
33
  assert_equal 0, whisper.full_get_segment_t0(0)
34
  assert_raise IndexError do
 
29
  assert_equal 0, whisper.full_lang_id
30
  end
31
 
32
+ def test_full_get_segment
33
+ segment = whisper.full_get_segment(0)
34
+ assert_equal 0, segment.start_time
35
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
36
+ end
37
+
38
  def test_full_get_segment_t0
39
  assert_equal 0, whisper.full_get_segment_t0(0)
40
  assert_raise IndexError do