Spaces:
Sleeping
whisper : token-level timestamps with DTW (#1485)
Browse files* whisper.cpp: impl dtw algo
* WIP: producing and placing DTW timestamps on tokens
* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.
* Fix mistake causing incorrect alignment of dtw timestamps
* implement N_TOP_MOST and CUSTOM alignment heads setting
* whisper: fix typo on alignment heads enum
* Fix issues related to changes in whisper.cpp
* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function
* decoder: save cross QKs only if requested
* Calling median filter with ggml_map_custom1
* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads
* Copying cross QKs from decoder backend correctly
* dtw: cleanup
* Fix incorrect n_frames passed to dtw when near end of audio
* Fix aheads_masks_init for backend != CPU
* whisper : minor style
* main : add dtw (wip)
* whisper: fix invalid memory access in aheads_masks_init
* main : add dtw (cont)
* whisper : minor
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- examples/main/main.cpp +39 -12
- whisper.cpp +572 -9
- whisper.h +41 -0
|
@@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string
|
|
| 26 |
|
| 27 |
// command-line parameters
|
| 28 |
struct whisper_params {
|
| 29 |
-
int32_t n_threads
|
| 30 |
-
int32_t n_processors
|
| 31 |
-
int32_t offset_t_ms
|
| 32 |
-
int32_t offset_n
|
| 33 |
-
int32_t duration_ms
|
| 34 |
-
int32_t progress_step =
|
| 35 |
-
int32_t max_context
|
| 36 |
-
int32_t max_len
|
| 37 |
-
int32_t best_of
|
| 38 |
-
int32_t beam_size
|
| 39 |
-
int32_t audio_ctx
|
| 40 |
|
| 41 |
float word_thold = 0.01f;
|
| 42 |
float entropy_thold = 2.40f;
|
|
@@ -76,6 +76,8 @@ struct whisper_params {
|
|
| 76 |
|
| 77 |
std::string openvino_encode_device = "CPU";
|
| 78 |
|
|
|
|
|
|
|
| 79 |
std::vector<std::string> fname_inp = {};
|
| 80 |
std::vector<std::string> fname_out = {};
|
| 81 |
};
|
|
@@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 149 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 150 |
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
| 151 |
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
|
|
|
| 152 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 153 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 154 |
else {
|
|
@@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 208 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 209 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
| 210 |
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
|
|
|
| 211 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 212 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 213 |
fprintf(stderr, "\n");
|
|
@@ -649,7 +653,8 @@ bool output_json(
|
|
| 649 |
times_o(token.t0, token.t1, false);
|
| 650 |
}
|
| 651 |
value_i("id", token.id, false);
|
| 652 |
-
value_f("p", token.p,
|
|
|
|
| 653 |
end_obj(j == (n - 1));
|
| 654 |
}
|
| 655 |
end_arr(!params.diarize && !params.tinydiarize);
|
|
@@ -889,6 +894,28 @@ int main(int argc, char ** argv) {
|
|
| 889 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 890 |
cparams.use_gpu = params.use_gpu;
|
| 891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 893 |
|
| 894 |
if (ctx == nullptr) {
|
|
|
|
| 26 |
|
| 27 |
// command-line parameters
|
| 28 |
struct whisper_params {
|
| 29 |
+
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
| 30 |
+
int32_t n_processors = 1;
|
| 31 |
+
int32_t offset_t_ms = 0;
|
| 32 |
+
int32_t offset_n = 0;
|
| 33 |
+
int32_t duration_ms = 0;
|
| 34 |
+
int32_t progress_step = 5;
|
| 35 |
+
int32_t max_context = -1;
|
| 36 |
+
int32_t max_len = 0;
|
| 37 |
+
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
| 38 |
+
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
| 39 |
+
int32_t audio_ctx = 0;
|
| 40 |
|
| 41 |
float word_thold = 0.01f;
|
| 42 |
float entropy_thold = 2.40f;
|
|
|
|
| 76 |
|
| 77 |
std::string openvino_encode_device = "CPU";
|
| 78 |
|
| 79 |
+
std::string dtw = "";
|
| 80 |
+
|
| 81 |
std::vector<std::string> fname_inp = {};
|
| 82 |
std::vector<std::string> fname_out = {};
|
| 83 |
};
|
|
|
|
| 151 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 152 |
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
| 153 |
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
| 154 |
+
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 155 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 156 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 157 |
else {
|
|
|
|
| 211 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 212 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
| 213 |
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
| 214 |
+
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 215 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 216 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 217 |
fprintf(stderr, "\n");
|
|
|
|
| 653 |
times_o(token.t0, token.t1, false);
|
| 654 |
}
|
| 655 |
value_i("id", token.id, false);
|
| 656 |
+
value_f("p", token.p, false);
|
| 657 |
+
value_f("t_dtw", token.t_dtw, true);
|
| 658 |
end_obj(j == (n - 1));
|
| 659 |
}
|
| 660 |
end_arr(!params.diarize && !params.tinydiarize);
|
|
|
|
| 894 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 895 |
cparams.use_gpu = params.use_gpu;
|
| 896 |
|
| 897 |
+
if (!params.dtw.empty()) {
|
| 898 |
+
cparams.dtw_token_timestamps = true;
|
| 899 |
+
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
| 900 |
+
|
| 901 |
+
if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
|
| 902 |
+
if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
|
| 903 |
+
if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
|
| 904 |
+
if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
|
| 905 |
+
if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
|
| 906 |
+
if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
|
| 907 |
+
if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
|
| 908 |
+
if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
|
| 909 |
+
if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
|
| 910 |
+
if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
|
| 911 |
+
if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
|
| 912 |
+
|
| 913 |
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
| 914 |
+
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
|
| 915 |
+
return 3;
|
| 916 |
+
}
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 920 |
|
| 921 |
if (ctx == nullptr) {
|
|
@@ -351,6 +351,35 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
| 351 |
{ "yue", { 99, "cantonese", } },
|
| 352 |
};
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
struct whisper_mel {
|
| 355 |
int n_len;
|
| 356 |
int n_len_org;
|
|
@@ -750,6 +779,13 @@ struct whisper_decoder {
|
|
| 750 |
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
| 751 |
};
|
| 752 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
struct whisper_state {
|
| 754 |
int64_t t_sample_us = 0;
|
| 755 |
int64_t t_encode_us = 0;
|
|
@@ -823,6 +859,11 @@ struct whisper_state {
|
|
| 823 |
|
| 824 |
std::vector<float> energy; // PCM signal energy
|
| 825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
// [EXPERIMENTAL] speed-up techniques
|
| 827 |
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
| 828 |
};
|
|
@@ -1027,6 +1068,132 @@ static void whisper_kv_cache_seq_cp(
|
|
| 1027 |
}
|
| 1028 |
}
|
| 1029 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1030 |
static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
|
| 1031 |
ggml_backend_t backend_gpu = NULL;
|
| 1032 |
|
|
@@ -2105,6 +2272,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2105 |
whisper_context & wctx,
|
| 2106 |
whisper_state & wstate,
|
| 2107 |
const whisper_batch & batch,
|
|
|
|
| 2108 |
bool worst_case) {
|
| 2109 |
const auto & model = wctx.model;
|
| 2110 |
const auto & hparams = model.hparams;
|
|
@@ -2158,6 +2326,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2158 |
|
| 2159 |
struct ggml_tensor * inpL = cur;
|
| 2160 |
|
|
|
|
|
|
|
|
|
|
| 2161 |
for (int il = 0; il < n_layer; ++il) {
|
| 2162 |
const auto & layer = model.layers_decoder[il];
|
| 2163 |
|
|
@@ -2337,6 +2508,24 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2337 |
|
| 2338 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2340 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2341 |
|
| 2342 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
@@ -2422,6 +2611,16 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2422 |
|
| 2423 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2425 |
ggml_build_forward_expand(gf, logits);
|
| 2426 |
|
| 2427 |
ggml_free(ctx0);
|
|
@@ -2444,6 +2643,7 @@ static bool whisper_decode_internal(
|
|
| 2444 |
whisper_state & wstate,
|
| 2445 |
const whisper_batch & batch,
|
| 2446 |
const int n_threads,
|
|
|
|
| 2447 |
ggml_abort_callback abort_callback,
|
| 2448 |
void * abort_callback_data) {
|
| 2449 |
const int64_t t_start_us = ggml_time_us();
|
|
@@ -2475,7 +2675,7 @@ static bool whisper_decode_internal(
|
|
| 2475 |
{
|
| 2476 |
auto & alloc = wstate.alloc_decode.alloc;
|
| 2477 |
|
| 2478 |
-
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, false);
|
| 2479 |
|
| 2480 |
if (!ggml_gallocr_alloc_graph(alloc, gf)) {
|
| 2481 |
// should never happen as we pre-allocate the memory
|
|
@@ -3003,6 +3203,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3003 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3004 |
}
|
| 3005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3006 |
#ifdef WHISPER_USE_COREML
|
| 3007 |
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
| 3008 |
|
|
@@ -3095,7 +3306,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3095 |
|
| 3096 |
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
| 3097 |
|
| 3098 |
-
return whisper_build_graph_decoder(*ctx, *state, state->batch, true);
|
| 3099 |
});
|
| 3100 |
|
| 3101 |
if (!ok) {
|
|
@@ -3161,8 +3372,17 @@ int whisper_ctx_init_openvino_encoder(
|
|
| 3161 |
|
| 3162 |
struct whisper_context_params whisper_context_default_params() {
|
| 3163 |
struct whisper_context_params result = {
|
| 3164 |
-
/*.use_gpu
|
| 3165 |
-
/*.gpu_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3166 |
};
|
| 3167 |
return result;
|
| 3168 |
}
|
|
@@ -3357,6 +3577,9 @@ void whisper_free_state(struct whisper_state * state) {
|
|
| 3357 |
|
| 3358 |
ggml_backend_free(state->backend);
|
| 3359 |
|
|
|
|
|
|
|
|
|
|
| 3360 |
delete state;
|
| 3361 |
}
|
| 3362 |
}
|
|
@@ -3476,7 +3699,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
| 3476 |
|
| 3477 |
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
| 3478 |
|
| 3479 |
-
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
| 3480 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3481 |
return 1;
|
| 3482 |
}
|
|
@@ -4411,6 +4634,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
|
| 4411 |
return txt[0] == ' ';
|
| 4412 |
}
|
| 4413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4414 |
// wrap the last segment to max_len characters
|
| 4415 |
// returns the number of new segments
|
| 4416 |
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
|
@@ -4779,7 +5013,7 @@ static whisper_token_data whisper_sample_token(
|
|
| 4779 |
const whisper_decoder & decoder,
|
| 4780 |
bool best) {
|
| 4781 |
whisper_token_data result = {
|
| 4782 |
-
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
| 4783 |
};
|
| 4784 |
|
| 4785 |
const auto & vocab = ctx.vocab;
|
|
@@ -4897,7 +5131,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 4897 |
const auto id = dist(decoder.rng);
|
| 4898 |
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
| 4899 |
|
| 4900 |
-
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
| 4901 |
|
| 4902 |
if (result[i].id >= vocab.token_beg) {
|
| 4903 |
result[i].tid = result[i].id;
|
|
@@ -5259,7 +5493,7 @@ int whisper_full_with_state(
|
|
| 5259 |
|
| 5260 |
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
| 5261 |
|
| 5262 |
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 5263 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5264 |
return -7;
|
| 5265 |
}
|
|
@@ -5559,7 +5793,7 @@ int whisper_full_with_state(
|
|
| 5559 |
|
| 5560 |
assert(batch.n_tokens > 0);
|
| 5561 |
|
| 5562 |
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 5563 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5564 |
return -8;
|
| 5565 |
}
|
|
@@ -5682,6 +5916,9 @@ int whisper_full_with_state(
|
|
| 5682 |
|
| 5683 |
const auto & tokens_cur = best_decoder.sequence.tokens;
|
| 5684 |
|
|
|
|
|
|
|
|
|
|
| 5685 |
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
| 5686 |
|
| 5687 |
// update prompt_past
|
|
@@ -5799,6 +6036,17 @@ int whisper_full_with_state(
|
|
| 5799 |
}
|
| 5800 |
}
|
| 5801 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5802 |
// update audio window
|
| 5803 |
seek += seek_delta;
|
| 5804 |
|
|
@@ -6601,6 +6849,321 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 6601 |
//}
|
| 6602 |
}
|
| 6603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6604 |
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
| 6605 |
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
| 6606 |
g_state.log_callback_user_data = user_data;
|
|
|
|
| 351 |
{ "yue", { 99, "cantonese", } },
|
| 352 |
};
|
| 353 |
|
| 354 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 355 |
+
static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
|
| 356 |
+
static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
|
| 357 |
+
static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
|
| 358 |
+
static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
|
| 359 |
+
static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
|
| 360 |
+
static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
|
| 361 |
+
static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
|
| 362 |
+
static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
|
| 363 |
+
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
| 364 |
+
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
| 365 |
+
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
| 366 |
+
|
| 367 |
+
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
| 368 |
+
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
| 369 |
+
{ WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
|
| 370 |
+
{ WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
|
| 371 |
+
{ WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
|
| 372 |
+
{ WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
|
| 373 |
+
{ WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
|
| 374 |
+
{ WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
|
| 375 |
+
{ WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
|
| 376 |
+
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
| 377 |
+
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
| 378 |
+
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
| 379 |
+
};
|
| 380 |
+
|
| 381 |
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
| 382 |
+
|
| 383 |
struct whisper_mel {
|
| 384 |
int n_len;
|
| 385 |
int n_len_org;
|
|
|
|
| 779 |
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
| 780 |
};
|
| 781 |
|
| 782 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 783 |
+
struct whisper_aheads_masks {
|
| 784 |
+
std::vector<struct ggml_tensor *> m; // One mask per text layer.
|
| 785 |
+
struct ggml_context * ctx = nullptr;
|
| 786 |
+
ggml_backend_buffer_t buffer = nullptr;
|
| 787 |
+
};
|
| 788 |
+
|
| 789 |
struct whisper_state {
|
| 790 |
int64_t t_sample_us = 0;
|
| 791 |
int64_t t_encode_us = 0;
|
|
|
|
| 859 |
|
| 860 |
std::vector<float> energy; // PCM signal energy
|
| 861 |
|
| 862 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 863 |
+
whisper_aheads_masks aheads_masks;
|
| 864 |
+
ggml_tensor * aheads_cross_QKs = nullptr;
|
| 865 |
+
std::vector<float> aheads_cross_QKs_data;
|
| 866 |
+
|
| 867 |
// [EXPERIMENTAL] speed-up techniques
|
| 868 |
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
| 869 |
};
|
|
|
|
| 1068 |
}
|
| 1069 |
}
|
| 1070 |
|
| 1071 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 1072 |
+
static bool aheads_masks_init(
|
| 1073 |
+
const whisper_context_params & cparams,
|
| 1074 |
+
const whisper_hparams & hparams,
|
| 1075 |
+
struct whisper_aheads_masks & aheads_masks,
|
| 1076 |
+
ggml_backend_t backend) {
|
| 1077 |
+
|
| 1078 |
+
const int32_t n_text_layer = hparams.n_text_layer;
|
| 1079 |
+
const int32_t n_head = hparams.n_text_head;
|
| 1080 |
+
|
| 1081 |
+
// Sanity checks
|
| 1082 |
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
| 1083 |
+
WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
|
| 1084 |
+
return false;
|
| 1085 |
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
| 1086 |
+
if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
|
| 1087 |
+
WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
|
| 1088 |
+
return false;
|
| 1089 |
+
}
|
| 1090 |
+
} else {
|
| 1091 |
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
| 1092 |
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
|
| 1093 |
+
if (aheads.n_heads == 0) {
|
| 1094 |
+
WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
|
| 1095 |
+
return false;
|
| 1096 |
+
}
|
| 1097 |
+
if (aheads.heads == NULL) {
|
| 1098 |
+
WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
|
| 1099 |
+
return false;
|
| 1100 |
+
}
|
| 1101 |
+
}
|
| 1102 |
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
| 1103 |
+
if (aheads.heads[i].n_text_layer >= n_text_layer) {
|
| 1104 |
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
|
| 1105 |
+
return false;
|
| 1106 |
+
}
|
| 1107 |
+
if (aheads.heads[i].n_text_layer < 0) {
|
| 1108 |
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
|
| 1109 |
+
return false;
|
| 1110 |
+
}
|
| 1111 |
+
if (aheads.heads[i].n_head >= n_head) {
|
| 1112 |
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
|
| 1113 |
+
return false;
|
| 1114 |
+
}
|
| 1115 |
+
if (aheads.heads[i].n_head < 0) {
|
| 1116 |
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
|
| 1117 |
+
return false;
|
| 1118 |
+
}
|
| 1119 |
+
}
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
struct ggml_init_params params = {
|
| 1123 |
+
/*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*ggml_tensor_overhead(),
|
| 1124 |
+
/*.mem_buffer =*/ nullptr,
|
| 1125 |
+
/*.no_alloc =*/ true,
|
| 1126 |
+
};
|
| 1127 |
+
|
| 1128 |
+
aheads_masks.ctx = ggml_init(params);
|
| 1129 |
+
|
| 1130 |
+
if (!aheads_masks.ctx) {
|
| 1131 |
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
|
| 1132 |
+
return false;
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
| 1136 |
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
| 1137 |
+
if (!aheads.empty()) {
|
| 1138 |
+
aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size()));
|
| 1139 |
+
} else {
|
| 1140 |
+
aheads_masks.m.push_back(nullptr);
|
| 1141 |
+
}
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
|
| 1145 |
+
if (!aheads_masks.buffer) {
|
| 1146 |
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
|
| 1147 |
+
return false;
|
| 1148 |
+
}
|
| 1149 |
+
|
| 1150 |
+
// Set data on mask tensors
|
| 1151 |
+
// Since this must be backend agnostic, we get tensor data with
|
| 1152 |
+
// ggml_backend_tensor_get, copy our desired values and send it back
|
| 1153 |
+
// to backend with ggml_backend_tensor_set
|
| 1154 |
+
std::vector<float> mask_data;
|
| 1155 |
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
| 1156 |
+
if (aheads_masks.m[il] != nullptr) {
|
| 1157 |
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
| 1158 |
+
|
| 1159 |
+
size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1] * sizeof(float);
|
| 1160 |
+
mask_data.resize(data_size);
|
| 1161 |
+
ggml_backend_tensor_get(aheads_masks.m[il], mask_data.data(), 0, data_size);
|
| 1162 |
+
memset(mask_data.data(), 0, data_size);
|
| 1163 |
+
|
| 1164 |
+
for (size_t ih = 0; ih < aheads.size(); ++ih) {
|
| 1165 |
+
size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0] * aheads[ih]));
|
| 1166 |
+
float v = 1.0f;
|
| 1167 |
+
memcpy(mask_data.data() + pos, &v, sizeof(float));
|
| 1168 |
+
}
|
| 1169 |
+
|
| 1170 |
+
ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size);
|
| 1171 |
+
}
|
| 1172 |
+
}
|
| 1173 |
+
|
| 1174 |
+
if (aheads_masks.m.empty()) {
|
| 1175 |
+
WHISPER_LOG_ERROR("%s: \n", __func__);
|
| 1176 |
+
return false;
|
| 1177 |
+
}
|
| 1178 |
+
|
| 1179 |
+
return true;
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
|
| 1183 |
+
ggml_free(aheads_masks.ctx);
|
| 1184 |
+
ggml_backend_buffer_free(aheads_masks.buffer);
|
| 1185 |
+
aheads_masks.ctx = nullptr;
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
| 1189 |
+
size_t size = 0;
|
| 1190 |
+
for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
|
| 1191 |
+
if (aheads_masks.m[i] != nullptr)
|
| 1192 |
+
size += ggml_nbytes(aheads_masks.m[i]);
|
| 1193 |
+
}
|
| 1194 |
+
return size;
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
|
| 1198 |
ggml_backend_t backend_gpu = NULL;
|
| 1199 |
|
|
|
|
| 2272 |
whisper_context & wctx,
|
| 2273 |
whisper_state & wstate,
|
| 2274 |
const whisper_batch & batch,
|
| 2275 |
+
bool save_alignment_heads_QKs,
|
| 2276 |
bool worst_case) {
|
| 2277 |
const auto & model = wctx.model;
|
| 2278 |
const auto & hparams = model.hparams;
|
|
|
|
| 2326 |
|
| 2327 |
struct ggml_tensor * inpL = cur;
|
| 2328 |
|
| 2329 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 2330 |
+
struct ggml_tensor * aheads_cross_QKs = nullptr;
|
| 2331 |
+
|
| 2332 |
for (int il = 0; il < n_layer; ++il) {
|
| 2333 |
const auto & layer = model.layers_decoder[il];
|
| 2334 |
|
|
|
|
| 2508 |
|
| 2509 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2510 |
|
| 2511 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 2512 |
+
if (wctx.params.dtw_token_timestamps) {
|
| 2513 |
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
| 2514 |
+
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
| 2515 |
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
| 2516 |
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
| 2517 |
+
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
| 2518 |
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
| 2519 |
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
| 2520 |
+
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
| 2521 |
+
if (aheads_cross_QKs == NULL) {
|
| 2522 |
+
aheads_cross_QKs = aheads_KQs;
|
| 2523 |
+
} else {
|
| 2524 |
+
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
|
| 2525 |
+
}
|
| 2526 |
+
}
|
| 2527 |
+
}
|
| 2528 |
+
|
| 2529 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2530 |
|
| 2531 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
|
| 2611 |
|
| 2612 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2613 |
|
| 2614 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 2615 |
+
if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
|
| 2616 |
+
aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs);
|
| 2617 |
+
aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs);
|
| 2618 |
+
if (save_alignment_heads_QKs) {
|
| 2619 |
+
ggml_build_forward_expand(gf, aheads_cross_QKs);
|
| 2620 |
+
wstate.aheads_cross_QKs = aheads_cross_QKs;
|
| 2621 |
+
}
|
| 2622 |
+
}
|
| 2623 |
+
|
| 2624 |
ggml_build_forward_expand(gf, logits);
|
| 2625 |
|
| 2626 |
ggml_free(ctx0);
|
|
|
|
| 2643 |
whisper_state & wstate,
|
| 2644 |
const whisper_batch & batch,
|
| 2645 |
const int n_threads,
|
| 2646 |
+
bool save_alignment_heads_QKs,
|
| 2647 |
ggml_abort_callback abort_callback,
|
| 2648 |
void * abort_callback_data) {
|
| 2649 |
const int64_t t_start_us = ggml_time_us();
|
|
|
|
| 2675 |
{
|
| 2676 |
auto & alloc = wstate.alloc_decode.alloc;
|
| 2677 |
|
| 2678 |
+
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
| 2679 |
|
| 2680 |
if (!ggml_gallocr_alloc_graph(alloc, gf)) {
|
| 2681 |
// should never happen as we pre-allocate the memory
|
|
|
|
| 3203 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3204 |
}
|
| 3205 |
|
| 3206 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3207 |
+
if (ctx->params.dtw_token_timestamps) {
|
| 3208 |
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
| 3209 |
+
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
| 3210 |
+
whisper_free_state(state);
|
| 3211 |
+
return nullptr;
|
| 3212 |
+
}
|
| 3213 |
+
const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
|
| 3214 |
+
WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
|
| 3215 |
+
}
|
| 3216 |
+
|
| 3217 |
#ifdef WHISPER_USE_COREML
|
| 3218 |
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
| 3219 |
|
|
|
|
| 3306 |
|
| 3307 |
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
| 3308 |
|
| 3309 |
+
return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
|
| 3310 |
});
|
| 3311 |
|
| 3312 |
if (!ok) {
|
|
|
|
| 3372 |
|
| 3373 |
struct whisper_context_params whisper_context_default_params() {
|
| 3374 |
struct whisper_context_params result = {
|
| 3375 |
+
/*.use_gpu =*/ true,
|
| 3376 |
+
/*.gpu_device =*/ 0,
|
| 3377 |
+
|
| 3378 |
+
/*.dtw_token_timestamps =*/ false,
|
| 3379 |
+
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
|
| 3380 |
+
/*.dtw_n_top =*/ -1,
|
| 3381 |
+
/*.dtw_aheads =*/ {
|
| 3382 |
+
/*.n_heads =*/ 0,
|
| 3383 |
+
/*.heads =*/ NULL,
|
| 3384 |
+
},
|
| 3385 |
+
/*.dtw_mem_size =*/ 1024*1024*128,
|
| 3386 |
};
|
| 3387 |
return result;
|
| 3388 |
}
|
|
|
|
| 3577 |
|
| 3578 |
ggml_backend_free(state->backend);
|
| 3579 |
|
| 3580 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3581 |
+
aheads_masks_free(state->aheads_masks);
|
| 3582 |
+
|
| 3583 |
delete state;
|
| 3584 |
}
|
| 3585 |
}
|
|
|
|
| 3699 |
|
| 3700 |
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
| 3701 |
|
| 3702 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
|
| 3703 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3704 |
return 1;
|
| 3705 |
}
|
|
|
|
| 4634 |
return txt[0] == ' ';
|
| 4635 |
}
|
| 4636 |
|
| 4637 |
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
| 4638 |
+
struct whisper_context * ctx,
|
| 4639 |
+
struct whisper_state * state,
|
| 4640 |
+
struct whisper_full_params params,
|
| 4641 |
+
int i_segment,
|
| 4642 |
+
size_t n_segments,
|
| 4643 |
+
int seek,
|
| 4644 |
+
int n_frames,
|
| 4645 |
+
int medfilt_width,
|
| 4646 |
+
int n_threads);
|
| 4647 |
+
|
| 4648 |
// wrap the last segment to max_len characters
|
| 4649 |
// returns the number of new segments
|
| 4650 |
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
|
|
|
| 5013 |
const whisper_decoder & decoder,
|
| 5014 |
bool best) {
|
| 5015 |
whisper_token_data result = {
|
| 5016 |
+
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
|
| 5017 |
};
|
| 5018 |
|
| 5019 |
const auto & vocab = ctx.vocab;
|
|
|
|
| 5131 |
const auto id = dist(decoder.rng);
|
| 5132 |
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
| 5133 |
|
| 5134 |
+
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
|
| 5135 |
|
| 5136 |
if (result[i].id >= vocab.token_beg) {
|
| 5137 |
result[i].tid = result[i].id;
|
|
|
|
| 5493 |
|
| 5494 |
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
| 5495 |
|
| 5496 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 5497 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5498 |
return -7;
|
| 5499 |
}
|
|
|
|
| 5793 |
|
| 5794 |
assert(batch.n_tokens > 0);
|
| 5795 |
|
| 5796 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 5797 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5798 |
return -8;
|
| 5799 |
}
|
|
|
|
| 5916 |
|
| 5917 |
const auto & tokens_cur = best_decoder.sequence.tokens;
|
| 5918 |
|
| 5919 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 5920 |
+
const auto n_segments_before = state->result_all.size();
|
| 5921 |
+
|
| 5922 |
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
| 5923 |
|
| 5924 |
// update prompt_past
|
|
|
|
| 6036 |
}
|
| 6037 |
}
|
| 6038 |
|
| 6039 |
+
// FIXME: will timestamp offsets be correct?
|
| 6040 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 6041 |
+
{
|
| 6042 |
+
const auto n_segments = state->result_all.size() - n_segments_before;
|
| 6043 |
+
if (ctx->params.dtw_token_timestamps && n_segments) {
|
| 6044 |
+
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
| 6045 |
+
whisper_exp_compute_token_level_timestamps_dtw(
|
| 6046 |
+
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
| 6047 |
+
}
|
| 6048 |
+
}
|
| 6049 |
+
|
| 6050 |
// update audio window
|
| 6051 |
seek += seek_delta;
|
| 6052 |
|
|
|
|
| 6849 |
//}
|
| 6850 |
}
|
| 6851 |
|
| 6852 |
+
//
|
| 6853 |
+
// token level timestamps - dtw version
|
| 6854 |
+
//
|
| 6855 |
+
|
| 6856 |
+
// n_text_layer -> total text layers on model
|
| 6857 |
+
// n_head -> total heads per text layer on model
|
| 6858 |
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
|
| 6859 |
+
std::vector<uint32_t> ret;
|
| 6860 |
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
| 6861 |
+
return ret;
|
| 6862 |
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
| 6863 |
+
if (il >= n_text_layer - cparams.dtw_n_top) {
|
| 6864 |
+
for (int32_t i = 0; i < n_head; ++i) {
|
| 6865 |
+
ret.push_back(i);
|
| 6866 |
+
}
|
| 6867 |
+
}
|
| 6868 |
+
} else {
|
| 6869 |
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
| 6870 |
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
| 6871 |
+
if (aheads.heads[i].n_text_layer == il) {
|
| 6872 |
+
ret.push_back(aheads.heads[i].n_head);
|
| 6873 |
+
}
|
| 6874 |
+
}
|
| 6875 |
+
}
|
| 6876 |
+
return ret;
|
| 6877 |
+
}
|
| 6878 |
+
|
| 6879 |
+
// dtw + backtrace to return found path
|
| 6880 |
+
// based on
|
| 6881 |
+
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
|
| 6882 |
+
static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
| 6883 |
+
WHISPER_ASSERT(ggml_n_dims(x) == 2);
|
| 6884 |
+
|
| 6885 |
+
int64_t N = x->ne[0];
|
| 6886 |
+
int64_t M = x->ne[1];
|
| 6887 |
+
struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
|
| 6888 |
+
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
|
| 6889 |
+
|
| 6890 |
+
cost = ggml_set_f32(cost, INFINITY);
|
| 6891 |
+
trace = ggml_set_f32(trace, -1);
|
| 6892 |
+
ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
| 6893 |
+
|
| 6894 |
+
// dtw
|
| 6895 |
+
// supposedly can be optmized by computing diagonals in parallel ?
|
| 6896 |
+
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
| 6897 |
+
for (int64_t j = 1; j < M + 1; ++j) {
|
| 6898 |
+
for (int64_t i = 1; i < N + 1; ++i) {
|
| 6899 |
+
float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
| 6900 |
+
float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
|
| 6901 |
+
float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
|
| 6902 |
+
|
| 6903 |
+
float c;
|
| 6904 |
+
int32_t t;
|
| 6905 |
+
if (c0 < c1 && c0 < c2) {
|
| 6906 |
+
c = c0;
|
| 6907 |
+
t = 0;
|
| 6908 |
+
} else if (c1 < c0 && c1 < c2) {
|
| 6909 |
+
c = c1;
|
| 6910 |
+
t = 1;
|
| 6911 |
+
} else {
|
| 6912 |
+
c = c2;
|
| 6913 |
+
t = 2;
|
| 6914 |
+
}
|
| 6915 |
+
|
| 6916 |
+
c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
| 6917 |
+
ggml_set_f32_nd(cost, i, j, 0, 0, c);
|
| 6918 |
+
ggml_set_i32_nd(trace, i, j, 0, 0, t);
|
| 6919 |
+
}
|
| 6920 |
+
}
|
| 6921 |
+
|
| 6922 |
+
// Backtrace
|
| 6923 |
+
const int64_t BT_MAX_ROWS = N + M - 1;
|
| 6924 |
+
struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
| 6925 |
+
// trace[0, :] = 2;
|
| 6926 |
+
for (int64_t i = 0; i < M + 1; ++i)
|
| 6927 |
+
ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
|
| 6928 |
+
//trace[:, 0] = 1;
|
| 6929 |
+
for (int64_t i = 0; i < N + 1; ++i)
|
| 6930 |
+
ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
|
| 6931 |
+
int bt_row_idx = BT_MAX_ROWS - 1;
|
| 6932 |
+
int64_t i = N;
|
| 6933 |
+
int64_t j = M;
|
| 6934 |
+
while (i > 0 || j > 0) {
|
| 6935 |
+
ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
| 6936 |
+
ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
| 6937 |
+
--bt_row_idx;
|
| 6938 |
+
|
| 6939 |
+
int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
|
| 6940 |
+
if (t == 0) {
|
| 6941 |
+
--i;
|
| 6942 |
+
--j;
|
| 6943 |
+
} else if (t == 1) {
|
| 6944 |
+
--i;
|
| 6945 |
+
} else if (t == 2) {
|
| 6946 |
+
--j;
|
| 6947 |
+
} else {
|
| 6948 |
+
WHISPER_ASSERT(0);
|
| 6949 |
+
}
|
| 6950 |
+
}
|
| 6951 |
+
|
| 6952 |
+
// FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
|
| 6953 |
+
// Clip + transpose
|
| 6954 |
+
// This might not be entirely necessary for our case, but leaving it for now so output matrix
|
| 6955 |
+
// is identical to dtw on openAI timing.py
|
| 6956 |
+
const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
|
| 6957 |
+
ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
|
| 6958 |
+
for (int64_t i = 0; i < 2; ++i) {
|
| 6959 |
+
for (int64_t j = 0; j < result_n_cols; ++j) {
|
| 6960 |
+
int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
| 6961 |
+
ggml_set_i32_nd(r, i, j, 0, 0, v);
|
| 6962 |
+
}
|
| 6963 |
+
}
|
| 6964 |
+
|
| 6965 |
+
return r;
|
| 6966 |
+
}
|
| 6967 |
+
|
| 6968 |
+
struct median_filter_user_data {
|
| 6969 |
+
int filter_width;
|
| 6970 |
+
};
|
| 6971 |
+
|
| 6972 |
+
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
|
| 6973 |
+
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
| 6974 |
+
WHISPER_ASSERT(nth == 1);
|
| 6975 |
+
WHISPER_ASSERT(ith == 0);
|
| 6976 |
+
WHISPER_ASSERT(filter_width < a->ne[2]);
|
| 6977 |
+
WHISPER_ASSERT(filter_width % 2);
|
| 6978 |
+
WHISPER_ASSERT(ggml_n_dims(a) == 3);
|
| 6979 |
+
WHISPER_ASSERT(a->type == GGML_TYPE_F32);
|
| 6980 |
+
|
| 6981 |
+
std::vector<float> filter;
|
| 6982 |
+
filter.reserve(filter_width);
|
| 6983 |
+
for (int64_t i = 0; i < a->ne[0]; ++i) {
|
| 6984 |
+
for (int64_t j = 0; j < a->ne[1]; ++j) {
|
| 6985 |
+
for (int64_t k = 0; k < a->ne[2]; ++k) {
|
| 6986 |
+
for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
|
| 6987 |
+
// "reflect" padding
|
| 6988 |
+
int64_t idx = k + off;
|
| 6989 |
+
if (idx < 0) {
|
| 6990 |
+
idx = -idx;
|
| 6991 |
+
} else if (idx >= a->ne[2]) {
|
| 6992 |
+
idx = 2*(a->ne[2] - 1) - idx;
|
| 6993 |
+
}
|
| 6994 |
+
|
| 6995 |
+
filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
|
| 6996 |
+
}
|
| 6997 |
+
std::sort(filter.begin(), filter.end());
|
| 6998 |
+
const float v = filter[filter.size()/2];
|
| 6999 |
+
ggml_set_f32_nd(dst, i, j, k, 0, v);
|
| 7000 |
+
filter.clear();
|
| 7001 |
+
}
|
| 7002 |
+
}
|
| 7003 |
+
}
|
| 7004 |
+
}
|
| 7005 |
+
|
| 7006 |
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
| 7007 |
+
struct whisper_context * ctx,
|
| 7008 |
+
struct whisper_state * state,
|
| 7009 |
+
struct whisper_full_params params,
|
| 7010 |
+
int i_segment,
|
| 7011 |
+
size_t n_segments,
|
| 7012 |
+
int seek,
|
| 7013 |
+
int n_frames,
|
| 7014 |
+
int medfilt_width,
|
| 7015 |
+
int n_threads)
|
| 7016 |
+
{
|
| 7017 |
+
const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
|
| 7018 |
+
WHISPER_ASSERT(medfilt_width % 2);
|
| 7019 |
+
WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
|
| 7020 |
+
WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
|
| 7021 |
+
|
| 7022 |
+
// FIXME: Allocating mem everytime we call this func
|
| 7023 |
+
// Our ggml buffer should be pre-allocated somewhere during init and reused
|
| 7024 |
+
// when we call this function
|
| 7025 |
+
struct ggml_init_params gparams = {
|
| 7026 |
+
/*.mem_size =*/ ctx->params.dtw_mem_size,
|
| 7027 |
+
/*.mem_buffer =*/ NULL,
|
| 7028 |
+
/*.no_alloc =*/ false,
|
| 7029 |
+
};
|
| 7030 |
+
struct ggml_context * gctx = ggml_init(gparams);
|
| 7031 |
+
|
| 7032 |
+
// Build token sequence that will be passed to decoder
|
| 7033 |
+
// sot + [lang] + text result + eot
|
| 7034 |
+
std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
|
| 7035 |
+
if (whisper_is_multilingual(ctx)) {
|
| 7036 |
+
const int lang_id = whisper_lang_id(params.language);
|
| 7037 |
+
state->lang_id = lang_id;
|
| 7038 |
+
tokens.push_back(whisper_token_lang(ctx, lang_id));
|
| 7039 |
+
}
|
| 7040 |
+
const size_t sot_sequence_length = tokens.size();
|
| 7041 |
+
tokens.push_back(whisper_token_not(ctx));
|
| 7042 |
+
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
| 7043 |
+
auto & segment = state->result_all[i];
|
| 7044 |
+
for (auto &t: segment.tokens) {
|
| 7045 |
+
// Only text tokens
|
| 7046 |
+
if (t.id < whisper_token_eot(ctx)) {
|
| 7047 |
+
tokens.push_back(t.id);
|
| 7048 |
+
}
|
| 7049 |
+
}
|
| 7050 |
+
}
|
| 7051 |
+
tokens.push_back(whisper_token_eot(ctx));
|
| 7052 |
+
|
| 7053 |
+
// Get result tokens, pass then along to decoder to get cross attention QKs
|
| 7054 |
+
// used in timestamping
|
| 7055 |
+
// Decoder already returns only alignment head QKs, already concatenated in
|
| 7056 |
+
// one tensor.
|
| 7057 |
+
whisper_kv_cache_clear(state->kv_self);
|
| 7058 |
+
whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
|
| 7059 |
+
whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
|
| 7060 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
|
| 7061 |
+
WHISPER_LOG_INFO("DECODER FAILED\n");
|
| 7062 |
+
WHISPER_ASSERT(0);
|
| 7063 |
+
}
|
| 7064 |
+
WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
|
| 7065 |
+
|
| 7066 |
+
const auto n_audio_tokens = n_frames/2;
|
| 7067 |
+
WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
|
| 7068 |
+
WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
|
| 7069 |
+
const auto n_tokens = state->aheads_cross_QKs->ne[0];
|
| 7070 |
+
const auto n_heads = state->aheads_cross_QKs->ne[2];
|
| 7071 |
+
|
| 7072 |
+
// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
|
| 7073 |
+
// tokens (i.e. discarding rows at the end of tensor)
|
| 7074 |
+
// IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
|
| 7075 |
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
| 7076 |
+
WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32);
|
| 7077 |
+
WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs));
|
| 7078 |
+
ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
|
| 7079 |
+
auto & data = state->aheads_cross_QKs_data;
|
| 7080 |
+
data.resize(n_tokens * n_audio_ctx * n_heads);
|
| 7081 |
+
ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
|
| 7082 |
+
for (int k = 0; k < n_heads; ++k) {
|
| 7083 |
+
for (int j = 0; j < n_audio_tokens; ++j) {
|
| 7084 |
+
memcpy(
|
| 7085 |
+
(char *) w->data + j * w->nb[1] + k * w->nb[2],
|
| 7086 |
+
data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
|
| 7087 |
+
n_tokens * sizeof(float)
|
| 7088 |
+
);
|
| 7089 |
+
}
|
| 7090 |
+
}
|
| 7091 |
+
|
| 7092 |
+
// Normalize - in original OpenAI code, this is done over dim=-2. In this case,
|
| 7093 |
+
// we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm
|
| 7094 |
+
// operates over columns. Afterwards, permute to a shape that facilitates mean
|
| 7095 |
+
// operation (after median filter)
|
| 7096 |
+
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
| 7097 |
+
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
| 7098 |
+
w = ggml_norm(gctx, w, 1e-9);
|
| 7099 |
+
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
| 7100 |
+
|
| 7101 |
+
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
| 7102 |
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
| 7103 |
+
// OUT: Same dims
|
| 7104 |
+
median_filter_user_data mf_user_data = {medfilt_width};
|
| 7105 |
+
w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
|
| 7106 |
+
|
| 7107 |
+
// Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
|
| 7108 |
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
| 7109 |
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
|
| 7110 |
+
w = ggml_mean(gctx, w);
|
| 7111 |
+
w = ggml_scale(gctx, w, -1.0);
|
| 7112 |
+
w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
|
| 7113 |
+
|
| 7114 |
+
// Remove SOT sequence and EOT
|
| 7115 |
+
// Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
|
| 7116 |
+
w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
|
| 7117 |
+
|
| 7118 |
+
// Compute
|
| 7119 |
+
struct ggml_cgraph * gf = ggml_new_graph(gctx);
|
| 7120 |
+
ggml_build_forward_expand(gf, w);
|
| 7121 |
+
ggml_graph_compute_with_ctx(gctx, gf, n_threads);
|
| 7122 |
+
|
| 7123 |
+
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
| 7124 |
+
|
| 7125 |
+
// Place timestamps on segments
|
| 7126 |
+
int32_t last_v = 0;
|
| 7127 |
+
auto seg_i = state->result_all.begin() + i_segment;
|
| 7128 |
+
auto tok_i = seg_i->tokens.begin();
|
| 7129 |
+
for (int i = 0; i < alignment->ne[1]; ++i) {
|
| 7130 |
+
int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
|
| 7131 |
+
if (v != last_v) {
|
| 7132 |
+
int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
|
| 7133 |
+
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
| 7134 |
+
last_v = v;
|
| 7135 |
+
|
| 7136 |
+
// Skip non-text tokens
|
| 7137 |
+
while (!(tok_i->id < whisper_token_eot(ctx))) {
|
| 7138 |
+
++tok_i;
|
| 7139 |
+
if (tok_i == seg_i->tokens.end()) {
|
| 7140 |
+
++seg_i;
|
| 7141 |
+
tok_i = seg_i->tokens.begin();
|
| 7142 |
+
}
|
| 7143 |
+
}
|
| 7144 |
+
|
| 7145 |
+
tok_i->t_dtw = timestamp;
|
| 7146 |
+
++tok_i;
|
| 7147 |
+
if (tok_i == seg_i->tokens.end()) {
|
| 7148 |
+
++seg_i;
|
| 7149 |
+
tok_i = seg_i->tokens.begin();
|
| 7150 |
+
}
|
| 7151 |
+
}
|
| 7152 |
+
}
|
| 7153 |
+
|
| 7154 |
+
// Print DTW timestamps
|
| 7155 |
+
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
| 7156 |
+
auto & segment = state->result_all[i];
|
| 7157 |
+
for (auto &t: segment.tokens) {
|
| 7158 |
+
const char * tok = whisper_token_to_str(ctx, t.id);
|
| 7159 |
+
fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
|
| 7160 |
+
}
|
| 7161 |
+
fprintf(stderr, "\n");
|
| 7162 |
+
}*/
|
| 7163 |
+
|
| 7164 |
+
ggml_free(gctx);
|
| 7165 |
+
}
|
| 7166 |
+
|
| 7167 |
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
| 7168 |
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
| 7169 |
g_state.log_callback_user_data = user_data;
|
|
@@ -84,9 +84,45 @@ extern "C" {
|
|
| 84 |
typedef int32_t whisper_token;
|
| 85 |
typedef int32_t whisper_seq_id;
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
struct whisper_context_params {
|
| 88 |
bool use_gpu;
|
| 89 |
int gpu_device; // CUDA device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
};
|
| 91 |
|
| 92 |
typedef struct whisper_token_data {
|
|
@@ -103,6 +139,11 @@ extern "C" {
|
|
| 103 |
int64_t t0; // start time of the token
|
| 104 |
int64_t t1; // end time of the token
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
float vlen; // voice length of the token
|
| 107 |
} whisper_token_data;
|
| 108 |
|
|
|
|
| 84 |
typedef int32_t whisper_token;
|
| 85 |
typedef int32_t whisper_seq_id;
|
| 86 |
|
| 87 |
+
enum whisper_alignment_heads_preset {
|
| 88 |
+
WHISPER_AHEADS_NONE,
|
| 89 |
+
WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
|
| 90 |
+
WHISPER_AHEADS_CUSTOM,
|
| 91 |
+
WHISPER_AHEADS_TINY_EN,
|
| 92 |
+
WHISPER_AHEADS_TINY,
|
| 93 |
+
WHISPER_AHEADS_BASE_EN,
|
| 94 |
+
WHISPER_AHEADS_BASE,
|
| 95 |
+
WHISPER_AHEADS_SMALL_EN,
|
| 96 |
+
WHISPER_AHEADS_SMALL,
|
| 97 |
+
WHISPER_AHEADS_MEDIUM_EN,
|
| 98 |
+
WHISPER_AHEADS_MEDIUM,
|
| 99 |
+
WHISPER_AHEADS_LARGE_V1,
|
| 100 |
+
WHISPER_AHEADS_LARGE_V2,
|
| 101 |
+
WHISPER_AHEADS_LARGE_V3,
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
typedef struct whisper_ahead {
|
| 105 |
+
int n_text_layer;
|
| 106 |
+
int n_head;
|
| 107 |
+
} whisper_ahead;
|
| 108 |
+
|
| 109 |
+
typedef struct whisper_aheads {
|
| 110 |
+
size_t n_heads;
|
| 111 |
+
const whisper_ahead * heads;
|
| 112 |
+
} whisper_aheads;
|
| 113 |
+
|
| 114 |
struct whisper_context_params {
|
| 115 |
bool use_gpu;
|
| 116 |
int gpu_device; // CUDA device
|
| 117 |
+
|
| 118 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 119 |
+
bool dtw_token_timestamps;
|
| 120 |
+
enum whisper_alignment_heads_preset dtw_aheads_preset;
|
| 121 |
+
|
| 122 |
+
int dtw_n_top;
|
| 123 |
+
struct whisper_aheads dtw_aheads;
|
| 124 |
+
|
| 125 |
+
size_t dtw_mem_size; // TODO: remove
|
| 126 |
};
|
| 127 |
|
| 128 |
typedef struct whisper_token_data {
|
|
|
|
| 139 |
int64_t t0; // start time of the token
|
| 140 |
int64_t t1; // end time of the token
|
| 141 |
|
| 142 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 143 |
+
// do not use if you haven't computed token-level timestamps with dtw
|
| 144 |
+
// Roughly corresponds to the moment in audio in which the token was output
|
| 145 |
+
int64_t t_dtw;
|
| 146 |
+
|
| 147 |
float vlen; // voice length of the token
|
| 148 |
} whisper_token_data;
|
| 149 |
|