ggerganov commited on
Commit
15949a9
·
unverified ·
1 Parent(s): 749004e

whisper : improve handling of prompts (#1981)

Browse files

* whisper : improve handling of prompts

* whisper : add whisper_token_count helper

Files changed (3) hide show
  1. examples/main/main.cpp +1 -1
  2. whisper.cpp +11 -2
  3. whisper.h +7 -1
examples/main/main.cpp CHANGED
@@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
207
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
208
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
209
  fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
210
- fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
211
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
212
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
213
  fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
 
207
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
208
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
209
  fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
210
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
211
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
212
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
213
  fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
whisper.cpp CHANGED
@@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3721
 
3722
  if (n_max_tokens < (int) res.size()) {
3723
  WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3724
- return -1;
3725
  }
3726
 
3727
  for (int i = 0; i < (int) res.size(); i++) {
@@ -3731,6 +3731,10 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3731
  return res.size();
3732
  }
3733
 
 
 
 
 
3734
  int whisper_lang_max_id() {
3735
  auto max_id = 0;
3736
  for (const auto & kv : g_lang) {
@@ -5313,7 +5317,12 @@ int whisper_full_with_state(
5313
  // initial prompt
5314
  if (!params.prompt_tokens && params.initial_prompt) {
5315
  prompt_tokens.resize(1024);
5316
- prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
 
 
 
 
 
5317
  params.prompt_tokens = prompt_tokens.data();
5318
  params.prompt_n_tokens = prompt_tokens.size();
5319
  }
 
3721
 
3722
  if (n_max_tokens < (int) res.size()) {
3723
  WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3724
+ return -(int) res.size();
3725
  }
3726
 
3727
  for (int i = 0; i < (int) res.size(); i++) {
 
3731
  return res.size();
3732
  }
3733
 
3734
+ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3735
+ return -whisper_tokenize(ctx, text, NULL, 0);
3736
+ }
3737
+
3738
  int whisper_lang_max_id() {
3739
  auto max_id = 0;
3740
  for (const auto & kv : g_lang) {
 
5317
  // initial prompt
5318
  if (!params.prompt_tokens && params.initial_prompt) {
5319
  prompt_tokens.resize(1024);
5320
+ int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5321
+ if (n_needed < 0) {
5322
+ prompt_tokens.resize(-n_needed);
5323
+ n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5324
+ }
5325
+ prompt_tokens.resize(n_needed);
5326
  params.prompt_tokens = prompt_tokens.data();
5327
  params.prompt_n_tokens = prompt_tokens.size();
5328
  }
whisper.h CHANGED
@@ -337,7 +337,7 @@ extern "C" {
337
  // Convert the provided text into tokens.
338
  // The tokens pointer must be large enough to hold the resulting tokens.
339
  // Returns the number of tokens on success, no more than n_max_tokens
340
- // Returns -1 on failure
341
  // TODO: not sure if correct
342
  WHISPER_API int whisper_tokenize(
343
  struct whisper_context * ctx,
@@ -345,6 +345,10 @@ extern "C" {
345
  whisper_token * tokens,
346
  int n_max_tokens);
347
 
 
 
 
 
348
  // Largest language id (i.e. number of available languages - 1)
349
  WHISPER_API int whisper_lang_max_id();
350
 
@@ -503,6 +507,8 @@ extern "C" {
503
 
504
  // tokens to provide to the whisper decoder as initial prompt
505
  // these are prepended to any existing text context from a previous call
 
 
506
  const char * initial_prompt;
507
  const whisper_token * prompt_tokens;
508
  int prompt_n_tokens;
 
337
  // Convert the provided text into tokens.
338
  // The tokens pointer must be large enough to hold the resulting tokens.
339
  // Returns the number of tokens on success, no more than n_max_tokens
340
+ // Returns a negative number on failure - the number of tokens that would have been returned
341
  // TODO: not sure if correct
342
  WHISPER_API int whisper_tokenize(
343
  struct whisper_context * ctx,
 
345
  whisper_token * tokens,
346
  int n_max_tokens);
347
 
348
+ // Return the number of tokens in the provided text
349
+ // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
350
+ int whisper_token_count(struct whisper_context * ctx, const char * text);
351
+
352
  // Largest language id (i.e. number of available languages - 1)
353
  WHISPER_API int whisper_lang_max_id();
354
 
 
507
 
508
  // tokens to provide to the whisper decoder as initial prompt
509
  // these are prepended to any existing text context from a previous call
510
+ // use whisper_tokenize() to convert text to tokens
511
+ // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
512
  const char * initial_prompt;
513
  const whisper_token * prompt_tokens;
514
  int prompt_n_tokens;