Spaces:
Running
Running
Karthick
commited on
whisper : support no_speech_thold (#2625)
Browse files* Implement no_speech_thold
no_speech_thold functionality is on par with OpenAI's whisper
* Addressed review comments
- include/whisper.h +1 -1
- src/whisper.cpp +60 -32
include/whisper.h
CHANGED
|
@@ -534,7 +534,7 @@ extern "C" {
|
|
| 534 |
float temperature_inc;
|
| 535 |
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
| 536 |
float logprob_thold;
|
| 537 |
-
float no_speech_thold;
|
| 538 |
|
| 539 |
struct {
|
| 540 |
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
|
|
|
| 534 |
float temperature_inc;
|
| 535 |
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
| 536 |
float logprob_thold;
|
| 537 |
+
float no_speech_thold;
|
| 538 |
|
| 539 |
struct {
|
| 540 |
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
src/whisper.cpp
CHANGED
|
@@ -867,6 +867,7 @@ struct whisper_state {
|
|
| 867 |
whisper_token tid_last;
|
| 868 |
|
| 869 |
std::vector<float> energy; // PCM signal energy
|
|
|
|
| 870 |
|
| 871 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 872 |
whisper_aheads_masks aheads_masks;
|
|
@@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
| 4825 |
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
| 4826 |
};
|
| 4827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4828 |
// process the logits for the selected decoder
|
| 4829 |
// - applies logit filters
|
| 4830 |
// - computes logprobs and probs
|
|
@@ -4886,7 +4923,7 @@ static void whisper_process_logits(
|
|
| 4886 |
|
| 4887 |
// suppress sot and nosp tokens
|
| 4888 |
logits[vocab.token_sot] = -INFINITY;
|
| 4889 |
-
logits[vocab.token_nosp] = -INFINITY;
|
| 4890 |
|
| 4891 |
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
| 4892 |
if (params.tdrz_enable == false) {
|
|
@@ -4985,24 +5022,7 @@ static void whisper_process_logits(
|
|
| 4985 |
}
|
| 4986 |
|
| 4987 |
// populate the logprobs array (log_softmax)
|
| 4988 |
-
|
| 4989 |
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
| 4990 |
-
float logsumexp = 0.0f;
|
| 4991 |
-
for (int i = 0; i < n_logits; ++i) {
|
| 4992 |
-
if (logits[i] > -INFINITY) {
|
| 4993 |
-
logsumexp += expf(logits[i] - logit_max);
|
| 4994 |
-
}
|
| 4995 |
-
}
|
| 4996 |
-
logsumexp = logf(logsumexp) + logit_max;
|
| 4997 |
-
|
| 4998 |
-
for (int i = 0; i < n_logits; ++i) {
|
| 4999 |
-
if (logits[i] > -INFINITY) {
|
| 5000 |
-
logprobs[i] = logits[i] - logsumexp;
|
| 5001 |
-
} else {
|
| 5002 |
-
logprobs[i] = -INFINITY;
|
| 5003 |
-
}
|
| 5004 |
-
}
|
| 5005 |
-
}
|
| 5006 |
|
| 5007 |
// if sum of probability over timestamps is above any other token, sample timestamp
|
| 5008 |
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
|
@@ -5060,15 +5080,7 @@ static void whisper_process_logits(
|
|
| 5060 |
}
|
| 5061 |
|
| 5062 |
// compute probs
|
| 5063 |
-
|
| 5064 |
-
for (int i = 0; i < n_logits; ++i) {
|
| 5065 |
-
if (logits[i] == -INFINITY) {
|
| 5066 |
-
probs[i] = 0.0f;
|
| 5067 |
-
} else {
|
| 5068 |
-
probs[i] = expf(logprobs[i]);
|
| 5069 |
-
}
|
| 5070 |
-
}
|
| 5071 |
-
}
|
| 5072 |
|
| 5073 |
#if 0
|
| 5074 |
// print first 100 logits - token string : logit
|
|
@@ -5647,6 +5659,18 @@ int whisper_full_with_state(
|
|
| 5647 |
return -8;
|
| 5648 |
}
|
| 5649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5650 |
{
|
| 5651 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5652 |
|
|
@@ -6038,8 +6062,9 @@ int whisper_full_with_state(
|
|
| 6038 |
if (it != (int) temperatures.size() - 1) {
|
| 6039 |
const auto & decoder = state->decoders[best_decoder_id];
|
| 6040 |
|
| 6041 |
-
if (decoder.failed ||
|
| 6042 |
-
|
|
|
|
| 6043 |
success = false;
|
| 6044 |
state->n_fail_p++;
|
| 6045 |
}
|
|
@@ -6068,6 +6093,9 @@ int whisper_full_with_state(
|
|
| 6068 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 6069 |
const auto n_segments_before = state->result_all.size();
|
| 6070 |
|
|
|
|
|
|
|
|
|
|
| 6071 |
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
| 6072 |
|
| 6073 |
// update prompt_past
|
|
@@ -6076,11 +6104,11 @@ int whisper_full_with_state(
|
|
| 6076 |
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
| 6077 |
}
|
| 6078 |
|
| 6079 |
-
for (int i = 0; i < result_len; ++i) {
|
| 6080 |
prompt_past.push_back(tokens_cur[i].id);
|
| 6081 |
}
|
| 6082 |
|
| 6083 |
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
| 6084 |
int i0 = 0;
|
| 6085 |
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
| 6086 |
|
|
|
|
| 867 |
whisper_token tid_last;
|
| 868 |
|
| 869 |
std::vector<float> energy; // PCM signal energy
|
| 870 |
+
float no_speech_prob = 0.0f;
|
| 871 |
|
| 872 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 873 |
whisper_aheads_masks aheads_masks;
|
|
|
|
| 4826 |
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
| 4827 |
};
|
| 4828 |
|
| 4829 |
+
static void whisper_compute_logprobs(
|
| 4830 |
+
const std::vector<float> & logits,
|
| 4831 |
+
const int n_logits,
|
| 4832 |
+
std::vector<float> & logprobs) {
|
| 4833 |
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
| 4834 |
+
float logsumexp = 0.0f;
|
| 4835 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 4836 |
+
if (logits[i] > -INFINITY) {
|
| 4837 |
+
logsumexp += expf(logits[i] - logit_max);
|
| 4838 |
+
}
|
| 4839 |
+
}
|
| 4840 |
+
logsumexp = logf(logsumexp) + logit_max;
|
| 4841 |
+
|
| 4842 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 4843 |
+
if (logits[i] > -INFINITY) {
|
| 4844 |
+
logprobs[i] = logits[i] - logsumexp;
|
| 4845 |
+
} else {
|
| 4846 |
+
logprobs[i] = -INFINITY;
|
| 4847 |
+
}
|
| 4848 |
+
}
|
| 4849 |
+
}
|
| 4850 |
+
|
| 4851 |
+
static void whisper_compute_probs(
|
| 4852 |
+
const std::vector<float> & logits,
|
| 4853 |
+
const int n_logits,
|
| 4854 |
+
const std::vector<float> & logprobs,
|
| 4855 |
+
std::vector<float> & probs) {
|
| 4856 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 4857 |
+
if (logits[i] == -INFINITY) {
|
| 4858 |
+
probs[i] = 0.0f;
|
| 4859 |
+
} else {
|
| 4860 |
+
probs[i] = expf(logprobs[i]);
|
| 4861 |
+
}
|
| 4862 |
+
}
|
| 4863 |
+
}
|
| 4864 |
+
|
| 4865 |
// process the logits for the selected decoder
|
| 4866 |
// - applies logit filters
|
| 4867 |
// - computes logprobs and probs
|
|
|
|
| 4923 |
|
| 4924 |
// suppress sot and nosp tokens
|
| 4925 |
logits[vocab.token_sot] = -INFINITY;
|
| 4926 |
+
logits[vocab.token_nosp] = -INFINITY;
|
| 4927 |
|
| 4928 |
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
| 4929 |
if (params.tdrz_enable == false) {
|
|
|
|
| 5022 |
}
|
| 5023 |
|
| 5024 |
// populate the logprobs array (log_softmax)
|
| 5025 |
+
whisper_compute_logprobs(logits, n_logits, logprobs);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5026 |
|
| 5027 |
// if sum of probability over timestamps is above any other token, sample timestamp
|
| 5028 |
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
|
|
|
| 5080 |
}
|
| 5081 |
|
| 5082 |
// compute probs
|
| 5083 |
+
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5084 |
|
| 5085 |
#if 0
|
| 5086 |
// print first 100 logits - token string : logit
|
|
|
|
| 5659 |
return -8;
|
| 5660 |
}
|
| 5661 |
|
| 5662 |
+
// Calculate no_speech probability after first decode.
|
| 5663 |
+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
| 5664 |
+
{
|
| 5665 |
+
const int n_logits = ctx->vocab.id_to_token.size();
|
| 5666 |
+
std::vector<float> logprobs(n_logits);
|
| 5667 |
+
std::vector<float> probs(n_logits);
|
| 5668 |
+
|
| 5669 |
+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
| 5670 |
+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
| 5671 |
+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
| 5672 |
+
}
|
| 5673 |
+
|
| 5674 |
{
|
| 5675 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5676 |
|
|
|
|
| 6062 |
if (it != (int) temperatures.size() - 1) {
|
| 6063 |
const auto & decoder = state->decoders[best_decoder_id];
|
| 6064 |
|
| 6065 |
+
if (decoder.failed ||
|
| 6066 |
+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
| 6067 |
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
| 6068 |
success = false;
|
| 6069 |
state->n_fail_p++;
|
| 6070 |
}
|
|
|
|
| 6093 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 6094 |
const auto n_segments_before = state->result_all.size();
|
| 6095 |
|
| 6096 |
+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
| 6097 |
+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
| 6098 |
+
|
| 6099 |
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
| 6100 |
|
| 6101 |
// update prompt_past
|
|
|
|
| 6104 |
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
| 6105 |
}
|
| 6106 |
|
| 6107 |
+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
| 6108 |
prompt_past.push_back(tokens_cur[i].id);
|
| 6109 |
}
|
| 6110 |
|
| 6111 |
+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
| 6112 |
int i0 = 0;
|
| 6113 |
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
| 6114 |
|