ggerganov commited on
Commit
082b647
·
unverified ·
1 Parent(s): 2de3d0b

main : fix some edge cases for word-level timestamps

Browse files
Files changed (1) hide show
  1. examples/main/main.cpp +15 -5
examples/main/main.cpp CHANGED
@@ -424,7 +424,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
424
  //tokens[j].vlen = tokens[j].pt;
425
  tokens[j].vlen = voice_length(tokens[j].text);
426
 
427
- if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last) {
428
  if (j > 0) {
429
  tokens[j - 1].t1 = tt;
430
  }
@@ -482,15 +482,26 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
482
  tokens[j + 1].t0 = tokens[j].t1;
483
  }
484
 
 
 
 
 
 
 
 
485
  tokens[j].tt0 = tokens[j].t0;
486
  tokens[j].tt1 = tokens[j].t1;
487
  }
488
 
489
  // VAD
490
  {
491
- const int hw = WHISPER_SAMPLE_RATE; // take one second of audio around the token
492
 
493
  for (int j = 0; j < n; j++) {
 
 
 
 
494
  const int64_t t0 = tokens[j].t0;
495
  const int64_t t1 = tokens[j].t1;
496
 
@@ -503,13 +514,12 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
503
  const int n = ss1 - ss0;
504
 
505
  float sum = 0.0f;
 
506
  for (int k = ss0; k < ss1; k++) {
507
  sum += pcm_avg[k];
508
  }
509
 
510
- const float avg = sum/n;
511
-
512
- const float thold = 0.5*avg;
513
 
514
  {
515
  int k = s0;
 
424
  //tokens[j].vlen = tokens[j].pt;
425
  tokens[j].vlen = voice_length(tokens[j].text);
426
 
427
+ if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
428
  if (j > 0) {
429
  tokens[j - 1].t1 = tt;
430
  }
 
482
  tokens[j + 1].t0 = tokens[j].t1;
483
  }
484
 
485
+ if (j > 0) {
486
+ if (tokens[j - 1].t1 > tokens[j].t0) {
487
+ tokens[j].t0 = tokens[j - 1].t1;
488
+ tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
489
+ }
490
+ }
491
+
492
  tokens[j].tt0 = tokens[j].t0;
493
  tokens[j].tt1 = tokens[j].t1;
494
  }
495
 
496
  // VAD
497
  {
498
+ const int hw = WHISPER_SAMPLE_RATE/8;
499
 
500
  for (int j = 0; j < n; j++) {
501
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
502
+ continue;
503
+ }
504
+
505
  const int64_t t0 = tokens[j].t0;
506
  const int64_t t1 = tokens[j].t1;
507
 
 
514
  const int n = ss1 - ss0;
515
 
516
  float sum = 0.0f;
517
+
518
  for (int k = ss0; k < ss1; k++) {
519
  sum += pcm_avg[k];
520
  }
521
 
522
+ const float thold = 0.5*sum/n;
 
 
523
 
524
  {
525
  int k = s0;