Spaces:
Sleeping
Sleeping
whisper : language auto-detect (#59)
Browse files- examples/main/main.cpp +2 -2
- whisper.cpp +120 -5
- whisper.h +23 -0
examples/main/main.cpp
CHANGED
|
@@ -154,7 +154,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 154 |
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
| 155 |
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
| 156 |
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
| 157 |
-
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n",
|
| 158 |
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
| 159 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 160 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
|
@@ -453,7 +453,7 @@ int main(int argc, char ** argv) {
|
|
| 453 |
return 2;
|
| 454 |
}
|
| 455 |
|
| 456 |
-
if (whisper_lang_id(params.language.c_str()) == -1) {
|
| 457 |
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
| 458 |
whisper_print_usage(argc, argv, params);
|
| 459 |
exit(0);
|
|
|
|
| 154 |
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
| 155 |
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
| 156 |
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
| 157 |
+
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
| 158 |
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
| 159 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 160 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
|
|
|
| 453 |
return 2;
|
| 454 |
}
|
| 455 |
|
| 456 |
+
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
| 457 |
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
| 458 |
whisper_print_usage(argc, argv, params);
|
| 459 |
exit(0);
|
whisper.cpp
CHANGED
|
@@ -1105,7 +1105,7 @@ static bool whisper_encode(
|
|
| 1105 |
|
| 1106 |
struct ggml_init_params params;
|
| 1107 |
params.mem_size = wctx.buf_compute.size();
|
| 1108 |
-
params.mem_buffer = wctx.buf_compute.data();
|
| 1109 |
|
| 1110 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1111 |
|
|
@@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
| 2372 |
return res.size();
|
| 2373 |
}
|
| 2374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2375 |
int whisper_lang_id(const char * lang) {
|
| 2376 |
if (!g_lang.count(lang)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2377 |
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
| 2378 |
return -1;
|
| 2379 |
}
|
|
@@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
|
|
| 2381 |
return g_lang.at(lang).first;
|
| 2382 |
}
|
| 2383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2384 |
int whisper_n_len(struct whisper_context * ctx) {
|
| 2385 |
return ctx->mel.n_len;
|
| 2386 |
}
|
|
@@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
|
| 2429 |
return ctx->vocab.token_beg;
|
| 2430 |
}
|
| 2431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2432 |
whisper_token whisper_token_translate(void) {
|
| 2433 |
return whisper_vocab::token_translate;
|
| 2434 |
}
|
|
@@ -2661,10 +2760,25 @@ int whisper_full(
|
|
| 2661 |
} else {
|
| 2662 |
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
| 2663 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 2664 |
-
return -
|
| 2665 |
}
|
| 2666 |
}
|
| 2667 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2668 |
if (params.token_timestamps) {
|
| 2669 |
ctx->t_beg = 0;
|
| 2670 |
ctx->t_last = 0;
|
|
@@ -2703,7 +2817,8 @@ int whisper_full(
|
|
| 2703 |
// these tokens determine the task that will be performed
|
| 2704 |
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
| 2705 |
if (whisper_is_multilingual(ctx)) {
|
| 2706 |
-
|
|
|
|
| 2707 |
if (params.translate) {
|
| 2708 |
prompt_init.push_back(whisper_token_translate());
|
| 2709 |
} else {
|
|
@@ -2752,7 +2867,7 @@ int whisper_full(
|
|
| 2752 |
// encode audio features starting at offset seek
|
| 2753 |
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
| 2754 |
fprintf(stderr, "%s: failed to encode\n", __func__);
|
| 2755 |
-
return
|
| 2756 |
}
|
| 2757 |
|
| 2758 |
int n_past = 0;
|
|
@@ -2790,7 +2905,7 @@ int whisper_full(
|
|
| 2790 |
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
| 2791 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
| 2792 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2793 |
-
return
|
| 2794 |
}
|
| 2795 |
|
| 2796 |
n_past += prompt.size();
|
|
|
|
| 1105 |
|
| 1106 |
struct ggml_init_params params;
|
| 1107 |
params.mem_size = wctx.buf_compute.size();
|
| 1108 |
+
params.mem_buffer = wctx.buf_compute.data();
|
| 1109 |
|
| 1110 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1111 |
|
|
|
|
| 2372 |
return res.size();
|
| 2373 |
}
|
| 2374 |
|
| 2375 |
+
int whisper_lang_max_id() {
|
| 2376 |
+
auto max_id = 0;
|
| 2377 |
+
for (const auto & kv : g_lang) {
|
| 2378 |
+
max_id = std::max(max_id, kv.second.first);
|
| 2379 |
+
}
|
| 2380 |
+
|
| 2381 |
+
return max_id;
|
| 2382 |
+
}
|
| 2383 |
+
|
| 2384 |
int whisper_lang_id(const char * lang) {
|
| 2385 |
if (!g_lang.count(lang)) {
|
| 2386 |
+
for (const auto & kv : g_lang) {
|
| 2387 |
+
if (kv.second.second == lang) {
|
| 2388 |
+
return kv.second.first;
|
| 2389 |
+
}
|
| 2390 |
+
}
|
| 2391 |
+
|
| 2392 |
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
| 2393 |
return -1;
|
| 2394 |
}
|
|
|
|
| 2396 |
return g_lang.at(lang).first;
|
| 2397 |
}
|
| 2398 |
|
| 2399 |
+
const char * whisper_lang_str(int id) {
|
| 2400 |
+
for (const auto & kv : g_lang) {
|
| 2401 |
+
if (kv.second.first == id) {
|
| 2402 |
+
return kv.first.c_str();
|
| 2403 |
+
}
|
| 2404 |
+
}
|
| 2405 |
+
|
| 2406 |
+
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
|
| 2407 |
+
return NULL;
|
| 2408 |
+
}
|
| 2409 |
+
|
| 2410 |
+
int whisper_lang_auto_detect(
|
| 2411 |
+
struct whisper_context * ctx,
|
| 2412 |
+
int offset_ms,
|
| 2413 |
+
int n_threads,
|
| 2414 |
+
float * lang_probs) {
|
| 2415 |
+
const int seek = offset_ms/10;
|
| 2416 |
+
|
| 2417 |
+
if (seek < 0) {
|
| 2418 |
+
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
| 2419 |
+
return -1;
|
| 2420 |
+
}
|
| 2421 |
+
|
| 2422 |
+
if (seek >= ctx->mel.n_len) {
|
| 2423 |
+
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
|
| 2424 |
+
return -2;
|
| 2425 |
+
}
|
| 2426 |
+
|
| 2427 |
+
// run the encoder
|
| 2428 |
+
if (whisper_encode(ctx, seek, n_threads) != 0) {
|
| 2429 |
+
fprintf(stderr, "%s: failed to encode\n", __func__);
|
| 2430 |
+
return -6;
|
| 2431 |
+
}
|
| 2432 |
+
|
| 2433 |
+
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
| 2434 |
+
|
| 2435 |
+
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
| 2436 |
+
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2437 |
+
return -7;
|
| 2438 |
+
}
|
| 2439 |
+
|
| 2440 |
+
std::vector<std::pair<float, int>> probs_id;
|
| 2441 |
+
for (const auto kv : g_lang) {
|
| 2442 |
+
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
| 2443 |
+
probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
|
| 2444 |
+
}
|
| 2445 |
+
|
| 2446 |
+
// sort descending
|
| 2447 |
+
{
|
| 2448 |
+
using pair_type = decltype(probs_id)::value_type;
|
| 2449 |
+
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 2450 |
+
return a.first > b.first;
|
| 2451 |
+
});
|
| 2452 |
+
}
|
| 2453 |
+
|
| 2454 |
+
// softmax
|
| 2455 |
+
{
|
| 2456 |
+
float sum = 0;
|
| 2457 |
+
for (const auto & kv : probs_id) {
|
| 2458 |
+
sum += exp(kv.first);
|
| 2459 |
+
}
|
| 2460 |
+
|
| 2461 |
+
for (auto & kv : probs_id) {
|
| 2462 |
+
kv.first = exp(kv.first) / sum;
|
| 2463 |
+
}
|
| 2464 |
+
}
|
| 2465 |
+
|
| 2466 |
+
{
|
| 2467 |
+
for (int i = 0; i < probs_id.size(); i++) {
|
| 2468 |
+
if (lang_probs) {
|
| 2469 |
+
lang_probs[probs_id[i].second] = probs_id[i].first;
|
| 2470 |
+
}
|
| 2471 |
+
|
| 2472 |
+
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
| 2473 |
+
}
|
| 2474 |
+
}
|
| 2475 |
+
|
| 2476 |
+
return probs_id[0].second;
|
| 2477 |
+
}
|
| 2478 |
+
|
| 2479 |
int whisper_n_len(struct whisper_context * ctx) {
|
| 2480 |
return ctx->mel.n_len;
|
| 2481 |
}
|
|
|
|
| 2524 |
return ctx->vocab.token_beg;
|
| 2525 |
}
|
| 2526 |
|
| 2527 |
+
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
| 2528 |
+
return whisper_token_sot(ctx) + 1 + lang_id;
|
| 2529 |
+
}
|
| 2530 |
+
|
| 2531 |
whisper_token whisper_token_translate(void) {
|
| 2532 |
return whisper_vocab::token_translate;
|
| 2533 |
}
|
|
|
|
| 2760 |
} else {
|
| 2761 |
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
| 2762 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 2763 |
+
return -2;
|
| 2764 |
}
|
| 2765 |
}
|
| 2766 |
|
| 2767 |
+
// auto-detect language if not specified
|
| 2768 |
+
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
| 2769 |
+
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
| 2770 |
+
|
| 2771 |
+
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
|
| 2772 |
+
if (lang_id < 0) {
|
| 2773 |
+
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
| 2774 |
+
return -3;
|
| 2775 |
+
}
|
| 2776 |
+
|
| 2777 |
+
params.language = whisper_lang_str(lang_id);
|
| 2778 |
+
|
| 2779 |
+
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
| 2780 |
+
}
|
| 2781 |
+
|
| 2782 |
if (params.token_timestamps) {
|
| 2783 |
ctx->t_beg = 0;
|
| 2784 |
ctx->t_last = 0;
|
|
|
|
| 2817 |
// these tokens determine the task that will be performed
|
| 2818 |
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
| 2819 |
if (whisper_is_multilingual(ctx)) {
|
| 2820 |
+
const int lang_id = whisper_lang_id(params.language);
|
| 2821 |
+
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
| 2822 |
if (params.translate) {
|
| 2823 |
prompt_init.push_back(whisper_token_translate());
|
| 2824 |
} else {
|
|
|
|
| 2867 |
// encode audio features starting at offset seek
|
| 2868 |
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
| 2869 |
fprintf(stderr, "%s: failed to encode\n", __func__);
|
| 2870 |
+
return -4;
|
| 2871 |
}
|
| 2872 |
|
| 2873 |
int n_past = 0;
|
|
|
|
| 2905 |
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
| 2906 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
| 2907 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2908 |
+
return -5;
|
| 2909 |
}
|
| 2910 |
|
| 2911 |
n_past += prompt.size();
|
whisper.h
CHANGED
|
@@ -150,9 +150,30 @@ extern "C" {
|
|
| 150 |
whisper_token * tokens,
|
| 151 |
int n_max_tokens);
|
| 152 |
|
|
|
|
|
|
|
|
|
|
| 153 |
// Return the id of the specified language, returns -1 if not found
|
|
|
|
|
|
|
|
|
|
| 154 |
WHISPER_API int whisper_lang_id(const char * lang);
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
| 157 |
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
| 158 |
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
|
@@ -171,6 +192,7 @@ extern "C" {
|
|
| 171 |
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
|
| 172 |
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
|
| 173 |
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
|
|
|
|
| 174 |
|
| 175 |
// Task tokens
|
| 176 |
WHISPER_API whisper_token whisper_token_translate (void);
|
|
@@ -236,6 +258,7 @@ extern "C" {
|
|
| 236 |
const whisper_token * prompt_tokens;
|
| 237 |
int prompt_n_tokens;
|
| 238 |
|
|
|
|
| 239 |
const char * language;
|
| 240 |
|
| 241 |
struct {
|
|
|
|
| 150 |
whisper_token * tokens,
|
| 151 |
int n_max_tokens);
|
| 152 |
|
| 153 |
+
// Largest language id (i.e. number of available languages - 1)
|
| 154 |
+
WHISPER_API int whisper_lang_max_id();
|
| 155 |
+
|
| 156 |
// Return the id of the specified language, returns -1 if not found
|
| 157 |
+
// Examples:
|
| 158 |
+
// "de" -> 2
|
| 159 |
+
// "german" -> 2
|
| 160 |
WHISPER_API int whisper_lang_id(const char * lang);
|
| 161 |
|
| 162 |
+
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
| 163 |
+
WHISPER_API const char * whisper_lang_str(int id);
|
| 164 |
+
|
| 165 |
+
// Use mel data at offset_ms to try and auto-detect the spoken language
|
| 166 |
+
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
| 167 |
+
// Returns the top language id or negative on failure
|
| 168 |
+
// If not null, fills the lang_probs array with the probabilities of all languages
|
| 169 |
+
// The array must be whispe_lang_max_id() + 1 in size
|
| 170 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
| 171 |
+
WHISPER_API int whisper_lang_auto_detect(
|
| 172 |
+
struct whisper_context * ctx,
|
| 173 |
+
int offset_ms,
|
| 174 |
+
int n_threads,
|
| 175 |
+
float * lang_probs);
|
| 176 |
+
|
| 177 |
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
| 178 |
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
| 179 |
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
|
|
|
| 192 |
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
|
| 193 |
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
|
| 194 |
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
|
| 195 |
+
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
|
| 196 |
|
| 197 |
// Task tokens
|
| 198 |
WHISPER_API whisper_token whisper_token_translate (void);
|
|
|
|
| 258 |
const whisper_token * prompt_tokens;
|
| 259 |
int prompt_n_tokens;
|
| 260 |
|
| 261 |
+
// for auto-detection, set to nullptr, "" or "auto"
|
| 262 |
const char * language;
|
| 263 |
|
| 264 |
struct {
|