ggerganov commited on
Commit
07043b9
·
unverified ·
1 Parent(s): 53613f1

ref #52 : improve greedy sampling strategy

Browse files

Force timestamp token to be sampled if the probability sum over all
timestamp tokens is above the probability of any other token

Files changed (2) hide show
  1. whisper.cpp +25 -14
  2. whisper.h +1 -1
whisper.cpp CHANGED
@@ -1784,7 +1784,7 @@ bool whisper_decode(
1784
  // the most basic sampling scheme - select the top token
1785
  whisper_vocab::id whisper_sample_best(
1786
  const whisper_vocab & vocab,
1787
- const float * probs, bool need_timestamp) {
1788
  int n_logits = vocab.id_to_token.size();
1789
 
1790
  std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
1794
  probs_id.push_back(std::make_pair(probs[i], i));
1795
  }
1796
 
1797
- const int top_k = 4;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1798
 
1799
  // find the top K tokens
 
 
1800
  std::partial_sort(
1801
  probs_id.begin(),
1802
  probs_id.begin() + top_k, probs_id.end(),
@@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
1811
  // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1812
  //}
1813
 
1814
- if (need_timestamp) {
1815
- // at the end of the 30-second audio segment, we start giving preference to time tokens
1816
- for (int i = 0; i < top_k; i++) {
1817
- if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
1818
- return probs_id[i].second;
1819
- }
1820
- }
1821
- }
1822
-
1823
  int res = 0;
1824
  while ((probs_id[res].second == vocab.token_sot ||
1825
  probs_id[res].second == vocab.token_solm ||
@@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
2155
  return 0;
2156
  }
2157
 
2158
- whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
2159
  const int64_t t_start_sample_us = ggml_time_us();
2160
 
2161
  // TODO: simplify
2162
- auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
2163
 
2164
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2165
 
@@ -2437,7 +2448,7 @@ int whisper_full(
2437
  whisper_token id = 0;
2438
  whisper_token tid = whisper_token_beg(ctx);
2439
 
2440
- id = whisper_sample_best(ctx, result_len == 0);
2441
  if (i > 0) {
2442
  tid = whisper_sample_timestamp(ctx);
2443
  }
 
1784
  // the most basic sampling scheme - select the top token
1785
  whisper_vocab::id whisper_sample_best(
1786
  const whisper_vocab & vocab,
1787
+ const float * probs) {
1788
  int n_logits = vocab.id_to_token.size();
1789
 
1790
  std::vector<std::pair<double, whisper_vocab::id>> probs_id;
 
1794
  probs_id.push_back(std::make_pair(probs[i], i));
1795
  }
1796
 
1797
+ double sum_ts = 0.0;
1798
+ double max_tx = 0.0;
1799
+
1800
+ for (int i = 0; i < vocab.token_beg; i++) {
1801
+ max_tx = std::max(max_tx, probs_id[i].first);
1802
+ }
1803
+
1804
+ for (int i = vocab.token_beg; i < n_logits; i++) {
1805
+ sum_ts += probs_id[i].first;
1806
+ }
1807
+
1808
+ // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
1809
+ // timestamp token
1810
+ if (sum_ts > max_tx) {
1811
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
1812
+ for (int i = 0; i < vocab.token_beg; i++) {
1813
+ probs_id[i].first = -INFINITY;
1814
+ }
1815
+ }
1816
 
1817
  // find the top K tokens
1818
+ const int top_k = 4;
1819
+
1820
  std::partial_sort(
1821
  probs_id.begin(),
1822
  probs_id.begin() + top_k, probs_id.end(),
 
1831
  // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1832
  //}
1833
 
 
 
 
 
 
 
 
 
 
1834
  int res = 0;
1835
  while ((probs_id[res].second == vocab.token_sot ||
1836
  probs_id[res].second == vocab.token_solm ||
 
2166
  return 0;
2167
  }
2168
 
2169
+ whisper_token whisper_sample_best(struct whisper_context * ctx) {
2170
  const int64_t t_start_sample_us = ggml_time_us();
2171
 
2172
  // TODO: simplify
2173
+ auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
2174
 
2175
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2176
 
 
2448
  whisper_token id = 0;
2449
  whisper_token tid = whisper_token_beg(ctx);
2450
 
2451
+ id = whisper_sample_best(ctx);
2452
  if (i > 0) {
2453
  tid = whisper_sample_timestamp(ctx);
2454
  }
whisper.h CHANGED
@@ -120,7 +120,7 @@ extern "C" {
120
  // You can also implement your own sampling method using the whisper_get_probs() function.
121
  // whisper_sample_best() returns the token with the highest probability
122
  // whisper_sample_timestamp() returns the most probable timestamp token
123
- WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
124
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
125
 
126
  // Return the id of the specified language, returns -1 if not found
 
120
  // You can also implement your own sampling method using the whisper_get_probs() function.
121
  // whisper_sample_best() returns the token with the highest probability
122
  // whisper_sample_timestamp() returns the most probable timestamp token
123
+ WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
124
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
125
 
126
  // Return the id of the specified language, returns -1 if not found