CRD716 commited on
Commit
1251039
·
unverified ·
1 Parent(s): a7b3aa5

whisper : add detect-language mode (#853)

Browse files

* add detectlanguage flag

* renaming and help

* no idea why that last one didn't commit

* run language detection if dl is set

* help message fix

* various fixes

* fix quitting

* fix language being english on print

Files changed (3) hide show
  1. examples/main/main.cpp +7 -0
  2. whisper.cpp +5 -1
  3. whisper.h +1 -0
examples/main/main.cpp CHANGED
@@ -66,6 +66,7 @@ struct whisper_params {
66
 
67
  bool speed_up = false;
68
  bool translate = false;
 
69
  bool diarize = false;
70
  bool split_on_word = false;
71
  bool no_fallback = false;
@@ -141,6 +142,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
141
  else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
142
  else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
143
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
 
144
  else if ( arg == "--prompt") { params.prompt = argv[++i]; }
145
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
146
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
@@ -191,6 +193,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
191
  fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
192
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
193
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
 
194
  fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
195
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
196
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
@@ -739,6 +742,9 @@ int main(int argc, char ** argv) {
739
  fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
740
  }
741
  }
 
 
 
742
  fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
743
  __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
744
  params.n_threads, params.n_processors,
@@ -761,6 +767,7 @@ int main(int argc, char ** argv) {
761
  wparams.print_special = params.print_special;
762
  wparams.translate = params.translate;
763
  wparams.language = params.language.c_str();
 
764
  wparams.n_threads = params.n_threads;
765
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
766
  wparams.offset_ms = params.offset_t_ms;
 
66
 
67
  bool speed_up = false;
68
  bool translate = false;
69
+ bool detect_language= false;
70
  bool diarize = false;
71
  bool split_on_word = false;
72
  bool no_fallback = false;
 
142
  else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
143
  else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
144
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
145
+ else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
146
  else if ( arg == "--prompt") { params.prompt = argv[++i]; }
147
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
148
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
 
193
  fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
194
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
195
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
196
+ fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
197
  fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
198
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
199
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
 
742
  fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
743
  }
744
  }
745
+ if (params.detect_language) {
746
+ params.language = "auto";
747
+ }
748
  fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
749
  __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
750
  params.n_threads, params.n_processors,
 
767
  wparams.print_special = params.print_special;
768
  wparams.translate = params.translate;
769
  wparams.language = params.language.c_str();
770
+ wparams.detect_language = params.detect_language;
771
  wparams.n_threads = params.n_threads;
772
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
773
  wparams.offset_ms = params.offset_t_ms;
whisper.cpp CHANGED
@@ -3312,6 +3312,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3312
  /*.prompt_n_tokens =*/ 0,
3313
 
3314
  /*.language =*/ "en",
 
3315
 
3316
  /*.suppress_blank =*/ true,
3317
  /*.suppress_non_speech_tokens =*/ false,
@@ -3898,7 +3899,7 @@ int whisper_full_with_state(
3898
  }
3899
 
3900
  // auto-detect language if not specified
3901
- if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3902
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3903
 
3904
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
@@ -3910,6 +3911,9 @@ int whisper_full_with_state(
3910
  params.language = whisper_lang_str(lang_id);
3911
 
3912
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
 
 
 
3913
  }
3914
 
3915
  if (params.token_timestamps) {
 
3312
  /*.prompt_n_tokens =*/ 0,
3313
 
3314
  /*.language =*/ "en",
3315
+ /*.detect_language =*/ false,
3316
 
3317
  /*.suppress_blank =*/ true,
3318
  /*.suppress_non_speech_tokens =*/ false,
 
3899
  }
3900
 
3901
  // auto-detect language if not specified
3902
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
3903
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3904
 
3905
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
 
3911
  params.language = whisper_lang_str(lang_id);
3912
 
3913
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3914
+ if (params.detect_language) {
3915
+ return 0;
3916
+ }
3917
  }
3918
 
3919
  if (params.token_timestamps) {
whisper.h CHANGED
@@ -365,6 +365,7 @@ extern "C" {
365
 
366
  // for auto-detection, set to nullptr, "" or "auto"
367
  const char * language;
 
368
 
369
  // common decoding parameters:
370
  bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
 
365
 
366
  // for auto-detection, set to nullptr, "" or "auto"
367
  const char * language;
368
+ bool detect_language;
369
 
370
  // common decoding parameters:
371
  bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89