ggerganov commited on
Commit
5b0631d
·
unverified ·
1 Parent(s): 647c7e7

whisper : rename suppress_non_speech_tokens to suppress_nst (#2653)

Browse files
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java CHANGED
@@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
181
  }
182
 
183
  /** Flag to suppress non-speech tokens. */
184
- public CBool suppress_non_speech_tokens;
185
 
186
  /** Flag to suppress non-speech tokens. */
187
  public void suppressNonSpeechTokens(boolean enable) {
188
- suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
189
  }
190
 
191
  /** Initial decoding temperature. */
@@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
315
  "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
316
  "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
317
  "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
318
- "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
319
  "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
320
  "new_segment_callback", "new_segment_callback_user_data",
321
  "progress_callback", "progress_callback_user_data",
 
181
  }
182
 
183
  /** Flag to suppress non-speech tokens. */
184
+ public CBool suppress_nst;
185
 
186
  /** Flag to suppress non-speech tokens. */
187
  public void suppressNonSpeechTokens(boolean enable) {
188
+ suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
189
  }
190
 
191
  /** Initial decoding temperature. */
 
315
  "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
316
  "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
317
  "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
318
+ "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
319
  "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
320
  "new_segment_callback", "new_segment_callback_user_data",
321
  "progress_callback", "progress_callback_user_data",
bindings/ruby/ext/ruby_whisper.cpp CHANGED
@@ -979,19 +979,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
979
  }
980
  /*
981
  * call-seq:
982
- * suppress_non_speech_tokens = force_suppress -> force_suppress
983
  */
984
- static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
985
- BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
986
  }
987
  /*
988
  * If true, suppresses non-speech-tokens.
989
  *
990
  * call-seq:
991
- * suppress_non_speech_tokens -> bool
992
  */
993
- static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
994
- BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
995
  }
996
  /*
997
  * If true, enables token-level timestamps.
@@ -1832,8 +1832,8 @@ void Init_whisper() {
1832
  rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
1833
  rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
1834
  rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
1835
- rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
1836
- rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
1837
  rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
1838
  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
1839
  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
 
979
  }
980
  /*
981
  * call-seq:
982
+ * suppress_nst = force_suppress -> force_suppress
983
  */
984
+ static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
985
+ BOOL_PARAMS_SETTER(self, suppress_nst, value)
986
  }
987
  /*
988
  * If true, suppresses non-speech-tokens.
989
  *
990
  * call-seq:
991
+ * suppress_nst -> bool
992
  */
993
+ static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
994
+ BOOL_PARAMS_GETTER(self, suppress_nst)
995
  }
996
  /*
997
  * If true, enables token-level timestamps.
 
1832
  rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
1833
  rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
1834
  rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
1835
+ rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
1836
+ rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
1837
  rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
1838
  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
1839
  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
bindings/ruby/tests/test_params.rb CHANGED
@@ -89,11 +89,11 @@ class TestParams < TestBase
89
  assert [email protected]_blank
90
  end
91
 
92
- def test_suppress_non_speech_tokens
93
- @params.suppress_non_speech_tokens = true
94
- assert @params.suppress_non_speech_tokens
95
- @params.suppress_non_speech_tokens = false
96
- assert !@params.suppress_non_speech_tokens
97
  end
98
 
99
  def test_token_timestamps
 
89
  assert [email protected]_blank
90
  end
91
 
92
+ def test_suppress_nst
93
+ @params.suppress_nst = true
94
+ assert @params.suppress_nst
95
+ @params.suppress_nst = false
96
+ assert !@params.suppress_nst
97
  end
98
 
99
  def test_token_timestamps
examples/lsp/lsp.cpp CHANGED
@@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
181
  wparams.n_threads = params.n_threads;
182
 
183
  wparams.audio_ctx = params.audio_ctx;
184
- wparams.suppress_non_speech_tokens = true;
185
  // run the transformer and a single decoding pass
186
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
187
  fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
@@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
225
  wparams.prompt_tokens = cs.prompt_tokens.data();
226
  wparams.prompt_n_tokens = cs.prompt_tokens.size();
227
  // TODO: properly expose as option
228
- wparams.suppress_non_speech_tokens = true;
229
 
230
  // run the transformer and a single decoding pass
231
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
 
181
  wparams.n_threads = params.n_threads;
182
 
183
  wparams.audio_ctx = params.audio_ctx;
184
+ wparams.suppress_nst = true;
185
  // run the transformer and a single decoding pass
186
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
187
  fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
 
225
  wparams.prompt_tokens = cs.prompt_tokens.data();
226
  wparams.prompt_n_tokens = cs.prompt_tokens.size();
227
  // TODO: properly expose as option
228
+ wparams.suppress_nst = true;
229
 
230
  // run the transformer and a single decoding pass
231
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
examples/server/server.cpp CHANGED
@@ -76,7 +76,7 @@ struct whisper_params {
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
78
  bool flash_attn = false;
79
- bool suppress_non_speech_tokens = false;
80
 
81
  std::string language = "en";
82
  std::string prompt = "";
@@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
136
  fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
137
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
138
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
139
- fprintf(stderr, " -sns, --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false");
140
  fprintf(stderr, "\n");
141
  }
142
 
@@ -181,7 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
181
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
182
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
183
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
184
- else if (arg == "-sns" || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; }
185
  // server params
186
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
187
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@@ -477,7 +477,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
477
  }
478
  if (req.has_file("suppress_non_speech"))
479
  {
480
- params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
 
 
 
 
481
  }
482
  }
483
 
@@ -793,7 +797,7 @@ int main(int argc, char ** argv) {
793
  wparams.no_timestamps = params.no_timestamps;
794
  wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
795
 
796
- wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens;
797
 
798
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
799
 
 
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
78
  bool flash_attn = false;
79
+ bool suppress_nst = false;
80
 
81
  std::string language = "en";
82
  std::string prompt = "";
 
136
  fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
137
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
138
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
139
+ fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
140
  fprintf(stderr, "\n");
141
  }
142
 
 
181
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
182
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
183
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
184
+ else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
185
  // server params
186
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
187
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
 
477
  }
478
  if (req.has_file("suppress_non_speech"))
479
  {
480
+ params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
481
+ }
482
+ if (req.has_file("suppress_nst"))
483
+ {
484
+ params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
485
  }
486
  }
487
 
 
797
  wparams.no_timestamps = params.no_timestamps;
798
  wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
799
 
800
+ wparams.suppress_nst = params.suppress_nst;
801
 
802
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
803
 
include/whisper.h CHANGED
@@ -522,8 +522,8 @@ extern "C" {
522
  bool detect_language;
523
 
524
  // common decoding parameters:
525
- bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
526
- bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
527
 
528
  float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
529
  float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
 
522
  bool detect_language;
523
 
524
  // common decoding parameters:
525
+ bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
526
+ bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
527
 
528
  float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
529
  float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
src/whisper.cpp CHANGED
@@ -4676,7 +4676,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4676
  /*.detect_language =*/ false,
4677
 
4678
  /*.suppress_blank =*/ true,
4679
- /*.suppress_non_speech_tokens =*/ false,
4680
 
4681
  /*.temperature =*/ 0.0f,
4682
  /*.max_initial_ts =*/ 1.0f,
@@ -4960,7 +4960,7 @@ static void whisper_process_logits(
4960
 
4961
  // suppress non-speech tokens
4962
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4963
- if (params.suppress_non_speech_tokens) {
4964
  for (const std::string & token : non_speech_tokens) {
4965
  const std::string suppress_tokens[] = {token, " " + token};
4966
  for (const std::string & suppress_token : suppress_tokens) {
 
4676
  /*.detect_language =*/ false,
4677
 
4678
  /*.suppress_blank =*/ true,
4679
+ /*.suppress_nst =*/ false,
4680
 
4681
  /*.temperature =*/ 0.0f,
4682
  /*.max_initial_ts =*/ 1.0f,
 
4960
 
4961
  // suppress non-speech tokens
4962
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4963
+ if (params.suppress_nst) {
4964
  for (const std::string & token : non_speech_tokens) {
4965
  const std::string suppress_tokens[] = {token, " " + token};
4966
  for (const std::string & suppress_token : suppress_tokens) {