Spaces:
Sleeping
Sleeping
stream : add beam size parameter(#2836)
Browse files* feat: Add beam size parameter to stream.cpp for beam search configuration
* feat: Add beam size parameter to whisper full params in stream example
* fix: Remove duplicate beam search size assignment in server.cpp
examples/stream/stream.cpp
CHANGED
|
@@ -23,6 +23,7 @@ struct whisper_params {
|
|
| 23 |
int32_t capture_id = -1;
|
| 24 |
int32_t max_tokens = 32;
|
| 25 |
int32_t audio_ctx = 0;
|
|
|
|
| 26 |
|
| 27 |
float vad_thold = 0.6f;
|
| 28 |
float freq_thold = 100.0f;
|
|
@@ -59,6 +60,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
|
| 59 |
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
| 60 |
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
| 61 |
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
|
|
|
| 62 |
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
| 63 |
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
| 64 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
|
@@ -96,6 +98,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 96 |
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
| 97 |
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
| 98 |
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
|
|
|
| 99 |
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
| 100 |
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
| 101 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
|
@@ -298,7 +301,7 @@ int main(int argc, char ** argv) {
|
|
| 298 |
|
| 299 |
// run the inference
|
| 300 |
{
|
| 301 |
-
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 302 |
|
| 303 |
wparams.print_progress = false;
|
| 304 |
wparams.print_special = params.print_special;
|
|
@@ -309,6 +312,7 @@ int main(int argc, char ** argv) {
|
|
| 309 |
wparams.max_tokens = params.max_tokens;
|
| 310 |
wparams.language = params.language.c_str();
|
| 311 |
wparams.n_threads = params.n_threads;
|
|
|
|
| 312 |
|
| 313 |
wparams.audio_ctx = params.audio_ctx;
|
| 314 |
|
|
|
|
| 23 |
int32_t capture_id = -1;
|
| 24 |
int32_t max_tokens = 32;
|
| 25 |
int32_t audio_ctx = 0;
|
| 26 |
+
int32_t beam_size = -1;
|
| 27 |
|
| 28 |
float vad_thold = 0.6f;
|
| 29 |
float freq_thold = 100.0f;
|
|
|
|
| 60 |
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
| 61 |
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
| 62 |
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
| 63 |
+
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
| 64 |
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
| 65 |
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
| 66 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
|
|
|
| 98 |
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
| 99 |
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
| 100 |
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
| 101 |
+
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
| 102 |
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
| 103 |
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
| 104 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
|
|
|
| 301 |
|
| 302 |
// run the inference
|
| 303 |
{
|
| 304 |
+
whisper_full_params wparams = whisper_full_default_params(params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
|
| 305 |
|
| 306 |
wparams.print_progress = false;
|
| 307 |
wparams.print_special = params.print_special;
|
|
|
|
| 312 |
wparams.max_tokens = params.max_tokens;
|
| 313 |
wparams.language = params.language.c_str();
|
| 314 |
wparams.n_threads = params.n_threads;
|
| 315 |
+
wparams.beam_search.beam_size = params.beam_size;
|
| 316 |
|
| 317 |
wparams.audio_ctx = params.audio_ctx;
|
| 318 |
|