ggerganov commited on
Commit
e02ade6
·
unverified ·
1 Parent(s): dedf05b

whisper : restore decoder temperature fallbacks

Browse files

I disabled this because there were many complaints about slow decoding.
The current implementation does not allow batching the decoders when
using the "best of" or "beam size" parameters, so the decoding time is
proportional to the number of decoders, which is obviously not great.

However, now there are even more complaints about wrong decodings and
repetition.

So, making a compromise by re-enabling the fallbacks, but defaulting to
just 2 "best of" / "beam size" decoders. Also, the temperature step is
increased from 0.2 to 0.4 - i.e. from maximum of 5 fallbacks to maximum
of 2.

Also, the stream example now has fallbacks enabled by default.

close #471 #477 #508 #612 #719 #731

examples/main/main.cpp CHANGED
@@ -57,7 +57,7 @@ struct whisper_params {
57
  int32_t duration_ms = 0;
58
  int32_t max_context = -1;
59
  int32_t max_len = 0;
60
- int32_t best_of = 5;
61
  int32_t beam_size = -1;
62
 
63
  float word_thold = 0.01f;
 
57
  int32_t duration_ms = 0;
58
  int32_t max_context = -1;
59
  int32_t max_len = 0;
60
+ int32_t best_of = 2;
61
  int32_t beam_size = -1;
62
 
63
  float word_thold = 0.01f;
examples/stream/stream.cpp CHANGED
@@ -43,6 +43,7 @@ struct whisper_params {
43
 
44
  bool speed_up = false;
45
  bool translate = false;
 
46
  bool print_special = false;
47
  bool no_context = true;
48
  bool no_timestamps = false;
@@ -73,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
73
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
74
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
75
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
 
76
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
77
  else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
78
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@@ -94,22 +96,23 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
94
  fprintf(stderr, "\n");
95
  fprintf(stderr, "options:\n");
96
  fprintf(stderr, " -h, --help [default] show this help message and exit\n");
97
- fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
98
- fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
99
- fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
100
- fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
101
- fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
102
- fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
103
- fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
104
- fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
105
- fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
106
- fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
107
- fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
108
- fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
109
- fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
110
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
111
- fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
112
- fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
 
113
  fprintf(stderr, "\n");
114
  }
115
 
@@ -297,7 +300,8 @@ int main(int argc, char ** argv) {
297
  wparams.speed_up = params.speed_up;
298
 
299
  // disable temperature fallback
300
- wparams.temperature_inc = -1.0f;
 
301
 
302
  wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
303
  wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
 
43
 
44
  bool speed_up = false;
45
  bool translate = false;
46
+ bool no_fallback = false;
47
  bool print_special = false;
48
  bool no_context = true;
49
  bool no_timestamps = false;
 
74
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
75
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
76
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
77
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
78
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
79
  else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
80
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
 
96
  fprintf(stderr, "\n");
97
  fprintf(stderr, "options:\n");
98
  fprintf(stderr, " -h, --help [default] show this help message and exit\n");
99
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
100
+ fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
101
+ fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
102
+ fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
103
+ fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
104
+ fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
105
+ fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
106
+ fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
107
+ fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
108
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
109
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
110
+ fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
111
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
112
+ fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
113
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
114
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
115
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
116
  fprintf(stderr, "\n");
117
  }
118
 
 
300
  wparams.speed_up = params.speed_up;
301
 
302
  // disable temperature fallback
303
+ //wparams.temperature_inc = -1.0f;
304
+ wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
305
 
306
  wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
307
  wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
whisper.cpp CHANGED
@@ -3220,7 +3220,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3220
  /*.max_initial_ts =*/ 1.0f,
3221
  /*.length_penalty =*/ -1.0f,
3222
 
3223
- /*.temperature_inc =*/ 0.0f, // TODO: temporary disabled until improve performance
3224
  /*.entropy_thold =*/ 2.4f,
3225
  /*.logprob_thold =*/ -1.0f,
3226
  /*.no_speech_thold =*/ 0.6f,
@@ -3252,13 +3252,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3252
  case WHISPER_SAMPLING_GREEDY:
3253
  {
3254
  result.greedy = {
3255
- /*.best_of =*/ 1,
3256
  };
3257
  } break;
3258
  case WHISPER_SAMPLING_BEAM_SEARCH:
3259
  {
3260
  result.beam_search = {
3261
- /*.beam_size =*/ 5,
3262
 
3263
  /*.patience =*/ -1.0f,
3264
  };
 
3220
  /*.max_initial_ts =*/ 1.0f,
3221
  /*.length_penalty =*/ -1.0f,
3222
 
3223
+ /*.temperature_inc =*/ 0.4f,
3224
  /*.entropy_thold =*/ 2.4f,
3225
  /*.logprob_thold =*/ -1.0f,
3226
  /*.no_speech_thold =*/ 0.6f,
 
3252
  case WHISPER_SAMPLING_GREEDY:
3253
  {
3254
  result.greedy = {
3255
+ /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
3256
  };
3257
  } break;
3258
  case WHISPER_SAMPLING_BEAM_SEARCH:
3259
  {
3260
  result.beam_search = {
3261
+ /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
3262
 
3263
  /*.patience =*/ -1.0f,
3264
  };