ulatekh ggerganov commited on
Commit
8cc6334
·
unverified ·
1 Parent(s): 2616a7c

whisper : suppress tokens with a regex (#1997)

Browse files

* Allow a regular expression to describe tokens to suppress.

Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens.

Technique inspired by https://github.com/openai/whisper/discussions/1041

Co-authored-by: Georgi Gerganov <[email protected]>

* Blind change to fix Java test.

---------

Co-authored-by: Georgi Gerganov <[email protected]>

bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java CHANGED
@@ -148,6 +148,9 @@ public class WhisperFullParams extends Structure {
148
  tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
149
  }
150
 
 
 
 
151
  /** Tokens to provide to the whisper decoder as an initial prompt.
152
  * These are prepended to any existing text context from a previous call. */
153
  public String initial_prompt;
@@ -319,7 +322,7 @@ public class WhisperFullParams extends Structure {
319
  "no_context", "single_segment", "no_timestamps",
320
  "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
321
  "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
322
- "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
323
  "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
324
  "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
325
  "new_segment_callback", "new_segment_callback_user_data",
 
148
  tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
149
  }
150
 
151
+ /** Regular expression matching tokens to suppress. */
152
+ public String suppress_regex;
153
+
154
  /** Tokens to provide to the whisper decoder as an initial prompt.
155
  * These are prepended to any existing text context from a previous call. */
156
  public String initial_prompt;
 
322
  "no_context", "single_segment", "no_timestamps",
323
  "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
324
  "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
325
+ "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
326
  "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
327
  "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
328
  "new_segment_callback", "new_segment_callback_user_data",
examples/command/command.cpp CHANGED
@@ -52,6 +52,9 @@ struct whisper_params {
52
  std::string prompt;
53
  std::string context;
54
  std::string grammar;
 
 
 
55
  };
56
 
57
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
85
  else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
86
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
87
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
 
88
  else {
89
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
90
  whisper_print_usage(argc, argv, params);
@@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
122
  fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
123
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
124
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
 
125
  fprintf(stderr, "\n");
126
  }
127
 
@@ -167,6 +172,8 @@ std::string transcribe(
167
 
168
  wparams.initial_prompt = params.context.data();
169
 
 
 
170
  const auto & grammar_parsed = params.grammar_parsed;
171
  auto grammar_rules = grammar_parsed.c_rules();
172
 
 
52
  std::string prompt;
53
  std::string context;
54
  std::string grammar;
55
+
56
+ // A regular expression that matches tokens to suppress
57
+ std::string suppress_regex;
58
  };
59
 
60
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
88
  else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
89
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
90
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
91
+ else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
92
  else {
93
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
94
  whisper_print_usage(argc, argv, params);
 
126
  fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
127
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
128
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
129
+ fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
130
  fprintf(stderr, "\n");
131
  }
132
 
 
172
 
173
  wparams.initial_prompt = params.context.data();
174
 
175
+ wparams.suppress_regex = params.suppress_regex.c_str();
176
+
177
  const auto & grammar_parsed = params.grammar_parsed;
178
  auto grammar_rules = grammar_parsed.c_rules();
179
 
examples/main/main.cpp CHANGED
@@ -6,6 +6,7 @@
6
  #include <cmath>
7
  #include <fstream>
8
  #include <cstdio>
 
9
  #include <string>
10
  #include <thread>
11
  #include <vector>
@@ -78,6 +79,9 @@ struct whisper_params {
78
  // [TDRZ] speaker turn string
79
  std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
80
 
 
 
 
81
  std::string openvino_encode_device = "CPU";
82
 
83
  std::string dtw = "";
@@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
160
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
161
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
162
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
163
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
164
  else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
165
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
223
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
224
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
225
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
226
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
227
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
228
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
@@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) {
1033
 
1034
  wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
1035
 
 
 
1036
  wparams.initial_prompt = params.prompt.c_str();
1037
 
1038
  wparams.greedy.best_of = params.best_of;
 
6
  #include <cmath>
7
  #include <fstream>
8
  #include <cstdio>
9
+ #include <regex>
10
  #include <string>
11
  #include <thread>
12
  #include <vector>
 
79
  // [TDRZ] speaker turn string
80
  std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
81
 
82
+ // A regular expression that matches tokens to suppress
83
+ std::string suppress_regex;
84
+
85
  std::string openvino_encode_device = "CPU";
86
 
87
  std::string dtw = "";
 
164
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
165
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
166
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
167
+ else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
168
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
169
  else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
170
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
 
228
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
229
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
230
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
231
+ fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
232
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
233
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
234
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
 
1039
 
1040
  wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
1041
 
1042
+ wparams.suppress_regex = params.suppress_regex.c_str();
1043
+
1044
  wparams.initial_prompt = params.prompt.c_str();
1045
 
1046
  wparams.greedy.best_of = params.best_of;
whisper.cpp CHANGED
@@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4553
 
4554
  /*.tdrz_enable =*/ false,
4555
 
 
 
4556
  /*.initial_prompt =*/ nullptr,
4557
  /*.prompt_tokens =*/ nullptr,
4558
  /*.prompt_n_tokens =*/ 0,
@@ -4796,6 +4798,17 @@ static void whisper_process_logits(
4796
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4797
  }
4798
 
 
 
 
 
 
 
 
 
 
 
 
4799
  // suppress non-speech tokens
4800
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4801
  if (params.suppress_non_speech_tokens) {
 
4553
 
4554
  /*.tdrz_enable =*/ false,
4555
 
4556
+ /* suppress_regex =*/ nullptr,
4557
+
4558
  /*.initial_prompt =*/ nullptr,
4559
  /*.prompt_tokens =*/ nullptr,
4560
  /*.prompt_n_tokens =*/ 0,
 
4798
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4799
  }
4800
 
4801
+ // suppress any tokens matching a regular expression
4802
+ // ref: https://github.com/openai/whisper/discussions/1041
4803
+ if (params.suppress_regex != nullptr) {
4804
+ std::regex re(params.suppress_regex);
4805
+ for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
4806
+ if (std::regex_match(token_id.first, re)) {
4807
+ logits[token_id.second] = -INFINITY;
4808
+ }
4809
+ }
4810
+ }
4811
+
4812
  // suppress non-speech tokens
4813
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4814
  if (params.suppress_non_speech_tokens) {
whisper.h CHANGED
@@ -505,6 +505,9 @@ extern "C" {
505
  // [EXPERIMENTAL] [TDRZ] tinydiarize
506
  bool tdrz_enable; // enable tinydiarize speaker turn detection
507
 
 
 
 
508
  // tokens to provide to the whisper decoder as initial prompt
509
  // these are prepended to any existing text context from a previous call
510
  // use whisper_tokenize() to convert text to tokens
 
505
  // [EXPERIMENTAL] [TDRZ] tinydiarize
506
  bool tdrz_enable; // enable tinydiarize speaker turn detection
507
 
508
+ // A regular expression that matches tokens to suppress
509
+ const char * suppress_regex;
510
+
511
  // tokens to provide to the whisper decoder as initial prompt
512
  // these are prepended to any existing text context from a previous call
513
  // use whisper_tokenize() to convert text to tokens