Spaces:
Running
Running
whisper : rename suppress_non_speech_tokens to suppress_nst (#2653)
Browse files
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
CHANGED
|
@@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
|
|
| 181 |
}
|
| 182 |
|
| 183 |
/** Flag to suppress non-speech tokens. */
|
| 184 |
-
public CBool
|
| 185 |
|
| 186 |
/** Flag to suppress non-speech tokens. */
|
| 187 |
public void suppressNonSpeechTokens(boolean enable) {
|
| 188 |
-
|
| 189 |
}
|
| 190 |
|
| 191 |
/** Initial decoding temperature. */
|
|
@@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
|
|
| 315 |
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
| 316 |
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
| 317 |
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
| 318 |
-
"suppress_blank", "
|
| 319 |
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
| 320 |
"new_segment_callback", "new_segment_callback_user_data",
|
| 321 |
"progress_callback", "progress_callback_user_data",
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
/** Flag to suppress non-speech tokens. */
|
| 184 |
+
public CBool suppress_nst;
|
| 185 |
|
| 186 |
/** Flag to suppress non-speech tokens. */
|
| 187 |
public void suppressNonSpeechTokens(boolean enable) {
|
| 188 |
+
suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
|
| 189 |
}
|
| 190 |
|
| 191 |
/** Initial decoding temperature. */
|
|
|
|
| 315 |
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
| 316 |
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
| 317 |
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
| 318 |
+
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
|
| 319 |
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
| 320 |
"new_segment_callback", "new_segment_callback_user_data",
|
| 321 |
"progress_callback", "progress_callback_user_data",
|
bindings/ruby/ext/ruby_whisper.cpp
CHANGED
|
@@ -979,19 +979,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
|
|
| 979 |
}
|
| 980 |
/*
|
| 981 |
* call-seq:
|
| 982 |
-
*
|
| 983 |
*/
|
| 984 |
-
static VALUE
|
| 985 |
-
BOOL_PARAMS_SETTER(self,
|
| 986 |
}
|
| 987 |
/*
|
| 988 |
* If true, suppresses non-speech-tokens.
|
| 989 |
*
|
| 990 |
* call-seq:
|
| 991 |
-
*
|
| 992 |
*/
|
| 993 |
-
static VALUE
|
| 994 |
-
BOOL_PARAMS_GETTER(self,
|
| 995 |
}
|
| 996 |
/*
|
| 997 |
* If true, enables token-level timestamps.
|
|
@@ -1832,8 +1832,8 @@ void Init_whisper() {
|
|
| 1832 |
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
| 1833 |
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
| 1834 |
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
| 1835 |
-
rb_define_method(cParams, "
|
| 1836 |
-
rb_define_method(cParams, "
|
| 1837 |
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
| 1838 |
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
| 1839 |
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
|
|
|
| 979 |
}
|
| 980 |
/*
|
| 981 |
* call-seq:
|
| 982 |
+
* suppress_nst = force_suppress -> force_suppress
|
| 983 |
*/
|
| 984 |
+
static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
|
| 985 |
+
BOOL_PARAMS_SETTER(self, suppress_nst, value)
|
| 986 |
}
|
| 987 |
/*
|
| 988 |
* If true, suppresses non-speech-tokens.
|
| 989 |
*
|
| 990 |
* call-seq:
|
| 991 |
+
* suppress_nst -> bool
|
| 992 |
*/
|
| 993 |
+
static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
|
| 994 |
+
BOOL_PARAMS_GETTER(self, suppress_nst)
|
| 995 |
}
|
| 996 |
/*
|
| 997 |
* If true, enables token-level timestamps.
|
|
|
|
| 1832 |
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
| 1833 |
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
| 1834 |
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
| 1835 |
+
rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
|
| 1836 |
+
rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
|
| 1837 |
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
| 1838 |
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
| 1839 |
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
bindings/ruby/tests/test_params.rb
CHANGED
|
@@ -89,11 +89,11 @@ class TestParams < TestBase
|
|
| 89 |
assert [email protected]_blank
|
| 90 |
end
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
@params.
|
| 94 |
-
assert @params.
|
| 95 |
-
@params.
|
| 96 |
-
assert !@params.
|
| 97 |
end
|
| 98 |
|
| 99 |
def test_token_timestamps
|
|
|
|
| 89 |
assert [email protected]_blank
|
| 90 |
end
|
| 91 |
|
| 92 |
+
def test_suppress_nst
|
| 93 |
+
@params.suppress_nst = true
|
| 94 |
+
assert @params.suppress_nst
|
| 95 |
+
@params.suppress_nst = false
|
| 96 |
+
assert !@params.suppress_nst
|
| 97 |
end
|
| 98 |
|
| 99 |
def test_token_timestamps
|
examples/lsp/lsp.cpp
CHANGED
|
@@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
|
|
| 181 |
wparams.n_threads = params.n_threads;
|
| 182 |
|
| 183 |
wparams.audio_ctx = params.audio_ctx;
|
| 184 |
-
wparams.
|
| 185 |
// run the transformer and a single decoding pass
|
| 186 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 187 |
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
|
@@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
|
|
| 225 |
wparams.prompt_tokens = cs.prompt_tokens.data();
|
| 226 |
wparams.prompt_n_tokens = cs.prompt_tokens.size();
|
| 227 |
// TODO: properly expose as option
|
| 228 |
-
wparams.
|
| 229 |
|
| 230 |
// run the transformer and a single decoding pass
|
| 231 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
|
|
|
| 181 |
wparams.n_threads = params.n_threads;
|
| 182 |
|
| 183 |
wparams.audio_ctx = params.audio_ctx;
|
| 184 |
+
wparams.suppress_nst = true;
|
| 185 |
// run the transformer and a single decoding pass
|
| 186 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 187 |
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
|
|
|
| 225 |
wparams.prompt_tokens = cs.prompt_tokens.data();
|
| 226 |
wparams.prompt_n_tokens = cs.prompt_tokens.size();
|
| 227 |
// TODO: properly expose as option
|
| 228 |
+
wparams.suppress_nst = true;
|
| 229 |
|
| 230 |
// run the transformer and a single decoding pass
|
| 231 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
examples/server/server.cpp
CHANGED
|
@@ -76,7 +76,7 @@ struct whisper_params {
|
|
| 76 |
bool no_timestamps = false;
|
| 77 |
bool use_gpu = true;
|
| 78 |
bool flash_attn = false;
|
| 79 |
-
bool
|
| 80 |
|
| 81 |
std::string language = "en";
|
| 82 |
std::string prompt = "";
|
|
@@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 136 |
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
|
| 137 |
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
|
| 138 |
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
| 139 |
-
fprintf(stderr, " -sns, --suppress-
|
| 140 |
fprintf(stderr, "\n");
|
| 141 |
}
|
| 142 |
|
|
@@ -181,7 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
|
| 181 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 182 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 183 |
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 184 |
-
else if (arg == "-sns" || arg == "--suppress-
|
| 185 |
// server params
|
| 186 |
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
| 187 |
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
|
@@ -477,7 +477,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
|
| 477 |
}
|
| 478 |
if (req.has_file("suppress_non_speech"))
|
| 479 |
{
|
| 480 |
-
params.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
}
|
| 482 |
}
|
| 483 |
|
|
@@ -793,7 +797,7 @@ int main(int argc, char ** argv) {
|
|
| 793 |
wparams.no_timestamps = params.no_timestamps;
|
| 794 |
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
| 795 |
|
| 796 |
-
wparams.
|
| 797 |
|
| 798 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
| 799 |
|
|
|
|
| 76 |
bool no_timestamps = false;
|
| 77 |
bool use_gpu = true;
|
| 78 |
bool flash_attn = false;
|
| 79 |
+
bool suppress_nst = false;
|
| 80 |
|
| 81 |
std::string language = "en";
|
| 82 |
std::string prompt = "";
|
|
|
|
| 136 |
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
|
| 137 |
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
|
| 138 |
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
| 139 |
+
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
| 140 |
fprintf(stderr, "\n");
|
| 141 |
}
|
| 142 |
|
|
|
|
| 181 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 182 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 183 |
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 184 |
+
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
| 185 |
// server params
|
| 186 |
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
| 187 |
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
|
|
|
| 477 |
}
|
| 478 |
if (req.has_file("suppress_non_speech"))
|
| 479 |
{
|
| 480 |
+
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
|
| 481 |
+
}
|
| 482 |
+
if (req.has_file("suppress_nst"))
|
| 483 |
+
{
|
| 484 |
+
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
|
| 485 |
}
|
| 486 |
}
|
| 487 |
|
|
|
|
| 797 |
wparams.no_timestamps = params.no_timestamps;
|
| 798 |
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
| 799 |
|
| 800 |
+
wparams.suppress_nst = params.suppress_nst;
|
| 801 |
|
| 802 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
| 803 |
|
include/whisper.h
CHANGED
|
@@ -522,8 +522,8 @@ extern "C" {
|
|
| 522 |
bool detect_language;
|
| 523 |
|
| 524 |
// common decoding parameters:
|
| 525 |
-
bool suppress_blank;
|
| 526 |
-
bool
|
| 527 |
|
| 528 |
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
| 529 |
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
|
|
|
| 522 |
bool detect_language;
|
| 523 |
|
| 524 |
// common decoding parameters:
|
| 525 |
+
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
| 526 |
+
bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 527 |
|
| 528 |
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
| 529 |
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
src/whisper.cpp
CHANGED
|
@@ -4676,7 +4676,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 4676 |
/*.detect_language =*/ false,
|
| 4677 |
|
| 4678 |
/*.suppress_blank =*/ true,
|
| 4679 |
-
/*.
|
| 4680 |
|
| 4681 |
/*.temperature =*/ 0.0f,
|
| 4682 |
/*.max_initial_ts =*/ 1.0f,
|
|
@@ -4960,7 +4960,7 @@ static void whisper_process_logits(
|
|
| 4960 |
|
| 4961 |
// suppress non-speech tokens
|
| 4962 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 4963 |
-
if (params.
|
| 4964 |
for (const std::string & token : non_speech_tokens) {
|
| 4965 |
const std::string suppress_tokens[] = {token, " " + token};
|
| 4966 |
for (const std::string & suppress_token : suppress_tokens) {
|
|
|
|
| 4676 |
/*.detect_language =*/ false,
|
| 4677 |
|
| 4678 |
/*.suppress_blank =*/ true,
|
| 4679 |
+
/*.suppress_nst =*/ false,
|
| 4680 |
|
| 4681 |
/*.temperature =*/ 0.0f,
|
| 4682 |
/*.max_initial_ts =*/ 1.0f,
|
|
|
|
| 4960 |
|
| 4961 |
// suppress non-speech tokens
|
| 4962 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 4963 |
+
if (params.suppress_nst) {
|
| 4964 |
for (const std::string & token : non_speech_tokens) {
|
| 4965 |
const std::string suppress_tokens[] = {token, " " + token};
|
| 4966 |
for (const std::string & suppress_token : suppress_tokens) {
|