Spaces:
Running
Running
whisper : suppress tokens with a regex (#1997)
Browse files* Allow a regular expression to describe tokens to suppress.
Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens.
Technique inspired by https://github.com/openai/whisper/discussions/1041
Co-authored-by: Georgi Gerganov <[email protected]>
* Blind change to fix Java test.
---------
Co-authored-by: Georgi Gerganov <[email protected]>
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
CHANGED
|
@@ -148,6 +148,9 @@ public class WhisperFullParams extends Structure {
|
|
| 148 |
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
|
| 149 |
}
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
/** Tokens to provide to the whisper decoder as an initial prompt.
|
| 152 |
* These are prepended to any existing text context from a previous call. */
|
| 153 |
public String initial_prompt;
|
|
@@ -319,7 +322,7 @@ public class WhisperFullParams extends Structure {
|
|
| 319 |
"no_context", "single_segment", "no_timestamps",
|
| 320 |
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
| 321 |
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
| 322 |
-
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
| 323 |
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
| 324 |
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
| 325 |
"new_segment_callback", "new_segment_callback_user_data",
|
|
|
|
| 148 |
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
|
| 149 |
}
|
| 150 |
|
| 151 |
+
/** Regular expression matching tokens to suppress. */
|
| 152 |
+
public String suppress_regex;
|
| 153 |
+
|
| 154 |
/** Tokens to provide to the whisper decoder as an initial prompt.
|
| 155 |
* These are prepended to any existing text context from a previous call. */
|
| 156 |
public String initial_prompt;
|
|
|
|
| 322 |
"no_context", "single_segment", "no_timestamps",
|
| 323 |
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
| 324 |
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
| 325 |
+
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
| 326 |
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
| 327 |
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
| 328 |
"new_segment_callback", "new_segment_callback_user_data",
|
examples/command/command.cpp
CHANGED
|
@@ -52,6 +52,9 @@ struct whisper_params {
|
|
| 52 |
std::string prompt;
|
| 53 |
std::string context;
|
| 54 |
std::string grammar;
|
|
|
|
|
|
|
|
|
|
| 55 |
};
|
| 56 |
|
| 57 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 85 |
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
| 86 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 87 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
|
|
| 88 |
else {
|
| 89 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 90 |
whisper_print_usage(argc, argv, params);
|
|
@@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 122 |
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
| 123 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 124 |
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
|
|
|
| 125 |
fprintf(stderr, "\n");
|
| 126 |
}
|
| 127 |
|
|
@@ -167,6 +172,8 @@ std::string transcribe(
|
|
| 167 |
|
| 168 |
wparams.initial_prompt = params.context.data();
|
| 169 |
|
|
|
|
|
|
|
| 170 |
const auto & grammar_parsed = params.grammar_parsed;
|
| 171 |
auto grammar_rules = grammar_parsed.c_rules();
|
| 172 |
|
|
|
|
| 52 |
std::string prompt;
|
| 53 |
std::string context;
|
| 54 |
std::string grammar;
|
| 55 |
+
|
| 56 |
+
// A regular expression that matches tokens to suppress
|
| 57 |
+
std::string suppress_regex;
|
| 58 |
};
|
| 59 |
|
| 60 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 88 |
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
| 89 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 90 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
| 91 |
+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
| 92 |
else {
|
| 93 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 94 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 126 |
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
| 127 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 128 |
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
| 129 |
+
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
| 130 |
fprintf(stderr, "\n");
|
| 131 |
}
|
| 132 |
|
|
|
|
| 172 |
|
| 173 |
wparams.initial_prompt = params.context.data();
|
| 174 |
|
| 175 |
+
wparams.suppress_regex = params.suppress_regex.c_str();
|
| 176 |
+
|
| 177 |
const auto & grammar_parsed = params.grammar_parsed;
|
| 178 |
auto grammar_rules = grammar_parsed.c_rules();
|
| 179 |
|
examples/main/main.cpp
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
#include <cmath>
|
| 7 |
#include <fstream>
|
| 8 |
#include <cstdio>
|
|
|
|
| 9 |
#include <string>
|
| 10 |
#include <thread>
|
| 11 |
#include <vector>
|
|
@@ -78,6 +79,9 @@ struct whisper_params {
|
|
| 78 |
// [TDRZ] speaker turn string
|
| 79 |
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
std::string openvino_encode_device = "CPU";
|
| 82 |
|
| 83 |
std::string dtw = "";
|
|
@@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 160 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 161 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 162 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 163 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 164 |
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
| 165 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
@@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 223 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 224 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 225 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 226 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 227 |
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
| 228 |
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
|
@@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) {
|
|
| 1033 |
|
| 1034 |
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
| 1035 |
|
|
|
|
|
|
|
| 1036 |
wparams.initial_prompt = params.prompt.c_str();
|
| 1037 |
|
| 1038 |
wparams.greedy.best_of = params.best_of;
|
|
|
|
| 6 |
#include <cmath>
|
| 7 |
#include <fstream>
|
| 8 |
#include <cstdio>
|
| 9 |
+
#include <regex>
|
| 10 |
#include <string>
|
| 11 |
#include <thread>
|
| 12 |
#include <vector>
|
|
|
|
| 79 |
// [TDRZ] speaker turn string
|
| 80 |
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
| 81 |
|
| 82 |
+
// A regular expression that matches tokens to suppress
|
| 83 |
+
std::string suppress_regex;
|
| 84 |
+
|
| 85 |
std::string openvino_encode_device = "CPU";
|
| 86 |
|
| 87 |
std::string dtw = "";
|
|
|
|
| 164 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 165 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 166 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 167 |
+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
| 168 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 169 |
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
| 170 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
|
|
| 228 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 229 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 230 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 231 |
+
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
| 232 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 233 |
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
| 234 |
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
|
|
|
| 1039 |
|
| 1040 |
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
| 1041 |
|
| 1042 |
+
wparams.suppress_regex = params.suppress_regex.c_str();
|
| 1043 |
+
|
| 1044 |
wparams.initial_prompt = params.prompt.c_str();
|
| 1045 |
|
| 1046 |
wparams.greedy.best_of = params.best_of;
|
whisper.cpp
CHANGED
|
@@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 4553 |
|
| 4554 |
/*.tdrz_enable =*/ false,
|
| 4555 |
|
|
|
|
|
|
|
| 4556 |
/*.initial_prompt =*/ nullptr,
|
| 4557 |
/*.prompt_tokens =*/ nullptr,
|
| 4558 |
/*.prompt_n_tokens =*/ 0,
|
|
@@ -4796,6 +4798,17 @@ static void whisper_process_logits(
|
|
| 4796 |
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
| 4797 |
}
|
| 4798 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4799 |
// suppress non-speech tokens
|
| 4800 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 4801 |
if (params.suppress_non_speech_tokens) {
|
|
|
|
| 4553 |
|
| 4554 |
/*.tdrz_enable =*/ false,
|
| 4555 |
|
| 4556 |
+
/* suppress_regex =*/ nullptr,
|
| 4557 |
+
|
| 4558 |
/*.initial_prompt =*/ nullptr,
|
| 4559 |
/*.prompt_tokens =*/ nullptr,
|
| 4560 |
/*.prompt_n_tokens =*/ 0,
|
|
|
|
| 4798 |
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
| 4799 |
}
|
| 4800 |
|
| 4801 |
+
// suppress any tokens matching a regular expression
|
| 4802 |
+
// ref: https://github.com/openai/whisper/discussions/1041
|
| 4803 |
+
if (params.suppress_regex != nullptr) {
|
| 4804 |
+
std::regex re(params.suppress_regex);
|
| 4805 |
+
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
|
| 4806 |
+
if (std::regex_match(token_id.first, re)) {
|
| 4807 |
+
logits[token_id.second] = -INFINITY;
|
| 4808 |
+
}
|
| 4809 |
+
}
|
| 4810 |
+
}
|
| 4811 |
+
|
| 4812 |
// suppress non-speech tokens
|
| 4813 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 4814 |
if (params.suppress_non_speech_tokens) {
|
whisper.h
CHANGED
|
@@ -505,6 +505,9 @@ extern "C" {
|
|
| 505 |
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
| 506 |
bool tdrz_enable; // enable tinydiarize speaker turn detection
|
| 507 |
|
|
|
|
|
|
|
|
|
|
| 508 |
// tokens to provide to the whisper decoder as initial prompt
|
| 509 |
// these are prepended to any existing text context from a previous call
|
| 510 |
// use whisper_tokenize() to convert text to tokens
|
|
|
|
| 505 |
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
| 506 |
bool tdrz_enable; // enable tinydiarize speaker turn detection
|
| 507 |
|
| 508 |
+
// A regular expression that matches tokens to suppress
|
| 509 |
+
const char * suppress_regex;
|
| 510 |
+
|
| 511 |
// tokens to provide to the whisper decoder as initial prompt
|
| 512 |
// these are prepended to any existing text context from a previous call
|
| 513 |
// use whisper_tokenize() to convert text to tokens
|