Spaces:
Running
Running
whisper : improve handling of prompts (#1981)
Browse files* whisper : improve handling of prompts
* whisper : add whisper_token_count helper
- examples/main/main.cpp +1 -1
- whisper.cpp +11 -2
- whisper.h +7 -1
examples/main/main.cpp
CHANGED
|
@@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 207 |
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
| 208 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
| 209 |
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
| 210 |
-
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n",
|
| 211 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 212 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
| 213 |
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
|
|
|
| 207 |
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
| 208 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
| 209 |
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
| 210 |
+
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
|
| 211 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 212 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
| 213 |
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
whisper.cpp
CHANGED
|
@@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
| 3721 |
|
| 3722 |
if (n_max_tokens < (int) res.size()) {
|
| 3723 |
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
| 3724 |
-
return -
|
| 3725 |
}
|
| 3726 |
|
| 3727 |
for (int i = 0; i < (int) res.size(); i++) {
|
|
@@ -3731,6 +3731,10 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
| 3731 |
return res.size();
|
| 3732 |
}
|
| 3733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3734 |
int whisper_lang_max_id() {
|
| 3735 |
auto max_id = 0;
|
| 3736 |
for (const auto & kv : g_lang) {
|
|
@@ -5313,7 +5317,12 @@ int whisper_full_with_state(
|
|
| 5313 |
// initial prompt
|
| 5314 |
if (!params.prompt_tokens && params.initial_prompt) {
|
| 5315 |
prompt_tokens.resize(1024);
|
| 5316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5317 |
params.prompt_tokens = prompt_tokens.data();
|
| 5318 |
params.prompt_n_tokens = prompt_tokens.size();
|
| 5319 |
}
|
|
|
|
| 3721 |
|
| 3722 |
if (n_max_tokens < (int) res.size()) {
|
| 3723 |
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
| 3724 |
+
return -(int) res.size();
|
| 3725 |
}
|
| 3726 |
|
| 3727 |
for (int i = 0; i < (int) res.size(); i++) {
|
|
|
|
| 3731 |
return res.size();
|
| 3732 |
}
|
| 3733 |
|
| 3734 |
+
int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
| 3735 |
+
return -whisper_tokenize(ctx, text, NULL, 0);
|
| 3736 |
+
}
|
| 3737 |
+
|
| 3738 |
int whisper_lang_max_id() {
|
| 3739 |
auto max_id = 0;
|
| 3740 |
for (const auto & kv : g_lang) {
|
|
|
|
| 5317 |
// initial prompt
|
| 5318 |
if (!params.prompt_tokens && params.initial_prompt) {
|
| 5319 |
prompt_tokens.resize(1024);
|
| 5320 |
+
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
| 5321 |
+
if (n_needed < 0) {
|
| 5322 |
+
prompt_tokens.resize(-n_needed);
|
| 5323 |
+
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
| 5324 |
+
}
|
| 5325 |
+
prompt_tokens.resize(n_needed);
|
| 5326 |
params.prompt_tokens = prompt_tokens.data();
|
| 5327 |
params.prompt_n_tokens = prompt_tokens.size();
|
| 5328 |
}
|
whisper.h
CHANGED
|
@@ -337,7 +337,7 @@ extern "C" {
|
|
| 337 |
// Convert the provided text into tokens.
|
| 338 |
// The tokens pointer must be large enough to hold the resulting tokens.
|
| 339 |
// Returns the number of tokens on success, no more than n_max_tokens
|
| 340 |
-
// Returns
|
| 341 |
// TODO: not sure if correct
|
| 342 |
WHISPER_API int whisper_tokenize(
|
| 343 |
struct whisper_context * ctx,
|
|
@@ -345,6 +345,10 @@ extern "C" {
|
|
| 345 |
whisper_token * tokens,
|
| 346 |
int n_max_tokens);
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
// Largest language id (i.e. number of available languages - 1)
|
| 349 |
WHISPER_API int whisper_lang_max_id();
|
| 350 |
|
|
@@ -503,6 +507,8 @@ extern "C" {
|
|
| 503 |
|
| 504 |
// tokens to provide to the whisper decoder as initial prompt
|
| 505 |
// these are prepended to any existing text context from a previous call
|
|
|
|
|
|
|
| 506 |
const char * initial_prompt;
|
| 507 |
const whisper_token * prompt_tokens;
|
| 508 |
int prompt_n_tokens;
|
|
|
|
| 337 |
// Convert the provided text into tokens.
|
| 338 |
// The tokens pointer must be large enough to hold the resulting tokens.
|
| 339 |
// Returns the number of tokens on success, no more than n_max_tokens
|
| 340 |
+
// Returns a negative number on failure - the number of tokens that would have been returned
|
| 341 |
// TODO: not sure if correct
|
| 342 |
WHISPER_API int whisper_tokenize(
|
| 343 |
struct whisper_context * ctx,
|
|
|
|
| 345 |
whisper_token * tokens,
|
| 346 |
int n_max_tokens);
|
| 347 |
|
| 348 |
+
// Return the number of tokens in the provided text
|
| 349 |
+
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
|
| 350 |
+
int whisper_token_count(struct whisper_context * ctx, const char * text);
|
| 351 |
+
|
| 352 |
// Largest language id (i.e. number of available languages - 1)
|
| 353 |
WHISPER_API int whisper_lang_max_id();
|
| 354 |
|
|
|
|
| 507 |
|
| 508 |
// tokens to provide to the whisper decoder as initial prompt
|
| 509 |
// these are prepended to any existing text context from a previous call
|
| 510 |
+
// use whisper_tokenize() to convert text to tokens
|
| 511 |
+
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
|
| 512 |
const char * initial_prompt;
|
| 513 |
const whisper_token * prompt_tokens;
|
| 514 |
int prompt_n_tokens;
|