KitaitiMakoto commited on
Commit
4bf69ed
·
unverified ·
1 Parent(s): 7feeb43

ruby : Add low-level methods to transcribe (#2585)

Browse files

* Add tests for Whisper::Context#full

* Add Whisper::Context#full

* Add tests for Whisper::Error

* Add document of Whisper::Context#full [skip ci]

* Add additional signature for Whisper::Context#full

* Add description to Whisper::Context#full

* Add test for Whisper::Context#full_parallel

* Add Whisper::Context#full_parallel

* Hide Whisper's instance methods from Ruby code

* Add class to test MemoryView

* Build test class before running test

* Add test for MemoryView

* Make Whisper::Context#full and #full_parallel accept MemoryView

* Use Ruby 3.1 on CI

* Add comment on samples data type

* Update README

* Update README

* Remove unused code

.github/workflows/bindings-ruby.yml CHANGED
@@ -50,6 +50,6 @@ jobs:
50
  steps:
51
  - uses: ruby/setup-ruby@v1
52
  with:
53
- ruby-version: '3.0'
54
  - uses: actions/checkout@v4
55
  - run: rake test
 
50
  steps:
51
  - uses: ruby/setup-ruby@v1
52
  with:
53
+ ruby-version: '3.1'
54
  - uses: actions/checkout@v4
55
  - run: rake test
bindings/ruby/README.md CHANGED
@@ -160,6 +160,24 @@ Whisper.log_set ->(level, buffer, user_data) {
160
  Whisper::Context.new(MODEL)
161
  ```
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  License
164
  -------
165
 
 
160
  Whisper::Context.new(MODEL)
161
  ```
162
 
163
+ You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility.
164
+
165
+ ```ruby
166
+ require "whisper"
167
+ require "wavefile"
168
+
169
+ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
170
+ samples = reader.enum_for(:each_buffer).map(&:samples).flatten
171
+
172
+ whisper = Whisper::Context.new("path/to/model.bin")
173
+ whisper.full(Whisper::Params.new, samples)
174
+ whisper.each_segment do |segment|
175
+ puts segment.text
176
+ end
177
+ ```
178
+
179
+ The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
180
+
181
  License
182
  -------
183
 
bindings/ruby/Rakefile CHANGED
@@ -68,3 +68,13 @@ file TEST_MODEL do
68
  sh "./models/download-ggml-model.sh base.en"
69
  end
70
  end
 
 
 
 
 
 
 
 
 
 
 
68
  sh "./models/download-ggml-model.sh base.en"
69
  end
70
  end
71
+
72
+ TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
73
+ file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
74
+ Dir.chdir "tests/jfk_reader" do
75
+ ruby "extconf.rb"
76
+ sh "make"
77
+ end
78
+ end
79
+ CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
80
+ task test: TEST_MEMORY_VIEW
bindings/ruby/ext/ruby_whisper.cpp CHANGED
@@ -1,4 +1,5 @@
1
  #include <ruby.h>
 
2
  #include "ruby_whisper.h"
3
  #define DR_WAV_IMPLEMENTATION
4
  #include "dr_wav.h"
@@ -35,11 +36,15 @@ extern "C" {
35
  VALUE mWhisper;
36
  VALUE cContext;
37
  VALUE cParams;
 
38
 
39
  static ID id_to_s;
40
  static ID id_call;
41
  static ID id___method__;
42
  static ID id_to_enum;
 
 
 
43
 
44
  static bool is_log_callback_finalized = false;
45
 
@@ -100,13 +105,13 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
100
  * log_set ->(level, buffer, user_data) { ... }, user_data -> nil
101
  */
102
  static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
103
- VALUE old_callback = rb_iv_get(self, "@log_callback");
104
  if (!NIL_P(old_callback)) {
105
  rb_undefine_finalizer(old_callback);
106
  }
107
 
108
- rb_iv_set(self, "@log_callback", log_callback);
109
- rb_iv_set(self, "@user_data", user_data);
110
 
111
  VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
112
  rb_define_finalizer(log_callback, finalize_log_callback);
@@ -115,8 +120,8 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
115
  if (is_log_callback_finalized) {
116
  return;
117
  }
118
- VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
119
- VALUE udata = rb_iv_get(mWhisper, "@user_data");
120
  rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
121
  }, nullptr);
122
 
@@ -544,6 +549,168 @@ VALUE ruby_whisper_model_type(VALUE self) {
544
  return rb_str_new2(whisper_model_type_readable(rw->context));
545
  }
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  /*
548
  * Number of segments.
549
  *
@@ -1518,15 +1685,59 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
1518
  return rb_str_new2(whisper_model_type_readable(rw->context));
1519
  }
1520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1521
  void Init_whisper() {
1522
  id_to_s = rb_intern("to_s");
1523
  id_call = rb_intern("call");
1524
  id___method__ = rb_intern("__method__");
1525
  id_to_enum = rb_intern("to_enum");
 
 
 
1526
 
1527
  mWhisper = rb_define_module("Whisper");
1528
  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
1529
  cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
 
1530
 
1531
  rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
1532
  rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
@@ -1564,6 +1775,8 @@ void Init_whisper() {
1564
  rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
1565
  rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
1566
  rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
 
 
1567
 
1568
  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
1569
 
@@ -1623,6 +1836,9 @@ void Init_whisper() {
1623
  rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1624
  rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
1625
 
 
 
 
1626
  // High leve
1627
  cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
1628
 
 
1
  #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
  #include "ruby_whisper.h"
4
  #define DR_WAV_IMPLEMENTATION
5
  #include "dr_wav.h"
 
36
  VALUE mWhisper;
37
  VALUE cContext;
38
  VALUE cParams;
39
+ VALUE eError;
40
 
41
  static ID id_to_s;
42
  static ID id_call;
43
  static ID id___method__;
44
  static ID id_to_enum;
45
+ static ID id_length;
46
+ static ID id_next;
47
+ static ID id_new;
48
 
49
  static bool is_log_callback_finalized = false;
50
 
 
105
  * log_set ->(level, buffer, user_data) { ... }, user_data -> nil
106
  */
107
  static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
108
+ VALUE old_callback = rb_iv_get(self, "log_callback");
109
  if (!NIL_P(old_callback)) {
110
  rb_undefine_finalizer(old_callback);
111
  }
112
 
113
+ rb_iv_set(self, "log_callback", log_callback);
114
+ rb_iv_set(self, "user_data", user_data);
115
 
116
  VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
117
  rb_define_finalizer(log_callback, finalize_log_callback);
 
120
  if (is_log_callback_finalized) {
121
  return;
122
  }
123
+ VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
124
+ VALUE udata = rb_iv_get(mWhisper, "user_data");
125
  rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
126
  }, nullptr);
127
 
 
549
  return rb_str_new2(whisper_model_type_readable(rw->context));
550
  }
551
 
552
+ /*
553
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
554
+ * Not thread safe for same context
555
+ * Uses the specified decoding strategy to obtain the text.
556
+ *
557
+ * call-seq:
558
+ * full(params, samples, n_samples) -> nil
559
+ * full(params, samples) -> nil
560
+ *
561
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
562
+ */
563
+ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
564
+ if (argc < 2 || argc > 3) {
565
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
566
+ }
567
+
568
+ ruby_whisper *rw;
569
+ ruby_whisper_params *rwp;
570
+ Data_Get_Struct(self, ruby_whisper, rw);
571
+ VALUE params = argv[0];
572
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
573
+ VALUE samples = argv[1];
574
+ int n_samples;
575
+ rb_memory_view_t view;
576
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
577
+ if (argc == 3) {
578
+ n_samples = NUM2INT(argv[2]);
579
+ if (TYPE(samples) == T_ARRAY) {
580
+ if (RARRAY_LEN(samples) < n_samples) {
581
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
582
+ }
583
+ }
584
+ // Should check when samples.respond_to?(:length)?
585
+ } else {
586
+ if (TYPE(samples) == T_ARRAY) {
587
+ n_samples = RARRAY_LEN(samples);
588
+ } else if (memory_view_available_p) {
589
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
590
+ view.obj = Qnil;
591
+ rb_raise(rb_eArgError, "unable to get a memory view");
592
+ }
593
+ n_samples = view.byte_size / view.item_size;
594
+ } else if (rb_respond_to(samples, id_length)) {
595
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
596
+ } else {
597
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
598
+ }
599
+ }
600
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
601
+ if (memory_view_available_p) {
602
+ c_samples = (float *)view.data;
603
+ } else {
604
+ if (TYPE(samples) == T_ARRAY) {
605
+ for (int i = 0; i < n_samples; i++) {
606
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
607
+ }
608
+ } else {
609
+ // TODO: use rb_block_call
610
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
611
+ for (int i = 0; i < n_samples; i++) {
612
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
613
+ VALUE sample = rb_funcall(iter, id_next, 0);
614
+ c_samples[i] = RFLOAT_VALUE(sample);
615
+ }
616
+ }
617
+ }
618
+ const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
619
+ if (0 == result) {
620
+ return Qnil;
621
+ } else {
622
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
623
+ }
624
+ }
625
+
626
+ /*
627
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
628
+ * Result is stored in the default state of the context
629
+ * Not thread safe if executed in parallel on the same context.
630
+ * It seems this approach can offer some speedup in some cases.
631
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
632
+ *
633
+ * call-seq:
634
+ * full_parallel(params, samples) -> nil
635
+ * full_parallel(params, samples, n_samples) -> nil
636
+ * full_parallel(params, samples, n_samples, n_processors) -> nil
637
+ * full_parallel(params, samples, nil, n_processors) -> nil
638
+ */
639
+ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
640
+ if (argc < 2 || argc > 4) {
641
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
642
+ }
643
+
644
+ ruby_whisper *rw;
645
+ ruby_whisper_params *rwp;
646
+ Data_Get_Struct(self, ruby_whisper, rw);
647
+ VALUE params = argv[0];
648
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
649
+ VALUE samples = argv[1];
650
+ int n_samples;
651
+ int n_processors;
652
+ rb_memory_view_t view;
653
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
654
+ switch (argc) {
655
+ case 2:
656
+ n_processors = 1;
657
+ break;
658
+ case 3:
659
+ n_processors = 1;
660
+ break;
661
+ case 4:
662
+ n_processors = NUM2INT(argv[3]);
663
+ break;
664
+ }
665
+ if (argc >= 3 && !NIL_P(argv[2])) {
666
+ n_samples = NUM2INT(argv[2]);
667
+ if (TYPE(samples) == T_ARRAY) {
668
+ if (RARRAY_LEN(samples) < n_samples) {
669
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
670
+ }
671
+ }
672
+ // Should check when samples.respond_to?(:length)?
673
+ } else if (memory_view_available_p) {
674
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
675
+ view.obj = Qnil;
676
+ rb_raise(rb_eArgError, "unable to get a memory view");
677
+ }
678
+ n_samples = view.byte_size / view.item_size;
679
+ } else {
680
+ if (TYPE(samples) == T_ARRAY) {
681
+ n_samples = RARRAY_LEN(samples);
682
+ } else if (rb_respond_to(samples, id_length)) {
683
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
684
+ } else {
685
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
686
+ }
687
+ }
688
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
689
+ if (memory_view_available_p) {
690
+ c_samples = (float *)view.data;
691
+ } else {
692
+ if (TYPE(samples) == T_ARRAY) {
693
+ for (int i = 0; i < n_samples; i++) {
694
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
695
+ }
696
+ } else {
697
+ // FIXME: use rb_block_call
698
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
699
+ for (int i = 0; i < n_samples; i++) {
700
+ // TODO: check if iter is exhausted and raise ArgumentError
701
+ VALUE sample = rb_funcall(iter, id_next, 0);
702
+ c_samples[i] = RFLOAT_VALUE(sample);
703
+ }
704
+ }
705
+ }
706
+ const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
707
+ if (0 == result) {
708
+ return Qnil;
709
+ } else {
710
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
711
+ }
712
+ }
713
+
714
  /*
715
  * Number of segments.
716
  *
 
1685
  return rb_str_new2(whisper_model_type_readable(rw->context));
1686
  }
1687
 
1688
+ static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
1689
+ const int c_code = NUM2INT(code);
1690
+ char *raw_message;
1691
+ switch (c_code) {
1692
+ case -2:
1693
+ raw_message = "failed to compute log mel spectrogram";
1694
+ break;
1695
+ case -3:
1696
+ raw_message = "failed to auto-detect language";
1697
+ break;
1698
+ case -4:
1699
+ raw_message = "too many decoders requested";
1700
+ break;
1701
+ case -5:
1702
+ raw_message = "audio_ctx is larger than the maximum allowed";
1703
+ break;
1704
+ case -6:
1705
+ raw_message = "failed to encode";
1706
+ break;
1707
+ case -7:
1708
+ raw_message = "whisper_kv_cache_init() failed for self-attention cache";
1709
+ break;
1710
+ case -8:
1711
+ raw_message = "failed to decode";
1712
+ break;
1713
+ case -9:
1714
+ raw_message = "failed to decode";
1715
+ break;
1716
+ default:
1717
+ raw_message = "unknown error";
1718
+ break;
1719
+ }
1720
+ const VALUE message = rb_str_new2(raw_message);
1721
+ rb_call_super(1, &message);
1722
+ rb_iv_set(self, "@code", code);
1723
+
1724
+ return self;
1725
+ }
1726
+
1727
+
1728
  void Init_whisper() {
1729
  id_to_s = rb_intern("to_s");
1730
  id_call = rb_intern("call");
1731
  id___method__ = rb_intern("__method__");
1732
  id_to_enum = rb_intern("to_enum");
1733
+ id_length = rb_intern("length");
1734
+ id_next = rb_intern("next");
1735
+ id_new = rb_intern("new");
1736
 
1737
  mWhisper = rb_define_module("Whisper");
1738
  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
1739
  cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
1740
+ eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
1741
 
1742
  rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
1743
  rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
 
1775
  rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
1776
  rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
1777
  rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
1778
+ rb_define_method(cContext, "full", ruby_whisper_full, -1);
1779
+ rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
1780
 
1781
  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
1782
 
 
1836
  rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1837
  rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
1838
 
1839
+ rb_define_attr(eError, "code", true, false);
1840
+ rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
1841
+
1842
  // High leve
1843
  cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
1844
 
bindings/ruby/tests/helper.rb CHANGED
@@ -1,5 +1,6 @@
1
  require "test/unit"
2
  require "whisper"
 
3
 
4
  class TestBase < Test::Unit::TestCase
5
  MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
 
1
  require "test/unit"
2
  require "whisper"
3
+ require_relative "jfk_reader/jfk_reader"
4
 
5
  class TestBase < Test::Unit::TestCase
6
  MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
bindings/ruby/tests/jfk_reader/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Makefile
2
+ jfk_reader.o
3
+ jfk_reader.so
4
+ jfk_reader.bundle
5
+ jfk_reader.dll
bindings/ruby/tests/jfk_reader/extconf.rb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ require "mkmf"
2
+
3
+ create_makefile("jfk_reader")
bindings/ruby/tests/jfk_reader/jfk_reader.c ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
+ #include <ruby/encoding.h>
4
+
5
+ static VALUE
6
+ jfk_reader_initialize(VALUE self, VALUE audio_path)
7
+ {
8
+ rb_iv_set(self, "audio_path", audio_path);
9
+ return Qnil;
10
+ }
11
+
12
+ static bool
13
+ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
14
+ {
15
+ VALUE audio_path = rb_iv_get(obj, "audio_path");
16
+ const char *audio_path_str = StringValueCStr(audio_path);
17
+ const int n_samples = 176000;
18
+ float *data = (float *)malloc(n_samples * sizeof(float));
19
+ short *samples = (short *)malloc(n_samples * sizeof(short));
20
+ FILE *file = fopen(audio_path_str, "rb");
21
+
22
+ fseek(file, 78, SEEK_SET);
23
+ fread(samples, sizeof(short), n_samples, file);
24
+ fclose(file);
25
+ for (int i = 0; i < n_samples; i++) {
26
+ data[i] = samples[i]/32768.0;
27
+ }
28
+
29
+ view->obj = obj;
30
+ view->data = (void *)data;
31
+ view->byte_size = sizeof(float) * n_samples;
32
+ view->readonly = true;
33
+ view->format = "f";
34
+ view->item_size = sizeof(float);
35
+ view->item_desc.components = NULL;
36
+ view->item_desc.length = 0;
37
+ view->ndim = 1;
38
+ view->shape = NULL;
39
+ view->sub_offsets = NULL;
40
+ view->private_data = NULL;
41
+
42
+ return true;
43
+ }
44
+
45
+ static bool
46
+ jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
47
+ {
48
+ return true;
49
+ }
50
+
51
+ static bool
52
+ jfk_reader_memory_view_available_p(const VALUE obj)
53
+ {
54
+ return true;
55
+ }
56
+
57
+ static const rb_memory_view_entry_t jfk_reader_view_entry = {
58
+ jfk_reader_get_memory_view,
59
+ jfk_reader_release_memory_view,
60
+ jfk_reader_memory_view_available_p
61
+ };
62
+
63
+ static VALUE
64
+ read_jfk(int argc, VALUE *argv, VALUE obj)
65
+ {
66
+ const char *audio_path_str = StringValueCStr(argv[0]);
67
+ const int n_samples = 176000;
68
+
69
+ short samples[n_samples];
70
+ FILE *file = fopen(audio_path_str, "rb");
71
+
72
+ fseek(file, 78, SEEK_SET);
73
+ fread(samples, sizeof(short), n_samples, file);
74
+ fclose(file);
75
+
76
+ VALUE rb_samples = rb_ary_new2(n_samples);
77
+ for (int i = 0; i < n_samples; i++) {
78
+ rb_ary_push(rb_samples, INT2FIX(samples[i]));
79
+ }
80
+
81
+ VALUE rb_data = rb_ary_new2(n_samples);
82
+ for (int i = 0; i < n_samples; i++) {
83
+ rb_ary_push(rb_data, DBL2NUM(samples[i]/32768.0));
84
+ }
85
+
86
+ float data[n_samples];
87
+ for (int i = 0; i < n_samples; i++) {
88
+ data[i] = samples[i]/32768.0;
89
+ }
90
+ void *c_data = (void *)data;
91
+ VALUE rb_void = rb_enc_str_new((const char *)c_data, sizeof(data), rb_ascii8bit_encoding());
92
+
93
+ VALUE rb_result = rb_ary_new3(3, rb_samples, rb_data, rb_void);
94
+ return rb_result;
95
+ }
96
+
97
+ void Init_jfk_reader(void)
98
+ {
99
+ VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
100
+ rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
101
+ rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
102
+
103
+
104
+ rb_define_global_function("read_jfk", read_jfk, -1);
105
+
106
+
107
+
108
+ }
bindings/ruby/tests/test_error.rb ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ require_relative "helper"
2
+
3
+ class TestError < TestBase
4
+ def test_error
5
+ error = Whisper::Error.new(-2)
6
+ assert_equal "failed to compute log mel spectrogram", error.message
7
+ assert_equal -2, error.code
8
+ end
9
+
10
+ def test_unknown_error
11
+ error = Whisper::Error.new(-20)
12
+ assert_equal "unknown error", error.message
13
+ end
14
+
15
+ def test_non_int_code
16
+ assert_raise TypeError do
17
+ error = Whisper::Error.new("non int")
18
+ end
19
+ end
20
+ end
bindings/ruby/tests/test_whisper.rb CHANGED
@@ -1,5 +1,6 @@
1
  require_relative "helper"
2
  require "stringio"
 
3
 
4
  # Exists to detect memory-related bug
5
  Whisper.log_set ->(level, buffer, user_data) {}, nil
@@ -124,4 +125,102 @@ class TestWhisper < TestBase
124
  ensure
125
  $stderr = stderr
126
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  end
 
1
  require_relative "helper"
2
  require "stringio"
3
+ require "etc"
4
 
5
  # Exists to detect memory-related bug
6
  Whisper.log_set ->(level, buffer, user_data) {}, nil
 
125
  ensure
126
  $stderr = stderr
127
  end
128
+
129
+ sub_test_case "full" do
130
+ def setup
131
+ super
132
+ @whisper = Whisper::Context.new(MODEL)
133
+ @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
134
+ end
135
+
136
+ def test_full
137
+ @whisper.full(@params, @samples, @samples.length)
138
+
139
+ assert_equal 1, @whisper.full_n_segments
140
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
141
+ end
142
+
143
+ def test_full_without_length
144
+ @whisper.full(@params, @samples)
145
+
146
+ assert_equal 1, @whisper.full_n_segments
147
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
148
+ end
149
+
150
+ def test_full_enumerator
151
+ samples = @samples.each
152
+ @whisper.full(@params, samples, @samples.length)
153
+
154
+ assert_equal 1, @whisper.full_n_segments
155
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
156
+ end
157
+
158
+ def test_full_enumerator_without_length
159
+ samples = @samples.each
160
+ assert_raise ArgumentError do
161
+ @whisper.full(@params, samples)
162
+ end
163
+ end
164
+
165
+ def test_full_enumerator_with_too_large_length
166
+ samples = @samples.each.take(10).to_enum
167
+ assert_raise StopIteration do
168
+ @whisper.full(@params, samples, 11)
169
+ end
170
+ end
171
+
172
+ def test_full_with_memory_view
173
+ samples = JFKReader.new(AUDIO)
174
+ @whisper.full(@params, samples)
175
+
176
+ assert_equal 1, @whisper.full_n_segments
177
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
178
+ end
179
+
180
+ def test_full_parallel
181
+ @whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
182
+
183
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
184
+ text = @whisper.each_segment.collect(&:text).join
185
+ assert_match /ask what you can do/i, text
186
+ assert_match /for your country/i, text
187
+ end
188
+
189
+ def test_full_parallel_with_memory_view
190
+ samples = JFKReader.new(AUDIO)
191
+ @whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
192
+
193
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
194
+ text = @whisper.each_segment.collect(&:text).join
195
+ assert_match /ask what you can do/i, text
196
+ assert_match /for your country/i, text
197
+ end
198
+
199
+ def test_full_parallel_without_length_and_n_processors
200
+ @whisper.full_parallel(@params, @samples)
201
+
202
+ assert_equal 1, @whisper.full_n_segments
203
+ text = @whisper.each_segment.collect(&:text).join
204
+ assert_match /ask what you can do/i, text
205
+ assert_match /for your country/i, text
206
+ end
207
+
208
+ def test_full_parallel_without_length
209
+ @whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
210
+
211
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
212
+ text = @whisper.each_segment.collect(&:text).join
213
+ assert_match /ask what you can do/i, text
214
+ assert_match /for your country/i, text
215
+ end
216
+
217
+ def test_full_parallel_without_n_processors
218
+ @whisper.full_parallel(@params, @samples, @samples.length)
219
+
220
+ assert_equal 1, @whisper.full_n_segments
221
+ text = @whisper.each_segment.collect(&:text).join
222
+ assert_match /ask what you can do/i, text
223
+ assert_match /for your country/i, text
224
+ end
225
+ end
226
  end