Karthick commited on
Commit
adb5837
·
unverified ·
1 Parent(s): 7655c06

whisper : support no_speech_thold (#2625)

Browse files

* Implement no_speech_thold

no_speech_thold functionality is on par with OpenAI's whisper

* Addressed review comments

Files changed (2) hide show
  1. include/whisper.h +1 -1
  2. src/whisper.cpp +60 -32
include/whisper.h CHANGED
@@ -534,7 +534,7 @@ extern "C" {
534
  float temperature_inc;
535
  float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
536
  float logprob_thold;
537
- float no_speech_thold; // TODO: not implemented
538
 
539
  struct {
540
  int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
 
534
  float temperature_inc;
535
  float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
536
  float logprob_thold;
537
+ float no_speech_thold;
538
 
539
  struct {
540
  int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
src/whisper.cpp CHANGED
@@ -867,6 +867,7 @@ struct whisper_state {
867
  whisper_token tid_last;
868
 
869
  std::vector<float> energy; // PCM signal energy
 
870
 
871
  // [EXPERIMENTAL] Token-level timestamps with DTW
872
  whisper_aheads_masks aheads_masks;
@@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
4825
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4826
  };
4827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4828
  // process the logits for the selected decoder
4829
  // - applies logit filters
4830
  // - computes logprobs and probs
@@ -4886,7 +4923,7 @@ static void whisper_process_logits(
4886
 
4887
  // suppress sot and nosp tokens
4888
  logits[vocab.token_sot] = -INFINITY;
4889
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
4890
 
4891
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4892
  if (params.tdrz_enable == false) {
@@ -4985,24 +5022,7 @@ static void whisper_process_logits(
4985
  }
4986
 
4987
  // populate the logprobs array (log_softmax)
4988
- {
4989
- const float logit_max = *std::max_element(logits.begin(), logits.end());
4990
- float logsumexp = 0.0f;
4991
- for (int i = 0; i < n_logits; ++i) {
4992
- if (logits[i] > -INFINITY) {
4993
- logsumexp += expf(logits[i] - logit_max);
4994
- }
4995
- }
4996
- logsumexp = logf(logsumexp) + logit_max;
4997
-
4998
- for (int i = 0; i < n_logits; ++i) {
4999
- if (logits[i] > -INFINITY) {
5000
- logprobs[i] = logits[i] - logsumexp;
5001
- } else {
5002
- logprobs[i] = -INFINITY;
5003
- }
5004
- }
5005
- }
5006
 
5007
  // if sum of probability over timestamps is above any other token, sample timestamp
5008
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -5060,15 +5080,7 @@ static void whisper_process_logits(
5060
  }
5061
 
5062
  // compute probs
5063
- {
5064
- for (int i = 0; i < n_logits; ++i) {
5065
- if (logits[i] == -INFINITY) {
5066
- probs[i] = 0.0f;
5067
- } else {
5068
- probs[i] = expf(logprobs[i]);
5069
- }
5070
- }
5071
- }
5072
 
5073
  #if 0
5074
  // print first 100 logits - token string : logit
@@ -5647,6 +5659,18 @@ int whisper_full_with_state(
5647
  return -8;
5648
  }
5649
 
 
 
 
 
 
 
 
 
 
 
 
 
5650
  {
5651
  const int64_t t_start_sample_us = ggml_time_us();
5652
 
@@ -6038,8 +6062,9 @@ int whisper_full_with_state(
6038
  if (it != (int) temperatures.size() - 1) {
6039
  const auto & decoder = state->decoders[best_decoder_id];
6040
 
6041
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
6042
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
 
6043
  success = false;
6044
  state->n_fail_p++;
6045
  }
@@ -6068,6 +6093,9 @@ int whisper_full_with_state(
6068
  // [EXPERIMENTAL] Token-level timestamps with DTW
6069
  const auto n_segments_before = state->result_all.size();
6070
 
 
 
 
6071
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
6072
 
6073
  // update prompt_past
@@ -6076,11 +6104,11 @@ int whisper_full_with_state(
6076
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
6077
  }
6078
 
6079
- for (int i = 0; i < result_len; ++i) {
6080
  prompt_past.push_back(tokens_cur[i].id);
6081
  }
6082
 
6083
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
6084
  int i0 = 0;
6085
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
6086
 
 
867
  whisper_token tid_last;
868
 
869
  std::vector<float> energy; // PCM signal energy
870
+ float no_speech_prob = 0.0f;
871
 
872
  // [EXPERIMENTAL] Token-level timestamps with DTW
873
  whisper_aheads_masks aheads_masks;
 
4826
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4827
  };
4828
 
4829
+ static void whisper_compute_logprobs(
4830
+ const std::vector<float> & logits,
4831
+ const int n_logits,
4832
+ std::vector<float> & logprobs) {
4833
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4834
+ float logsumexp = 0.0f;
4835
+ for (int i = 0; i < n_logits; ++i) {
4836
+ if (logits[i] > -INFINITY) {
4837
+ logsumexp += expf(logits[i] - logit_max);
4838
+ }
4839
+ }
4840
+ logsumexp = logf(logsumexp) + logit_max;
4841
+
4842
+ for (int i = 0; i < n_logits; ++i) {
4843
+ if (logits[i] > -INFINITY) {
4844
+ logprobs[i] = logits[i] - logsumexp;
4845
+ } else {
4846
+ logprobs[i] = -INFINITY;
4847
+ }
4848
+ }
4849
+ }
4850
+
4851
+ static void whisper_compute_probs(
4852
+ const std::vector<float> & logits,
4853
+ const int n_logits,
4854
+ const std::vector<float> & logprobs,
4855
+ std::vector<float> & probs) {
4856
+ for (int i = 0; i < n_logits; ++i) {
4857
+ if (logits[i] == -INFINITY) {
4858
+ probs[i] = 0.0f;
4859
+ } else {
4860
+ probs[i] = expf(logprobs[i]);
4861
+ }
4862
+ }
4863
+ }
4864
+
4865
  // process the logits for the selected decoder
4866
  // - applies logit filters
4867
  // - computes logprobs and probs
 
4923
 
4924
  // suppress sot and nosp tokens
4925
  logits[vocab.token_sot] = -INFINITY;
4926
+ logits[vocab.token_nosp] = -INFINITY;
4927
 
4928
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4929
  if (params.tdrz_enable == false) {
 
5022
  }
5023
 
5024
  // populate the logprobs array (log_softmax)
5025
+ whisper_compute_logprobs(logits, n_logits, logprobs);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5026
 
5027
  // if sum of probability over timestamps is above any other token, sample timestamp
5028
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
 
5080
  }
5081
 
5082
  // compute probs
5083
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
 
 
 
 
 
 
 
 
5084
 
5085
  #if 0
5086
  // print first 100 logits - token string : logit
 
5659
  return -8;
5660
  }
5661
 
5662
+ // Calculate no_speech probability after first decode.
5663
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
5664
+ {
5665
+ const int n_logits = ctx->vocab.id_to_token.size();
5666
+ std::vector<float> logprobs(n_logits);
5667
+ std::vector<float> probs(n_logits);
5668
+
5669
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
5670
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
5671
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
5672
+ }
5673
+
5674
  {
5675
  const int64_t t_start_sample_us = ggml_time_us();
5676
 
 
6062
  if (it != (int) temperatures.size() - 1) {
6063
  const auto & decoder = state->decoders[best_decoder_id];
6064
 
6065
+ if (decoder.failed ||
6066
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
6067
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
6068
  success = false;
6069
  state->n_fail_p++;
6070
  }
 
6093
  // [EXPERIMENTAL] Token-level timestamps with DTW
6094
  const auto n_segments_before = state->result_all.size();
6095
 
6096
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
6097
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
6098
+
6099
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
6100
 
6101
  // update prompt_past
 
6104
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
6105
  }
6106
 
6107
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
6108
  prompt_past.push_back(tokens_cur[i].id);
6109
  }
6110
 
6111
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
6112
  int i0 = 0;
6113
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
6114