denersc ggerganov commited on
Commit
ce7ca09
·
unverified ·
1 Parent(s): 8e9c985

whisper : token-level timestamps with DTW (#1485)

Browse files

* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (3) hide show
  1. examples/main/main.cpp +39 -12
  2. whisper.cpp +572 -9
  3. whisper.h +41 -0
examples/main/main.cpp CHANGED
@@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string
26
 
27
  // command-line parameters
28
  struct whisper_params {
29
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
30
- int32_t n_processors = 1;
31
- int32_t offset_t_ms = 0;
32
- int32_t offset_n = 0;
33
- int32_t duration_ms = 0;
34
- int32_t progress_step = 5;
35
- int32_t max_context = -1;
36
- int32_t max_len = 0;
37
- int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
38
- int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39
- int32_t audio_ctx = 0;
40
 
41
  float word_thold = 0.01f;
42
  float entropy_thold = 2.40f;
@@ -76,6 +76,8 @@ struct whisper_params {
76
 
77
  std::string openvino_encode_device = "CPU";
78
 
 
 
79
  std::vector<std::string> fname_inp = {};
80
  std::vector<std::string> fname_out = {};
81
  };
@@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
149
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
150
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
151
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
 
152
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
153
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
154
  else {
@@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
208
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
209
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
210
  fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
 
211
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
212
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
213
  fprintf(stderr, "\n");
@@ -649,7 +653,8 @@ bool output_json(
649
  times_o(token.t0, token.t1, false);
650
  }
651
  value_i("id", token.id, false);
652
- value_f("p", token.p, true);
 
653
  end_obj(j == (n - 1));
654
  }
655
  end_arr(!params.diarize && !params.tinydiarize);
@@ -889,6 +894,28 @@ int main(int argc, char ** argv) {
889
  struct whisper_context_params cparams = whisper_context_default_params();
890
  cparams.use_gpu = params.use_gpu;
891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
893
 
894
  if (ctx == nullptr) {
 
26
 
27
  // command-line parameters
28
  struct whisper_params {
29
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
30
+ int32_t n_processors = 1;
31
+ int32_t offset_t_ms = 0;
32
+ int32_t offset_n = 0;
33
+ int32_t duration_ms = 0;
34
+ int32_t progress_step = 5;
35
+ int32_t max_context = -1;
36
+ int32_t max_len = 0;
37
+ int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
38
+ int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39
+ int32_t audio_ctx = 0;
40
 
41
  float word_thold = 0.01f;
42
  float entropy_thold = 2.40f;
 
76
 
77
  std::string openvino_encode_device = "CPU";
78
 
79
+ std::string dtw = "";
80
+
81
  std::vector<std::string> fname_inp = {};
82
  std::vector<std::string> fname_out = {};
83
  };
 
151
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
152
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
153
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
154
+ else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
155
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
156
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
157
  else {
 
211
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
212
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
213
  fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
214
+ fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
215
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
216
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
217
  fprintf(stderr, "\n");
 
653
  times_o(token.t0, token.t1, false);
654
  }
655
  value_i("id", token.id, false);
656
+ value_f("p", token.p, false);
657
+ value_f("t_dtw", token.t_dtw, true);
658
  end_obj(j == (n - 1));
659
  }
660
  end_arr(!params.diarize && !params.tinydiarize);
 
894
  struct whisper_context_params cparams = whisper_context_default_params();
895
  cparams.use_gpu = params.use_gpu;
896
 
897
+ if (!params.dtw.empty()) {
898
+ cparams.dtw_token_timestamps = true;
899
+ cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
900
+
901
+ if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
902
+ if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
903
+ if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
904
+ if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
905
+ if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
906
+ if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
907
+ if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
908
+ if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
909
+ if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
910
+ if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
911
+ if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
912
+
913
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
914
+ fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
915
+ return 3;
916
+ }
917
+ }
918
+
919
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
920
 
921
  if (ctx == nullptr) {
whisper.cpp CHANGED
@@ -351,6 +351,35 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
351
  { "yue", { 99, "cantonese", } },
352
  };
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  struct whisper_mel {
355
  int n_len;
356
  int n_len_org;
@@ -750,6 +779,13 @@ struct whisper_decoder {
750
  mutable std::mt19937 rng; // used for sampling at t > 0.0
751
  };
752
 
 
 
 
 
 
 
 
753
  struct whisper_state {
754
  int64_t t_sample_us = 0;
755
  int64_t t_encode_us = 0;
@@ -823,6 +859,11 @@ struct whisper_state {
823
 
824
  std::vector<float> energy; // PCM signal energy
825
 
 
 
 
 
 
826
  // [EXPERIMENTAL] speed-up techniques
827
  int32_t exp_n_audio_ctx = 0; // 0 - use default
828
  };
@@ -1027,6 +1068,132 @@ static void whisper_kv_cache_seq_cp(
1027
  }
1028
  }
1029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1030
  static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1031
  ggml_backend_t backend_gpu = NULL;
1032
 
@@ -2105,6 +2272,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2105
  whisper_context & wctx,
2106
  whisper_state & wstate,
2107
  const whisper_batch & batch,
 
2108
  bool worst_case) {
2109
  const auto & model = wctx.model;
2110
  const auto & hparams = model.hparams;
@@ -2158,6 +2326,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2158
 
2159
  struct ggml_tensor * inpL = cur;
2160
 
 
 
 
2161
  for (int il = 0; il < n_layer; ++il) {
2162
  const auto & layer = model.layers_decoder[il];
2163
 
@@ -2337,6 +2508,24 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2337
 
2338
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2340
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2341
 
2342
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -2422,6 +2611,16 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2422
 
2423
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2424
 
 
 
 
 
 
 
 
 
 
 
2425
  ggml_build_forward_expand(gf, logits);
2426
 
2427
  ggml_free(ctx0);
@@ -2444,6 +2643,7 @@ static bool whisper_decode_internal(
2444
  whisper_state & wstate,
2445
  const whisper_batch & batch,
2446
  const int n_threads,
 
2447
  ggml_abort_callback abort_callback,
2448
  void * abort_callback_data) {
2449
  const int64_t t_start_us = ggml_time_us();
@@ -2475,7 +2675,7 @@ static bool whisper_decode_internal(
2475
  {
2476
  auto & alloc = wstate.alloc_decode.alloc;
2477
 
2478
- ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, false);
2479
 
2480
  if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2481
  // should never happen as we pre-allocate the memory
@@ -3003,6 +3203,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3003
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3004
  }
3005
 
 
 
 
 
 
 
 
 
 
 
 
3006
  #ifdef WHISPER_USE_COREML
3007
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
3008
 
@@ -3095,7 +3306,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3095
 
3096
  whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3097
 
3098
- return whisper_build_graph_decoder(*ctx, *state, state->batch, true);
3099
  });
3100
 
3101
  if (!ok) {
@@ -3161,8 +3372,17 @@ int whisper_ctx_init_openvino_encoder(
3161
 
3162
  struct whisper_context_params whisper_context_default_params() {
3163
  struct whisper_context_params result = {
3164
- /*.use_gpu =*/ true,
3165
- /*.gpu_device =*/ 0,
 
 
 
 
 
 
 
 
 
3166
  };
3167
  return result;
3168
  }
@@ -3357,6 +3577,9 @@ void whisper_free_state(struct whisper_state * state) {
3357
 
3358
  ggml_backend_free(state->backend);
3359
 
 
 
 
3360
  delete state;
3361
  }
3362
  }
@@ -3476,7 +3699,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3476
 
3477
  whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3478
 
3479
- if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3480
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3481
  return 1;
3482
  }
@@ -4411,6 +4634,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
4411
  return txt[0] == ' ';
4412
  }
4413
 
 
 
 
 
 
 
 
 
 
 
 
4414
  // wrap the last segment to max_len characters
4415
  // returns the number of new segments
4416
  static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
@@ -4779,7 +5013,7 @@ static whisper_token_data whisper_sample_token(
4779
  const whisper_decoder & decoder,
4780
  bool best) {
4781
  whisper_token_data result = {
4782
- 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
4783
  };
4784
 
4785
  const auto & vocab = ctx.vocab;
@@ -4897,7 +5131,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4897
  const auto id = dist(decoder.rng);
4898
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4899
 
4900
- result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
4901
 
4902
  if (result[i].id >= vocab.token_beg) {
4903
  result[i].tid = result[i].id;
@@ -5259,7 +5493,7 @@ int whisper_full_with_state(
5259
 
5260
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5261
 
5262
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5263
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5264
  return -7;
5265
  }
@@ -5559,7 +5793,7 @@ int whisper_full_with_state(
5559
 
5560
  assert(batch.n_tokens > 0);
5561
 
5562
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5563
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5564
  return -8;
5565
  }
@@ -5682,6 +5916,9 @@ int whisper_full_with_state(
5682
 
5683
  const auto & tokens_cur = best_decoder.sequence.tokens;
5684
 
 
 
 
5685
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5686
 
5687
  // update prompt_past
@@ -5799,6 +6036,17 @@ int whisper_full_with_state(
5799
  }
5800
  }
5801
 
 
 
 
 
 
 
 
 
 
 
 
5802
  // update audio window
5803
  seek += seek_delta;
5804
 
@@ -6601,6 +6849,321 @@ static void whisper_exp_compute_token_level_timestamps(
6601
  //}
6602
  }
6603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6604
  void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
6605
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
6606
  g_state.log_callback_user_data = user_data;
 
351
  { "yue", { 99, "cantonese", } },
352
  };
353
 
354
+ // [EXPERIMENTAL] Token-level timestamps with DTW
355
+ static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
356
+ static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
357
+ static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
358
+ static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
359
+ static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
360
+ static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
361
+ static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
362
+ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
363
+ static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
364
+ static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
365
+ static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
366
+
367
+ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
368
+ { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
369
+ { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
370
+ { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
371
+ { WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
372
+ { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
373
+ { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
374
+ { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
375
+ { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
376
+ { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
377
+ { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
378
+ { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
379
+ };
380
+
381
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
382
+
383
  struct whisper_mel {
384
  int n_len;
385
  int n_len_org;
 
779
  mutable std::mt19937 rng; // used for sampling at t > 0.0
780
  };
781
 
782
+ // [EXPERIMENTAL] Token-level timestamps with DTW
783
+ struct whisper_aheads_masks {
784
+ std::vector<struct ggml_tensor *> m; // One mask per text layer.
785
+ struct ggml_context * ctx = nullptr;
786
+ ggml_backend_buffer_t buffer = nullptr;
787
+ };
788
+
789
  struct whisper_state {
790
  int64_t t_sample_us = 0;
791
  int64_t t_encode_us = 0;
 
859
 
860
  std::vector<float> energy; // PCM signal energy
861
 
862
+ // [EXPERIMENTAL] Token-level timestamps with DTW
863
+ whisper_aheads_masks aheads_masks;
864
+ ggml_tensor * aheads_cross_QKs = nullptr;
865
+ std::vector<float> aheads_cross_QKs_data;
866
+
867
  // [EXPERIMENTAL] speed-up techniques
868
  int32_t exp_n_audio_ctx = 0; // 0 - use default
869
  };
 
1068
  }
1069
  }
1070
 
1071
+ // [EXPERIMENTAL] Token-level timestamps with DTW
1072
+ static bool aheads_masks_init(
1073
+ const whisper_context_params & cparams,
1074
+ const whisper_hparams & hparams,
1075
+ struct whisper_aheads_masks & aheads_masks,
1076
+ ggml_backend_t backend) {
1077
+
1078
+ const int32_t n_text_layer = hparams.n_text_layer;
1079
+ const int32_t n_head = hparams.n_text_head;
1080
+
1081
+ // Sanity checks
1082
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
1083
+ WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
1084
+ return false;
1085
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
1086
+ if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
1087
+ WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
1088
+ return false;
1089
+ }
1090
+ } else {
1091
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
1092
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
1093
+ if (aheads.n_heads == 0) {
1094
+ WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
1095
+ return false;
1096
+ }
1097
+ if (aheads.heads == NULL) {
1098
+ WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
1099
+ return false;
1100
+ }
1101
+ }
1102
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
1103
+ if (aheads.heads[i].n_text_layer >= n_text_layer) {
1104
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
1105
+ return false;
1106
+ }
1107
+ if (aheads.heads[i].n_text_layer < 0) {
1108
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
1109
+ return false;
1110
+ }
1111
+ if (aheads.heads[i].n_head >= n_head) {
1112
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
1113
+ return false;
1114
+ }
1115
+ if (aheads.heads[i].n_head < 0) {
1116
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
1117
+ return false;
1118
+ }
1119
+ }
1120
+ }
1121
+
1122
+ struct ggml_init_params params = {
1123
+ /*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*ggml_tensor_overhead(),
1124
+ /*.mem_buffer =*/ nullptr,
1125
+ /*.no_alloc =*/ true,
1126
+ };
1127
+
1128
+ aheads_masks.ctx = ggml_init(params);
1129
+
1130
+ if (!aheads_masks.ctx) {
1131
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
1132
+ return false;
1133
+ }
1134
+
1135
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1136
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1137
+ if (!aheads.empty()) {
1138
+ aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size()));
1139
+ } else {
1140
+ aheads_masks.m.push_back(nullptr);
1141
+ }
1142
+ }
1143
+
1144
+ aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
1145
+ if (!aheads_masks.buffer) {
1146
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
1147
+ return false;
1148
+ }
1149
+
1150
+ // Set data on mask tensors
1151
+ // Since this must be backend agnostic, we get tensor data with
1152
+ // ggml_backend_tensor_get, copy our desired values and send it back
1153
+ // to backend with ggml_backend_tensor_set
1154
+ std::vector<float> mask_data;
1155
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1156
+ if (aheads_masks.m[il] != nullptr) {
1157
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1158
+
1159
+ size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1] * sizeof(float);
1160
+ mask_data.resize(data_size);
1161
+ ggml_backend_tensor_get(aheads_masks.m[il], mask_data.data(), 0, data_size);
1162
+ memset(mask_data.data(), 0, data_size);
1163
+
1164
+ for (size_t ih = 0; ih < aheads.size(); ++ih) {
1165
+ size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0] * aheads[ih]));
1166
+ float v = 1.0f;
1167
+ memcpy(mask_data.data() + pos, &v, sizeof(float));
1168
+ }
1169
+
1170
+ ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size);
1171
+ }
1172
+ }
1173
+
1174
+ if (aheads_masks.m.empty()) {
1175
+ WHISPER_LOG_ERROR("%s: \n", __func__);
1176
+ return false;
1177
+ }
1178
+
1179
+ return true;
1180
+ }
1181
+
1182
+ static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
1183
+ ggml_free(aheads_masks.ctx);
1184
+ ggml_backend_buffer_free(aheads_masks.buffer);
1185
+ aheads_masks.ctx = nullptr;
1186
+ }
1187
+
1188
+ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1189
+ size_t size = 0;
1190
+ for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
1191
+ if (aheads_masks.m[i] != nullptr)
1192
+ size += ggml_nbytes(aheads_masks.m[i]);
1193
+ }
1194
+ return size;
1195
+ }
1196
+
1197
  static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1198
  ggml_backend_t backend_gpu = NULL;
1199
 
 
2272
  whisper_context & wctx,
2273
  whisper_state & wstate,
2274
  const whisper_batch & batch,
2275
+ bool save_alignment_heads_QKs,
2276
  bool worst_case) {
2277
  const auto & model = wctx.model;
2278
  const auto & hparams = model.hparams;
 
2326
 
2327
  struct ggml_tensor * inpL = cur;
2328
 
2329
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2330
+ struct ggml_tensor * aheads_cross_QKs = nullptr;
2331
+
2332
  for (int il = 0; il < n_layer; ++il) {
2333
  const auto & layer = model.layers_decoder[il];
2334
 
 
2508
 
2509
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2510
 
2511
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2512
+ if (wctx.params.dtw_token_timestamps) {
2513
+ if (wstate.aheads_masks.m[il] != nullptr) {
2514
+ struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2515
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2516
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2517
+ aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2518
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2519
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2520
+ aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2521
+ if (aheads_cross_QKs == NULL) {
2522
+ aheads_cross_QKs = aheads_KQs;
2523
+ } else {
2524
+ aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
2525
+ }
2526
+ }
2527
+ }
2528
+
2529
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2530
 
2531
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
2611
 
2612
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2613
 
2614
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2615
+ if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
2616
+ aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs);
2617
+ aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs);
2618
+ if (save_alignment_heads_QKs) {
2619
+ ggml_build_forward_expand(gf, aheads_cross_QKs);
2620
+ wstate.aheads_cross_QKs = aheads_cross_QKs;
2621
+ }
2622
+ }
2623
+
2624
  ggml_build_forward_expand(gf, logits);
2625
 
2626
  ggml_free(ctx0);
 
2643
  whisper_state & wstate,
2644
  const whisper_batch & batch,
2645
  const int n_threads,
2646
+ bool save_alignment_heads_QKs,
2647
  ggml_abort_callback abort_callback,
2648
  void * abort_callback_data) {
2649
  const int64_t t_start_us = ggml_time_us();
 
2675
  {
2676
  auto & alloc = wstate.alloc_decode.alloc;
2677
 
2678
+ ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2679
 
2680
  if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2681
  // should never happen as we pre-allocate the memory
 
3203
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3204
  }
3205
 
3206
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3207
+ if (ctx->params.dtw_token_timestamps) {
3208
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
3209
+ WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3210
+ whisper_free_state(state);
3211
+ return nullptr;
3212
+ }
3213
+ const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
3214
+ WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
3215
+ }
3216
+
3217
  #ifdef WHISPER_USE_COREML
3218
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
3219
 
 
3306
 
3307
  whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3308
 
3309
+ return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
3310
  });
3311
 
3312
  if (!ok) {
 
3372
 
3373
  struct whisper_context_params whisper_context_default_params() {
3374
  struct whisper_context_params result = {
3375
+ /*.use_gpu =*/ true,
3376
+ /*.gpu_device =*/ 0,
3377
+
3378
+ /*.dtw_token_timestamps =*/ false,
3379
+ /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
3380
+ /*.dtw_n_top =*/ -1,
3381
+ /*.dtw_aheads =*/ {
3382
+ /*.n_heads =*/ 0,
3383
+ /*.heads =*/ NULL,
3384
+ },
3385
+ /*.dtw_mem_size =*/ 1024*1024*128,
3386
  };
3387
  return result;
3388
  }
 
3577
 
3578
  ggml_backend_free(state->backend);
3579
 
3580
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3581
+ aheads_masks_free(state->aheads_masks);
3582
+
3583
  delete state;
3584
  }
3585
  }
 
3699
 
3700
  whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3701
 
3702
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
3703
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3704
  return 1;
3705
  }
 
4634
  return txt[0] == ' ';
4635
  }
4636
 
4637
+ static void whisper_exp_compute_token_level_timestamps_dtw(
4638
+ struct whisper_context * ctx,
4639
+ struct whisper_state * state,
4640
+ struct whisper_full_params params,
4641
+ int i_segment,
4642
+ size_t n_segments,
4643
+ int seek,
4644
+ int n_frames,
4645
+ int medfilt_width,
4646
+ int n_threads);
4647
+
4648
  // wrap the last segment to max_len characters
4649
  // returns the number of new segments
4650
  static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
 
5013
  const whisper_decoder & decoder,
5014
  bool best) {
5015
  whisper_token_data result = {
5016
+ 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
5017
  };
5018
 
5019
  const auto & vocab = ctx.vocab;
 
5131
  const auto id = dist(decoder.rng);
5132
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
5133
 
5134
+ result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
5135
 
5136
  if (result[i].id >= vocab.token_beg) {
5137
  result[i].tid = result[i].id;
 
5493
 
5494
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5495
 
5496
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5497
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5498
  return -7;
5499
  }
 
5793
 
5794
  assert(batch.n_tokens > 0);
5795
 
5796
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5797
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5798
  return -8;
5799
  }
 
5916
 
5917
  const auto & tokens_cur = best_decoder.sequence.tokens;
5918
 
5919
+ // [EXPERIMENTAL] Token-level timestamps with DTW
5920
+ const auto n_segments_before = state->result_all.size();
5921
+
5922
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5923
 
5924
  // update prompt_past
 
6036
  }
6037
  }
6038
 
6039
+ // FIXME: will timestamp offsets be correct?
6040
+ // [EXPERIMENTAL] Token-level timestamps with DTW
6041
+ {
6042
+ const auto n_segments = state->result_all.size() - n_segments_before;
6043
+ if (ctx->params.dtw_token_timestamps && n_segments) {
6044
+ const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6045
+ whisper_exp_compute_token_level_timestamps_dtw(
6046
+ ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
6047
+ }
6048
+ }
6049
+
6050
  // update audio window
6051
  seek += seek_delta;
6052
 
 
6849
  //}
6850
  }
6851
 
6852
+ //
6853
+ // token level timestamps - dtw version
6854
+ //
6855
+
6856
+ // n_text_layer -> total text layers on model
6857
+ // n_head -> total heads per text layer on model
6858
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
6859
+ std::vector<uint32_t> ret;
6860
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
6861
+ return ret;
6862
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
6863
+ if (il >= n_text_layer - cparams.dtw_n_top) {
6864
+ for (int32_t i = 0; i < n_head; ++i) {
6865
+ ret.push_back(i);
6866
+ }
6867
+ }
6868
+ } else {
6869
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
6870
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
6871
+ if (aheads.heads[i].n_text_layer == il) {
6872
+ ret.push_back(aheads.heads[i].n_head);
6873
+ }
6874
+ }
6875
+ }
6876
+ return ret;
6877
+ }
6878
+
6879
+ // dtw + backtrace to return found path
6880
+ // based on
6881
+ // https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
6882
+ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
6883
+ WHISPER_ASSERT(ggml_n_dims(x) == 2);
6884
+
6885
+ int64_t N = x->ne[0];
6886
+ int64_t M = x->ne[1];
6887
+ struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
6888
+ struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
6889
+
6890
+ cost = ggml_set_f32(cost, INFINITY);
6891
+ trace = ggml_set_f32(trace, -1);
6892
+ ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
6893
+
6894
+ // dtw
6895
+ // supposedly can be optmized by computing diagonals in parallel ?
6896
+ // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
6897
+ for (int64_t j = 1; j < M + 1; ++j) {
6898
+ for (int64_t i = 1; i < N + 1; ++i) {
6899
+ float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
6900
+ float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
6901
+ float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
6902
+
6903
+ float c;
6904
+ int32_t t;
6905
+ if (c0 < c1 && c0 < c2) {
6906
+ c = c0;
6907
+ t = 0;
6908
+ } else if (c1 < c0 && c1 < c2) {
6909
+ c = c1;
6910
+ t = 1;
6911
+ } else {
6912
+ c = c2;
6913
+ t = 2;
6914
+ }
6915
+
6916
+ c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
6917
+ ggml_set_f32_nd(cost, i, j, 0, 0, c);
6918
+ ggml_set_i32_nd(trace, i, j, 0, 0, t);
6919
+ }
6920
+ }
6921
+
6922
+ // Backtrace
6923
+ const int64_t BT_MAX_ROWS = N + M - 1;
6924
+ struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
6925
+ // trace[0, :] = 2;
6926
+ for (int64_t i = 0; i < M + 1; ++i)
6927
+ ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
6928
+ //trace[:, 0] = 1;
6929
+ for (int64_t i = 0; i < N + 1; ++i)
6930
+ ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
6931
+ int bt_row_idx = BT_MAX_ROWS - 1;
6932
+ int64_t i = N;
6933
+ int64_t j = M;
6934
+ while (i > 0 || j > 0) {
6935
+ ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
6936
+ ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
6937
+ --bt_row_idx;
6938
+
6939
+ int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
6940
+ if (t == 0) {
6941
+ --i;
6942
+ --j;
6943
+ } else if (t == 1) {
6944
+ --i;
6945
+ } else if (t == 2) {
6946
+ --j;
6947
+ } else {
6948
+ WHISPER_ASSERT(0);
6949
+ }
6950
+ }
6951
+
6952
+ // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
6953
+ // Clip + transpose
6954
+ // This might not be entirely necessary for our case, but leaving it for now so output matrix
6955
+ // is identical to dtw on openAI timing.py
6956
+ const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
6957
+ ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
6958
+ for (int64_t i = 0; i < 2; ++i) {
6959
+ for (int64_t j = 0; j < result_n_cols; ++j) {
6960
+ int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
6961
+ ggml_set_i32_nd(r, i, j, 0, 0, v);
6962
+ }
6963
+ }
6964
+
6965
+ return r;
6966
+ }
6967
+
6968
+ struct median_filter_user_data {
6969
+ int filter_width;
6970
+ };
6971
+
6972
+ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
6973
+ int filter_width = ((median_filter_user_data *) userdata)->filter_width;
6974
+ WHISPER_ASSERT(nth == 1);
6975
+ WHISPER_ASSERT(ith == 0);
6976
+ WHISPER_ASSERT(filter_width < a->ne[2]);
6977
+ WHISPER_ASSERT(filter_width % 2);
6978
+ WHISPER_ASSERT(ggml_n_dims(a) == 3);
6979
+ WHISPER_ASSERT(a->type == GGML_TYPE_F32);
6980
+
6981
+ std::vector<float> filter;
6982
+ filter.reserve(filter_width);
6983
+ for (int64_t i = 0; i < a->ne[0]; ++i) {
6984
+ for (int64_t j = 0; j < a->ne[1]; ++j) {
6985
+ for (int64_t k = 0; k < a->ne[2]; ++k) {
6986
+ for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
6987
+ // "reflect" padding
6988
+ int64_t idx = k + off;
6989
+ if (idx < 0) {
6990
+ idx = -idx;
6991
+ } else if (idx >= a->ne[2]) {
6992
+ idx = 2*(a->ne[2] - 1) - idx;
6993
+ }
6994
+
6995
+ filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
6996
+ }
6997
+ std::sort(filter.begin(), filter.end());
6998
+ const float v = filter[filter.size()/2];
6999
+ ggml_set_f32_nd(dst, i, j, k, 0, v);
7000
+ filter.clear();
7001
+ }
7002
+ }
7003
+ }
7004
+ }
7005
+
7006
+ static void whisper_exp_compute_token_level_timestamps_dtw(
7007
+ struct whisper_context * ctx,
7008
+ struct whisper_state * state,
7009
+ struct whisper_full_params params,
7010
+ int i_segment,
7011
+ size_t n_segments,
7012
+ int seek,
7013
+ int n_frames,
7014
+ int medfilt_width,
7015
+ int n_threads)
7016
+ {
7017
+ const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
7018
+ WHISPER_ASSERT(medfilt_width % 2);
7019
+ WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
7020
+ WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
7021
+
7022
+ // FIXME: Allocating mem everytime we call this func
7023
+ // Our ggml buffer should be pre-allocated somewhere during init and reused
7024
+ // when we call this function
7025
+ struct ggml_init_params gparams = {
7026
+ /*.mem_size =*/ ctx->params.dtw_mem_size,
7027
+ /*.mem_buffer =*/ NULL,
7028
+ /*.no_alloc =*/ false,
7029
+ };
7030
+ struct ggml_context * gctx = ggml_init(gparams);
7031
+
7032
+ // Build token sequence that will be passed to decoder
7033
+ // sot + [lang] + text result + eot
7034
+ std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
7035
+ if (whisper_is_multilingual(ctx)) {
7036
+ const int lang_id = whisper_lang_id(params.language);
7037
+ state->lang_id = lang_id;
7038
+ tokens.push_back(whisper_token_lang(ctx, lang_id));
7039
+ }
7040
+ const size_t sot_sequence_length = tokens.size();
7041
+ tokens.push_back(whisper_token_not(ctx));
7042
+ for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7043
+ auto & segment = state->result_all[i];
7044
+ for (auto &t: segment.tokens) {
7045
+ // Only text tokens
7046
+ if (t.id < whisper_token_eot(ctx)) {
7047
+ tokens.push_back(t.id);
7048
+ }
7049
+ }
7050
+ }
7051
+ tokens.push_back(whisper_token_eot(ctx));
7052
+
7053
+ // Get result tokens, pass then along to decoder to get cross attention QKs
7054
+ // used in timestamping
7055
+ // Decoder already returns only alignment head QKs, already concatenated in
7056
+ // one tensor.
7057
+ whisper_kv_cache_clear(state->kv_self);
7058
+ whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
7059
+ whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
7060
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
7061
+ WHISPER_LOG_INFO("DECODER FAILED\n");
7062
+ WHISPER_ASSERT(0);
7063
+ }
7064
+ WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
7065
+
7066
+ const auto n_audio_tokens = n_frames/2;
7067
+ WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
7068
+ WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
7069
+ const auto n_tokens = state->aheads_cross_QKs->ne[0];
7070
+ const auto n_heads = state->aheads_cross_QKs->ne[2];
7071
+
7072
+ // Copy data from decoder buffer to a local CPU tensor, discarding unused audio
7073
+ // tokens (i.e. discarding rows at the end of tensor)
7074
+ // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
7075
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7076
+ WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32);
7077
+ WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs));
7078
+ ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
7079
+ auto & data = state->aheads_cross_QKs_data;
7080
+ data.resize(n_tokens * n_audio_ctx * n_heads);
7081
+ ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
7082
+ for (int k = 0; k < n_heads; ++k) {
7083
+ for (int j = 0; j < n_audio_tokens; ++j) {
7084
+ memcpy(
7085
+ (char *) w->data + j * w->nb[1] + k * w->nb[2],
7086
+ data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
7087
+ n_tokens * sizeof(float)
7088
+ );
7089
+ }
7090
+ }
7091
+
7092
+ // Normalize - in original OpenAI code, this is done over dim=-2. In this case,
7093
+ // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm
7094
+ // operates over columns. Afterwards, permute to a shape that facilitates mean
7095
+ // operation (after median filter)
7096
+ // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7097
+ // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7098
+ w = ggml_norm(gctx, w, 1e-9);
7099
+ w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7100
+
7101
+ // Pass median filter - this is done over AUDIO_TOKENS dimension.
7102
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7103
+ // OUT: Same dims
7104
+ median_filter_user_data mf_user_data = {medfilt_width};
7105
+ w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
7106
+
7107
+ // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
7108
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7109
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
7110
+ w = ggml_mean(gctx, w);
7111
+ w = ggml_scale(gctx, w, -1.0);
7112
+ w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
7113
+
7114
+ // Remove SOT sequence and EOT
7115
+ // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
7116
+ w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
7117
+
7118
+ // Compute
7119
+ struct ggml_cgraph * gf = ggml_new_graph(gctx);
7120
+ ggml_build_forward_expand(gf, w);
7121
+ ggml_graph_compute_with_ctx(gctx, gf, n_threads);
7122
+
7123
+ ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7124
+
7125
+ // Place timestamps on segments
7126
+ int32_t last_v = 0;
7127
+ auto seg_i = state->result_all.begin() + i_segment;
7128
+ auto tok_i = seg_i->tokens.begin();
7129
+ for (int i = 0; i < alignment->ne[1]; ++i) {
7130
+ int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
7131
+ if (v != last_v) {
7132
+ int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
7133
+ int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7134
+ last_v = v;
7135
+
7136
+ // Skip non-text tokens
7137
+ while (!(tok_i->id < whisper_token_eot(ctx))) {
7138
+ ++tok_i;
7139
+ if (tok_i == seg_i->tokens.end()) {
7140
+ ++seg_i;
7141
+ tok_i = seg_i->tokens.begin();
7142
+ }
7143
+ }
7144
+
7145
+ tok_i->t_dtw = timestamp;
7146
+ ++tok_i;
7147
+ if (tok_i == seg_i->tokens.end()) {
7148
+ ++seg_i;
7149
+ tok_i = seg_i->tokens.begin();
7150
+ }
7151
+ }
7152
+ }
7153
+
7154
+ // Print DTW timestamps
7155
+ /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7156
+ auto & segment = state->result_all[i];
7157
+ for (auto &t: segment.tokens) {
7158
+ const char * tok = whisper_token_to_str(ctx, t.id);
7159
+ fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
7160
+ }
7161
+ fprintf(stderr, "\n");
7162
+ }*/
7163
+
7164
+ ggml_free(gctx);
7165
+ }
7166
+
7167
  void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
7168
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
7169
  g_state.log_callback_user_data = user_data;
whisper.h CHANGED
@@ -84,9 +84,45 @@ extern "C" {
84
  typedef int32_t whisper_token;
85
  typedef int32_t whisper_seq_id;
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  struct whisper_context_params {
88
  bool use_gpu;
89
  int gpu_device; // CUDA device
 
 
 
 
 
 
 
 
 
90
  };
91
 
92
  typedef struct whisper_token_data {
@@ -103,6 +139,11 @@ extern "C" {
103
  int64_t t0; // start time of the token
104
  int64_t t1; // end time of the token
105
 
 
 
 
 
 
106
  float vlen; // voice length of the token
107
  } whisper_token_data;
108
 
 
84
  typedef int32_t whisper_token;
85
  typedef int32_t whisper_seq_id;
86
 
87
+ enum whisper_alignment_heads_preset {
88
+ WHISPER_AHEADS_NONE,
89
+ WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
90
+ WHISPER_AHEADS_CUSTOM,
91
+ WHISPER_AHEADS_TINY_EN,
92
+ WHISPER_AHEADS_TINY,
93
+ WHISPER_AHEADS_BASE_EN,
94
+ WHISPER_AHEADS_BASE,
95
+ WHISPER_AHEADS_SMALL_EN,
96
+ WHISPER_AHEADS_SMALL,
97
+ WHISPER_AHEADS_MEDIUM_EN,
98
+ WHISPER_AHEADS_MEDIUM,
99
+ WHISPER_AHEADS_LARGE_V1,
100
+ WHISPER_AHEADS_LARGE_V2,
101
+ WHISPER_AHEADS_LARGE_V3,
102
+ };
103
+
104
+ typedef struct whisper_ahead {
105
+ int n_text_layer;
106
+ int n_head;
107
+ } whisper_ahead;
108
+
109
+ typedef struct whisper_aheads {
110
+ size_t n_heads;
111
+ const whisper_ahead * heads;
112
+ } whisper_aheads;
113
+
114
  struct whisper_context_params {
115
  bool use_gpu;
116
  int gpu_device; // CUDA device
117
+
118
+ // [EXPERIMENTAL] Token-level timestamps with DTW
119
+ bool dtw_token_timestamps;
120
+ enum whisper_alignment_heads_preset dtw_aheads_preset;
121
+
122
+ int dtw_n_top;
123
+ struct whisper_aheads dtw_aheads;
124
+
125
+ size_t dtw_mem_size; // TODO: remove
126
  };
127
 
128
  typedef struct whisper_token_data {
 
139
  int64_t t0; // start time of the token
140
  int64_t t1; // end time of the token
141
 
142
+ // [EXPERIMENTAL] Token-level timestamps with DTW
143
+ // do not use if you haven't computed token-level timestamps with dtw
144
+ // Roughly corresponds to the moment in audio in which the token was output
145
+ int64_t t_dtw;
146
+
147
  float vlen; // voice length of the token
148
  } whisper_token_data;
149