Spaces:
Sleeping
Sleeping
wip : experimental color coding of tokens based on probabilities
Browse files- main.cpp +36 -9
- whisper.cpp +88 -47
- whisper.h +9 -0
main.cpp
CHANGED
|
@@ -5,12 +5,20 @@
|
|
| 5 |
#define DR_WAV_IMPLEMENTATION
|
| 6 |
#include "dr_wav.h"
|
| 7 |
|
|
|
|
| 8 |
#include <fstream>
|
| 9 |
#include <cstdio>
|
| 10 |
#include <string>
|
| 11 |
#include <thread>
|
| 12 |
#include <vector>
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
// 500 -> 00:05.000
|
| 15 |
// 6000 -> 01:00.000
|
| 16 |
std::string to_timestamp(int64_t t) {
|
|
@@ -41,6 +49,7 @@ struct whisper_params {
|
|
| 41 |
bool output_vtt = false;
|
| 42 |
bool output_srt = false;
|
| 43 |
bool print_special_tokens = false;
|
|
|
|
| 44 |
bool no_timestamps = false;
|
| 45 |
|
| 46 |
std::string language = "en";
|
|
@@ -87,6 +96,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 87 |
params.output_srt = true;
|
| 88 |
} else if (arg == "-ps" || arg == "--print_special") {
|
| 89 |
params.print_special_tokens = true;
|
|
|
|
|
|
|
| 90 |
} else if (arg == "-nt" || arg == "--no_timestamps") {
|
| 91 |
params.no_timestamps = true;
|
| 92 |
} else if (arg == "-m" || arg == "--model") {
|
|
@@ -122,6 +133,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 122 |
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
|
| 123 |
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
|
| 124 |
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
|
|
|
| 125 |
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
| 126 |
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
| 127 |
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
|
|
@@ -222,7 +234,7 @@ int main(int argc, char ** argv) {
|
|
| 222 |
{
|
| 223 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 224 |
|
| 225 |
-
wparams.print_realtime =
|
| 226 |
wparams.print_progress = false;
|
| 227 |
wparams.print_timestamps = !params.no_timestamps;
|
| 228 |
wparams.print_special_tokens = params.print_special_tokens;
|
|
@@ -242,16 +254,34 @@ int main(int argc, char ** argv) {
|
|
| 242 |
|
| 243 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 244 |
for (int i = 0; i < n_segments; ++i) {
|
| 245 |
-
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 246 |
-
|
| 247 |
if (params.no_timestamps) {
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
} else {
|
| 251 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 252 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 253 |
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
}
|
| 256 |
}
|
| 257 |
}
|
|
@@ -260,7 +290,6 @@ int main(int argc, char ** argv) {
|
|
| 260 |
|
| 261 |
// output to text file
|
| 262 |
if (params.output_txt) {
|
| 263 |
-
|
| 264 |
const auto fname_txt = fname_inp + ".txt";
|
| 265 |
std::ofstream fout_txt(fname_txt);
|
| 266 |
if (!fout_txt.is_open()) {
|
|
@@ -279,7 +308,6 @@ int main(int argc, char ** argv) {
|
|
| 279 |
|
| 280 |
// output to VTT file
|
| 281 |
if (params.output_vtt) {
|
| 282 |
-
|
| 283 |
const auto fname_vtt = fname_inp + ".vtt";
|
| 284 |
std::ofstream fout_vtt(fname_vtt);
|
| 285 |
if (!fout_vtt.is_open()) {
|
|
@@ -304,7 +332,6 @@ int main(int argc, char ** argv) {
|
|
| 304 |
|
| 305 |
// output to SRT file
|
| 306 |
if (params.output_srt) {
|
| 307 |
-
|
| 308 |
const auto fname_srt = fname_inp + ".srt";
|
| 309 |
std::ofstream fout_srt(fname_srt);
|
| 310 |
if (!fout_srt.is_open()) {
|
|
|
|
| 5 |
#define DR_WAV_IMPLEMENTATION
|
| 6 |
#include "dr_wav.h"
|
| 7 |
|
| 8 |
+
#include <cmath>
|
| 9 |
#include <fstream>
|
| 10 |
#include <cstdio>
|
| 11 |
#include <string>
|
| 12 |
#include <thread>
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
+
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
| 16 |
+
// Lowest is red, middle is yellow, highest is green.
|
| 17 |
+
const std::vector<std::string> k_colors = {
|
| 18 |
+
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
|
| 19 |
+
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
// 500 -> 00:05.000
|
| 23 |
// 6000 -> 01:00.000
|
| 24 |
std::string to_timestamp(int64_t t) {
|
|
|
|
| 49 |
bool output_vtt = false;
|
| 50 |
bool output_srt = false;
|
| 51 |
bool print_special_tokens = false;
|
| 52 |
+
bool print_colors = false;
|
| 53 |
bool no_timestamps = false;
|
| 54 |
|
| 55 |
std::string language = "en";
|
|
|
|
| 96 |
params.output_srt = true;
|
| 97 |
} else if (arg == "-ps" || arg == "--print_special") {
|
| 98 |
params.print_special_tokens = true;
|
| 99 |
+
} else if (arg == "-pc" || arg == "--print_colors") {
|
| 100 |
+
params.print_colors = true;
|
| 101 |
} else if (arg == "-nt" || arg == "--no_timestamps") {
|
| 102 |
params.no_timestamps = true;
|
| 103 |
} else if (arg == "-m" || arg == "--model") {
|
|
|
|
| 133 |
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
|
| 134 |
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
|
| 135 |
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
| 136 |
+
fprintf(stderr, " -pc, --print_colors print colors\n");
|
| 137 |
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
| 138 |
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
| 139 |
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
|
|
|
|
| 234 |
{
|
| 235 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 236 |
|
| 237 |
+
wparams.print_realtime = !params.print_colors;
|
| 238 |
wparams.print_progress = false;
|
| 239 |
wparams.print_timestamps = !params.no_timestamps;
|
| 240 |
wparams.print_special_tokens = params.print_special_tokens;
|
|
|
|
| 254 |
|
| 255 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 256 |
for (int i = 0; i < n_segments; ++i) {
|
|
|
|
|
|
|
| 257 |
if (params.no_timestamps) {
|
| 258 |
+
if (params.print_colors) {
|
| 259 |
+
// TODO
|
| 260 |
+
} else {
|
| 261 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 262 |
+
printf("%s", text);
|
| 263 |
+
fflush(stdout);
|
| 264 |
+
}
|
| 265 |
} else {
|
| 266 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 267 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 268 |
|
| 269 |
+
if (params.print_colors) {
|
| 270 |
+
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 271 |
+
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
| 272 |
+
const char * text = whisper_full_get_token_text(ctx, i, j);
|
| 273 |
+
const float p = whisper_full_get_token_p (ctx, i, j);
|
| 274 |
+
|
| 275 |
+
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 276 |
+
|
| 277 |
+
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 278 |
+
}
|
| 279 |
+
printf("\n");
|
| 280 |
+
} else {
|
| 281 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 282 |
+
|
| 283 |
+
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
| 284 |
+
}
|
| 285 |
}
|
| 286 |
}
|
| 287 |
}
|
|
|
|
| 290 |
|
| 291 |
// output to text file
|
| 292 |
if (params.output_txt) {
|
|
|
|
| 293 |
const auto fname_txt = fname_inp + ".txt";
|
| 294 |
std::ofstream fout_txt(fname_txt);
|
| 295 |
if (!fout_txt.is_open()) {
|
|
|
|
| 308 |
|
| 309 |
// output to VTT file
|
| 310 |
if (params.output_vtt) {
|
|
|
|
| 311 |
const auto fname_vtt = fname_inp + ".vtt";
|
| 312 |
std::ofstream fout_vtt(fname_vtt);
|
| 313 |
if (!fout_vtt.is_open()) {
|
|
|
|
| 332 |
|
| 333 |
// output to SRT file
|
| 334 |
if (params.output_srt) {
|
|
|
|
| 335 |
const auto fname_srt = fname_inp + ".srt";
|
| 336 |
std::ofstream fout_srt(fname_srt);
|
| 337 |
if (!fout_srt.is_open()) {
|
whisper.cpp
CHANGED
|
@@ -210,9 +210,12 @@ struct whisper_vocab {
|
|
| 210 |
}
|
| 211 |
};
|
| 212 |
|
| 213 |
-
struct
|
| 214 |
-
|
| 215 |
-
whisper_token id
|
|
|
|
|
|
|
|
|
|
| 216 |
};
|
| 217 |
|
| 218 |
struct whisper_segment {
|
|
@@ -220,6 +223,8 @@ struct whisper_segment {
|
|
| 220 |
int64_t t1;
|
| 221 |
|
| 222 |
std::string text;
|
|
|
|
|
|
|
| 223 |
};
|
| 224 |
|
| 225 |
// medium
|
|
@@ -407,7 +412,7 @@ struct whisper_context {
|
|
| 407 |
std::vector<float> probs;
|
| 408 |
std::vector<float> logits;
|
| 409 |
|
| 410 |
-
std::vector<
|
| 411 |
std::vector<whisper_segment> result_all;
|
| 412 |
|
| 413 |
std::vector<whisper_token> prompt_past;
|
|
@@ -1786,9 +1791,11 @@ bool whisper_decode(
|
|
| 1786 |
}
|
| 1787 |
|
| 1788 |
// the most basic sampling scheme - select the top token
|
| 1789 |
-
|
| 1790 |
const whisper_vocab & vocab,
|
| 1791 |
const float * probs) {
|
|
|
|
|
|
|
| 1792 |
int n_logits = vocab.id_to_token.size();
|
| 1793 |
|
| 1794 |
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
@@ -1798,24 +1805,33 @@ whisper_vocab::id whisper_sample_best(
|
|
| 1798 |
probs_id.push_back(std::make_pair(probs[i], i));
|
| 1799 |
}
|
| 1800 |
|
| 1801 |
-
|
| 1802 |
-
|
|
|
|
|
|
|
| 1803 |
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
-
|
| 1807 |
|
| 1808 |
-
|
| 1809 |
-
|
| 1810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1811 |
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
|
|
|
|
| 1818 |
}
|
|
|
|
|
|
|
| 1819 |
}
|
| 1820 |
|
| 1821 |
// find the top K tokens
|
|
@@ -1843,7 +1859,10 @@ whisper_vocab::id whisper_sample_best(
|
|
| 1843 |
res++;
|
| 1844 |
}
|
| 1845 |
|
| 1846 |
-
|
|
|
|
|
|
|
|
|
|
| 1847 |
}
|
| 1848 |
|
| 1849 |
// samples only from the timestamps tokens
|
|
@@ -2178,7 +2197,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
|
|
| 2178 |
|
| 2179 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2180 |
|
| 2181 |
-
return res;
|
| 2182 |
}
|
| 2183 |
|
| 2184 |
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
|
|
@@ -2343,7 +2362,7 @@ int whisper_full(
|
|
| 2343 |
int n_samples) {
|
| 2344 |
// clear old results
|
| 2345 |
auto & result_all = ctx->result_all;
|
| 2346 |
-
auto &
|
| 2347 |
|
| 2348 |
result_all.clear();
|
| 2349 |
|
|
@@ -2430,7 +2449,7 @@ int whisper_full(
|
|
| 2430 |
|
| 2431 |
// the accumulated transcription in the current interation
|
| 2432 |
int result_len = 0;
|
| 2433 |
-
|
| 2434 |
|
| 2435 |
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
|
| 2436 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
@@ -2449,28 +2468,26 @@ int whisper_full(
|
|
| 2449 |
// feel free to experiment!
|
| 2450 |
//
|
| 2451 |
{
|
| 2452 |
-
|
| 2453 |
-
whisper_token tid = whisper_token_beg(ctx);
|
| 2454 |
|
| 2455 |
-
|
| 2456 |
-
|
| 2457 |
-
tid = whisper_sample_timestamp(ctx);
|
| 2458 |
}
|
| 2459 |
|
| 2460 |
-
// update sliding window
|
| 2461 |
-
if (id > whisper_token_beg(ctx)) {
|
| 2462 |
-
seek_delta = 2*(id - whisper_token_beg(ctx));
|
| 2463 |
result_len = i + 1;
|
| 2464 |
}
|
| 2465 |
|
| 2466 |
// add it to the context
|
| 2467 |
-
prompt.push_back(id);
|
| 2468 |
-
|
| 2469 |
|
| 2470 |
//printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
|
| 2471 |
|
| 2472 |
// end of text token
|
| 2473 |
-
if (id == whisper_token_eot(ctx)) {
|
| 2474 |
if (result_len == 0) {
|
| 2475 |
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
|
| 2476 |
result_len = i + 1;
|
|
@@ -2494,25 +2511,30 @@ int whisper_full(
|
|
| 2494 |
}
|
| 2495 |
}
|
| 2496 |
|
| 2497 |
-
|
| 2498 |
|
| 2499 |
-
for (const auto & r :
|
| 2500 |
prompt_past.push_back(r.id);
|
| 2501 |
}
|
| 2502 |
|
| 2503 |
// store the text from this iteration
|
| 2504 |
-
if (
|
| 2505 |
-
|
|
|
|
| 2506 |
|
| 2507 |
std::string text = "";
|
| 2508 |
|
| 2509 |
-
for (int i = 0; i < (int)
|
| 2510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2511 |
} else {
|
| 2512 |
-
text += whisper_token_to_str(ctx,
|
| 2513 |
}
|
| 2514 |
-
if (
|
| 2515 |
-
const auto t1 =
|
| 2516 |
if (!text.empty()) {
|
| 2517 |
if (params.print_realtime) {
|
| 2518 |
if (params.print_timestamps) {
|
|
@@ -2523,14 +2545,18 @@ int whisper_full(
|
|
| 2523 |
}
|
| 2524 |
}
|
| 2525 |
|
| 2526 |
-
result_all.push_back({ t0, t1, text });
|
|
|
|
|
|
|
|
|
|
| 2527 |
}
|
| 2528 |
text = "";
|
| 2529 |
-
while (i < (int)
|
| 2530 |
i++;
|
| 2531 |
}
|
| 2532 |
i--;
|
| 2533 |
-
t0 =
|
|
|
|
| 2534 |
}
|
| 2535 |
}
|
| 2536 |
|
|
@@ -2546,7 +2572,10 @@ int whisper_full(
|
|
| 2546 |
}
|
| 2547 |
}
|
| 2548 |
|
| 2549 |
-
result_all.push_back({ t0, t1, text });
|
|
|
|
|
|
|
|
|
|
| 2550 |
}
|
| 2551 |
}
|
| 2552 |
|
|
@@ -2571,3 +2600,15 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
|
|
| 2571 |
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
| 2572 |
return ctx->result_all[i_segment].text.c_str();
|
| 2573 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
};
|
| 212 |
|
| 213 |
+
struct whisper_token_data {
|
| 214 |
+
whisper_token id; // token id
|
| 215 |
+
whisper_token tid; // forced timestamp token id
|
| 216 |
+
|
| 217 |
+
float p; // probability of the token
|
| 218 |
+
float pt; // probability of the timestamp token
|
| 219 |
};
|
| 220 |
|
| 221 |
struct whisper_segment {
|
|
|
|
| 223 |
int64_t t1;
|
| 224 |
|
| 225 |
std::string text;
|
| 226 |
+
|
| 227 |
+
std::vector<whisper_token_data> tokens;
|
| 228 |
};
|
| 229 |
|
| 230 |
// medium
|
|
|
|
| 412 |
std::vector<float> probs;
|
| 413 |
std::vector<float> logits;
|
| 414 |
|
| 415 |
+
std::vector<whisper_token_data> tokens_cur;
|
| 416 |
std::vector<whisper_segment> result_all;
|
| 417 |
|
| 418 |
std::vector<whisper_token> prompt_past;
|
|
|
|
| 1791 |
}
|
| 1792 |
|
| 1793 |
// the most basic sampling scheme - select the top token
|
| 1794 |
+
whisper_token_data whisper_sample_best(
|
| 1795 |
const whisper_vocab & vocab,
|
| 1796 |
const float * probs) {
|
| 1797 |
+
whisper_token_data result;
|
| 1798 |
+
|
| 1799 |
int n_logits = vocab.id_to_token.size();
|
| 1800 |
|
| 1801 |
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
| 1805 |
probs_id.push_back(std::make_pair(probs[i], i));
|
| 1806 |
}
|
| 1807 |
|
| 1808 |
+
{
|
| 1809 |
+
double sum_ts = 0.0;
|
| 1810 |
+
double max_ts = -1.0;
|
| 1811 |
+
double max_tx = -1.0;
|
| 1812 |
|
| 1813 |
+
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1814 |
+
max_tx = std::max(max_tx, probs_id[i].first);
|
| 1815 |
+
}
|
| 1816 |
|
| 1817 |
+
for (int i = vocab.token_beg; i < n_logits; i++) {
|
| 1818 |
+
sum_ts += probs_id[i].first;
|
| 1819 |
+
if (probs_id[i].first > max_ts) {
|
| 1820 |
+
max_ts = probs_id[i].first;
|
| 1821 |
+
result.tid = probs_id[i].second;
|
| 1822 |
+
}
|
| 1823 |
+
}
|
| 1824 |
|
| 1825 |
+
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
|
| 1826 |
+
// timestamp token
|
| 1827 |
+
if (sum_ts > max_tx) {
|
| 1828 |
+
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
| 1829 |
+
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1830 |
+
probs_id[i].first = -INFINITY;
|
| 1831 |
+
}
|
| 1832 |
}
|
| 1833 |
+
|
| 1834 |
+
result.pt = max_ts/(sum_ts + 1e-6);
|
| 1835 |
}
|
| 1836 |
|
| 1837 |
// find the top K tokens
|
|
|
|
| 1859 |
res++;
|
| 1860 |
}
|
| 1861 |
|
| 1862 |
+
result.id = probs_id[res].second;
|
| 1863 |
+
result.p = probs_id[res].first;
|
| 1864 |
+
|
| 1865 |
+
return result;
|
| 1866 |
}
|
| 1867 |
|
| 1868 |
// samples only from the timestamps tokens
|
|
|
|
| 2197 |
|
| 2198 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2199 |
|
| 2200 |
+
return res.id;
|
| 2201 |
}
|
| 2202 |
|
| 2203 |
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
|
|
|
|
| 2362 |
int n_samples) {
|
| 2363 |
// clear old results
|
| 2364 |
auto & result_all = ctx->result_all;
|
| 2365 |
+
auto & tokens_cur = ctx->tokens_cur;
|
| 2366 |
|
| 2367 |
result_all.clear();
|
| 2368 |
|
|
|
|
| 2449 |
|
| 2450 |
// the accumulated transcription in the current interation
|
| 2451 |
int result_len = 0;
|
| 2452 |
+
tokens_cur.clear();
|
| 2453 |
|
| 2454 |
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
|
| 2455 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
| 2468 |
// feel free to experiment!
|
| 2469 |
//
|
| 2470 |
{
|
| 2471 |
+
auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
|
|
|
| 2472 |
|
| 2473 |
+
if (i == 0) {
|
| 2474 |
+
token.tid = whisper_token_beg(ctx);
|
|
|
|
| 2475 |
}
|
| 2476 |
|
| 2477 |
+
// timestamp token - update sliding window
|
| 2478 |
+
if (token.id > whisper_token_beg(ctx)) {
|
| 2479 |
+
seek_delta = 2*(token.id - whisper_token_beg(ctx));
|
| 2480 |
result_len = i + 1;
|
| 2481 |
}
|
| 2482 |
|
| 2483 |
// add it to the context
|
| 2484 |
+
prompt.push_back(token.id);
|
| 2485 |
+
tokens_cur.push_back(token);
|
| 2486 |
|
| 2487 |
//printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
|
| 2488 |
|
| 2489 |
// end of text token
|
| 2490 |
+
if (token.id == whisper_token_eot(ctx)) {
|
| 2491 |
if (result_len == 0) {
|
| 2492 |
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
|
| 2493 |
result_len = i + 1;
|
|
|
|
| 2511 |
}
|
| 2512 |
}
|
| 2513 |
|
| 2514 |
+
tokens_cur.resize(result_len);
|
| 2515 |
|
| 2516 |
+
for (const auto & r : tokens_cur) {
|
| 2517 |
prompt_past.push_back(r.id);
|
| 2518 |
}
|
| 2519 |
|
| 2520 |
// store the text from this iteration
|
| 2521 |
+
if (tokens_cur.size() > 0) {
|
| 2522 |
+
int i0 = 0;
|
| 2523 |
+
auto t0 = 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
| 2524 |
|
| 2525 |
std::string text = "";
|
| 2526 |
|
| 2527 |
+
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
| 2528 |
+
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
| 2529 |
+
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
|
| 2530 |
+
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
|
| 2531 |
+
|
| 2532 |
+
if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
|
| 2533 |
} else {
|
| 2534 |
+
text += whisper_token_to_str(ctx, tokens_cur[i].id);
|
| 2535 |
}
|
| 2536 |
+
if (tokens_cur[i].id > whisper_token_beg(ctx)) {
|
| 2537 |
+
const auto t1 = 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
| 2538 |
if (!text.empty()) {
|
| 2539 |
if (params.print_realtime) {
|
| 2540 |
if (params.print_timestamps) {
|
|
|
|
| 2545 |
}
|
| 2546 |
}
|
| 2547 |
|
| 2548 |
+
result_all.push_back({ t0, t1, text, {} });
|
| 2549 |
+
for (int j = i0; j <= i; j++) {
|
| 2550 |
+
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2551 |
+
}
|
| 2552 |
}
|
| 2553 |
text = "";
|
| 2554 |
+
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
| 2555 |
i++;
|
| 2556 |
}
|
| 2557 |
i--;
|
| 2558 |
+
t0 = t1;
|
| 2559 |
+
i0 = i + 1;
|
| 2560 |
}
|
| 2561 |
}
|
| 2562 |
|
|
|
|
| 2572 |
}
|
| 2573 |
}
|
| 2574 |
|
| 2575 |
+
result_all.push_back({ t0, t1, text, {} });
|
| 2576 |
+
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
| 2577 |
+
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2578 |
+
}
|
| 2579 |
}
|
| 2580 |
}
|
| 2581 |
|
|
|
|
| 2600 |
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
| 2601 |
return ctx->result_all[i_segment].text.c_str();
|
| 2602 |
}
|
| 2603 |
+
|
| 2604 |
+
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
|
| 2605 |
+
return ctx->result_all[i_segment].tokens.size();
|
| 2606 |
+
}
|
| 2607 |
+
|
| 2608 |
+
const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 2609 |
+
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
| 2610 |
+
}
|
| 2611 |
+
|
| 2612 |
+
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 2613 |
+
return ctx->result_all[i_segment].tokens[i_token].p;
|
| 2614 |
+
}
|
whisper.h
CHANGED
|
@@ -207,6 +207,15 @@ extern "C" {
|
|
| 207 |
// Get the text of the specified segment.
|
| 208 |
WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
#ifdef __cplusplus
|
| 211 |
}
|
| 212 |
#endif
|
|
|
|
| 207 |
// Get the text of the specified segment.
|
| 208 |
WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
|
| 209 |
|
| 210 |
+
// Get number of tokens in the specified segment.
|
| 211 |
+
WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
|
| 212 |
+
|
| 213 |
+
// Get the token text of the specified token in the specified segment.
|
| 214 |
+
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
| 215 |
+
|
| 216 |
+
// Get the probability of the specified token in the specified segment.
|
| 217 |
+
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
| 218 |
+
|
| 219 |
#ifdef __cplusplus
|
| 220 |
}
|
| 221 |
#endif
|