ggerganov commited on
Commit
dad1114
·
1 Parent(s): 443fc8d

whisper : language auto-detect (#59)

Browse files
Files changed (3) hide show
  1. examples/main/main.cpp +2 -2
  2. whisper.cpp +120 -5
  3. whisper.h +23 -0
examples/main/main.cpp CHANGED
@@ -154,7 +154,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
154
  fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
155
  fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
156
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
157
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
158
  fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
159
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
160
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
@@ -453,7 +453,7 @@ int main(int argc, char ** argv) {
453
  return 2;
454
  }
455
 
456
- if (whisper_lang_id(params.language.c_str()) == -1) {
457
  fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
458
  whisper_print_usage(argc, argv, params);
459
  exit(0);
 
154
  fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
155
  fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
156
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
157
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
158
  fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
159
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
160
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
 
453
  return 2;
454
  }
455
 
456
+ if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
457
  fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
458
  whisper_print_usage(argc, argv, params);
459
  exit(0);
whisper.cpp CHANGED
@@ -1105,7 +1105,7 @@ static bool whisper_encode(
1105
 
1106
  struct ggml_init_params params;
1107
  params.mem_size = wctx.buf_compute.size();
1108
- params.mem_buffer = wctx.buf_compute.data();
1109
 
1110
  struct ggml_context * ctx0 = ggml_init(params);
1111
 
@@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
2372
  return res.size();
2373
  }
2374
 
 
 
 
 
 
 
 
 
 
2375
  int whisper_lang_id(const char * lang) {
2376
  if (!g_lang.count(lang)) {
 
 
 
 
 
 
2377
  fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2378
  return -1;
2379
  }
@@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
2381
  return g_lang.at(lang).first;
2382
  }
2383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2384
  int whisper_n_len(struct whisper_context * ctx) {
2385
  return ctx->mel.n_len;
2386
  }
@@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
2429
  return ctx->vocab.token_beg;
2430
  }
2431
 
 
 
 
 
2432
  whisper_token whisper_token_translate(void) {
2433
  return whisper_vocab::token_translate;
2434
  }
@@ -2661,10 +2760,25 @@ int whisper_full(
2661
  } else {
2662
  if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
2663
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
2664
- return -1;
2665
  }
2666
  }
2667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2668
  if (params.token_timestamps) {
2669
  ctx->t_beg = 0;
2670
  ctx->t_last = 0;
@@ -2703,7 +2817,8 @@ int whisper_full(
2703
  // these tokens determine the task that will be performed
2704
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
2705
  if (whisper_is_multilingual(ctx)) {
2706
- prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
 
2707
  if (params.translate) {
2708
  prompt_init.push_back(whisper_token_translate());
2709
  } else {
@@ -2752,7 +2867,7 @@ int whisper_full(
2752
  // encode audio features starting at offset seek
2753
  if (whisper_encode(ctx, seek, params.n_threads) != 0) {
2754
  fprintf(stderr, "%s: failed to encode\n", __func__);
2755
- return 7;
2756
  }
2757
 
2758
  int n_past = 0;
@@ -2790,7 +2905,7 @@ int whisper_full(
2790
  for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
2791
  if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
2792
  fprintf(stderr, "%s: failed to decode\n", __func__);
2793
- return 8;
2794
  }
2795
 
2796
  n_past += prompt.size();
 
1105
 
1106
  struct ggml_init_params params;
1107
  params.mem_size = wctx.buf_compute.size();
1108
+ params.mem_buffer = wctx.buf_compute.data();
1109
 
1110
  struct ggml_context * ctx0 = ggml_init(params);
1111
 
 
2372
  return res.size();
2373
  }
2374
 
2375
+ int whisper_lang_max_id() {
2376
+ auto max_id = 0;
2377
+ for (const auto & kv : g_lang) {
2378
+ max_id = std::max(max_id, kv.second.first);
2379
+ }
2380
+
2381
+ return max_id;
2382
+ }
2383
+
2384
  int whisper_lang_id(const char * lang) {
2385
  if (!g_lang.count(lang)) {
2386
+ for (const auto & kv : g_lang) {
2387
+ if (kv.second.second == lang) {
2388
+ return kv.second.first;
2389
+ }
2390
+ }
2391
+
2392
  fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2393
  return -1;
2394
  }
 
2396
  return g_lang.at(lang).first;
2397
  }
2398
 
2399
+ const char * whisper_lang_str(int id) {
2400
+ for (const auto & kv : g_lang) {
2401
+ if (kv.second.first == id) {
2402
+ return kv.first.c_str();
2403
+ }
2404
+ }
2405
+
2406
+ fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
2407
+ return NULL;
2408
+ }
2409
+
2410
+ int whisper_lang_auto_detect(
2411
+ struct whisper_context * ctx,
2412
+ int offset_ms,
2413
+ int n_threads,
2414
+ float * lang_probs) {
2415
+ const int seek = offset_ms/10;
2416
+
2417
+ if (seek < 0) {
2418
+ fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
2419
+ return -1;
2420
+ }
2421
+
2422
+ if (seek >= ctx->mel.n_len) {
2423
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2424
+ return -2;
2425
+ }
2426
+
2427
+ // run the encoder
2428
+ if (whisper_encode(ctx, seek, n_threads) != 0) {
2429
+ fprintf(stderr, "%s: failed to encode\n", __func__);
2430
+ return -6;
2431
+ }
2432
+
2433
+ const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2434
+
2435
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2436
+ fprintf(stderr, "%s: failed to decode\n", __func__);
2437
+ return -7;
2438
+ }
2439
+
2440
+ std::vector<std::pair<float, int>> probs_id;
2441
+ for (const auto kv : g_lang) {
2442
+ const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2443
+ probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
2444
+ }
2445
+
2446
+ // sort descending
2447
+ {
2448
+ using pair_type = decltype(probs_id)::value_type;
2449
+ std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
2450
+ return a.first > b.first;
2451
+ });
2452
+ }
2453
+
2454
+ // softmax
2455
+ {
2456
+ float sum = 0;
2457
+ for (const auto & kv : probs_id) {
2458
+ sum += exp(kv.first);
2459
+ }
2460
+
2461
+ for (auto & kv : probs_id) {
2462
+ kv.first = exp(kv.first) / sum;
2463
+ }
2464
+ }
2465
+
2466
+ {
2467
+ for (int i = 0; i < probs_id.size(); i++) {
2468
+ if (lang_probs) {
2469
+ lang_probs[probs_id[i].second] = probs_id[i].first;
2470
+ }
2471
+
2472
+ //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
2473
+ }
2474
+ }
2475
+
2476
+ return probs_id[0].second;
2477
+ }
2478
+
2479
  int whisper_n_len(struct whisper_context * ctx) {
2480
  return ctx->mel.n_len;
2481
  }
 
2524
  return ctx->vocab.token_beg;
2525
  }
2526
 
2527
+ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
2528
+ return whisper_token_sot(ctx) + 1 + lang_id;
2529
+ }
2530
+
2531
  whisper_token whisper_token_translate(void) {
2532
  return whisper_vocab::token_translate;
2533
  }
 
2760
  } else {
2761
  if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
2762
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
2763
+ return -2;
2764
  }
2765
  }
2766
 
2767
+ // auto-detect language if not specified
2768
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
2769
+ std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
2770
+
2771
+ const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
2772
+ if (lang_id < 0) {
2773
+ fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
2774
+ return -3;
2775
+ }
2776
+
2777
+ params.language = whisper_lang_str(lang_id);
2778
+
2779
+ fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
2780
+ }
2781
+
2782
  if (params.token_timestamps) {
2783
  ctx->t_beg = 0;
2784
  ctx->t_last = 0;
 
2817
  // these tokens determine the task that will be performed
2818
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
2819
  if (whisper_is_multilingual(ctx)) {
2820
+ const int lang_id = whisper_lang_id(params.language);
2821
+ prompt_init.push_back(whisper_token_lang(ctx, lang_id));
2822
  if (params.translate) {
2823
  prompt_init.push_back(whisper_token_translate());
2824
  } else {
 
2867
  // encode audio features starting at offset seek
2868
  if (whisper_encode(ctx, seek, params.n_threads) != 0) {
2869
  fprintf(stderr, "%s: failed to encode\n", __func__);
2870
+ return -4;
2871
  }
2872
 
2873
  int n_past = 0;
 
2905
  for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
2906
  if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
2907
  fprintf(stderr, "%s: failed to decode\n", __func__);
2908
+ return -5;
2909
  }
2910
 
2911
  n_past += prompt.size();
whisper.h CHANGED
@@ -150,9 +150,30 @@ extern "C" {
150
  whisper_token * tokens,
151
  int n_max_tokens);
152
 
 
 
 
153
  // Return the id of the specified language, returns -1 if not found
 
 
 
154
  WHISPER_API int whisper_lang_id(const char * lang);
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
157
  WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
158
  WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
@@ -171,6 +192,7 @@ extern "C" {
171
  WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
172
  WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
173
  WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
 
174
 
175
  // Task tokens
176
  WHISPER_API whisper_token whisper_token_translate (void);
@@ -236,6 +258,7 @@ extern "C" {
236
  const whisper_token * prompt_tokens;
237
  int prompt_n_tokens;
238
 
 
239
  const char * language;
240
 
241
  struct {
 
150
  whisper_token * tokens,
151
  int n_max_tokens);
152
 
153
+ // Largest language id (i.e. number of available languages - 1)
154
+ WHISPER_API int whisper_lang_max_id();
155
+
156
  // Return the id of the specified language, returns -1 if not found
157
+ // Examples:
158
+ // "de" -> 2
159
+ // "german" -> 2
160
  WHISPER_API int whisper_lang_id(const char * lang);
161
 
162
+ // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
163
+ WHISPER_API const char * whisper_lang_str(int id);
164
+
165
+ // Use mel data at offset_ms to try and auto-detect the spoken language
166
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
167
+ // Returns the top language id or negative on failure
168
+ // If not null, fills the lang_probs array with the probabilities of all languages
169
+ // The array must be whispe_lang_max_id() + 1 in size
170
+ // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
171
+ WHISPER_API int whisper_lang_auto_detect(
172
+ struct whisper_context * ctx,
173
+ int offset_ms,
174
+ int n_threads,
175
+ float * lang_probs);
176
+
177
  WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
178
  WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
179
  WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
 
192
  WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
193
  WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
194
  WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
195
+ WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
196
 
197
  // Task tokens
198
  WHISPER_API whisper_token whisper_token_translate (void);
 
258
  const whisper_token * prompt_tokens;
259
  int prompt_n_tokens;
260
 
261
+ // for auto-detection, set to nullptr, "" or "auto"
262
  const char * language;
263
 
264
  struct {