Spaces:
Sleeping
Sleeping
whisper: use global cache for sin/cos vals and Hann window (#2194)
Browse files- also rename Hanning to Hann as it's named after Julius von Hann
as per Wikipedia
- whisper.cpp +54 -43
whisper.cpp
CHANGED
|
@@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
| 2857 |
}
|
| 2858 |
|
| 2859 |
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
| 2860 |
-
|
| 2861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2862 |
|
| 2863 |
-
|
| 2864 |
-
|
| 2865 |
-
|
| 2866 |
-
|
| 2867 |
-
|
| 2868 |
-
|
| 2869 |
-
|
| 2870 |
-
|
| 2871 |
-
cos_vals[i] = cosf(theta);
|
| 2872 |
}
|
| 2873 |
-
|
| 2874 |
}
|
| 2875 |
|
| 2876 |
// naive Discrete Fourier Transform
|
|
@@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2888 |
|
| 2889 |
for (int n = 0; n < N; n++) {
|
| 2890 |
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
| 2891 |
-
re += in[n]*cos_vals[idx]; // cos(t)
|
| 2892 |
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
| 2893 |
}
|
| 2894 |
|
| 2895 |
out[k*2 + 0] = re;
|
|
@@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2940 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 2941 |
for (int k = 0; k < N/2; k++) {
|
| 2942 |
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
| 2943 |
-
float re = cos_vals[idx]; // cos(t)
|
| 2944 |
-
float im = -sin_vals[idx]; // sin(t)
|
| 2945 |
|
| 2946 |
float re_odd = odd_fft[2*k + 0];
|
| 2947 |
float im_odd = odd_fft[2*k + 1];
|
|
@@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2954 |
}
|
| 2955 |
}
|
| 2956 |
|
| 2957 |
-
static
|
| 2958 |
-
if (output.size() < static_cast<size_t>(length)) {
|
| 2959 |
-
output.resize(length);
|
| 2960 |
-
}
|
| 2961 |
-
int offset = -1;
|
| 2962 |
-
if (periodic) {
|
| 2963 |
-
offset = 0;
|
| 2964 |
-
}
|
| 2965 |
-
for (int i = 0; i < length; i++) {
|
| 2966 |
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
| 2967 |
-
}
|
| 2968 |
-
|
| 2969 |
-
return true;
|
| 2970 |
-
}
|
| 2971 |
-
|
| 2972 |
-
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
|
| 2973 |
int n_samples, int frame_size, int frame_step, int n_threads,
|
| 2974 |
const whisper_filters & filters, whisper_mel & mel) {
|
| 2975 |
std::vector<float> fft_in(frame_size, 0.0);
|
|
@@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
| 2984 |
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
| 2985 |
const int offset = i * frame_step;
|
| 2986 |
|
| 2987 |
-
// apply
|
| 2988 |
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
| 2989 |
fft_in[j] = hann[j] * samples[offset + j];
|
| 2990 |
}
|
|
@@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
|
|
| 3051 |
whisper_mel & mel) {
|
| 3052 |
const int64_t t_start_us = ggml_time_us();
|
| 3053 |
|
| 3054 |
-
//
|
| 3055 |
-
|
| 3056 |
-
|
| 3057 |
-
|
| 3058 |
-
|
| 3059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3060 |
|
| 3061 |
// Calculate the length of padding
|
| 3062 |
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
|
@@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
|
|
| 3086 |
std::vector<std::thread> workers(n_threads - 1);
|
| 3087 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 3088 |
workers[iw] = std::thread(
|
| 3089 |
-
log_mel_spectrogram_worker_thread, iw + 1,
|
| 3090 |
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
| 3091 |
std::cref(filters), std::ref(mel));
|
| 3092 |
}
|
|
@@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
| 3246 |
#endif
|
| 3247 |
|
| 3248 |
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
| 3249 |
-
fill_sin_cos_table();
|
| 3250 |
-
|
| 3251 |
whisper_state * state = new whisper_state;
|
| 3252 |
|
| 3253 |
state->backend = whisper_backend_init(ctx->params);
|
|
@@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
| 7235 |
// operation (after median filter)
|
| 7236 |
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
| 7237 |
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
| 7238 |
-
w = ggml_norm(gctx, w, 1e-
|
| 7239 |
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
| 7240 |
|
| 7241 |
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
|
|
|
| 2857 |
}
|
| 2858 |
|
| 2859 |
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
| 2860 |
+
namespace {
|
| 2861 |
+
struct whisper_global_cache {
|
| 2862 |
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
| 2863 |
+
// We can use precalculated values to speed up the process.
|
| 2864 |
+
float sin_vals[SIN_COS_N_COUNT];
|
| 2865 |
+
float cos_vals[SIN_COS_N_COUNT];
|
| 2866 |
+
|
| 2867 |
+
// Hann window (Use cosf to eliminate difference)
|
| 2868 |
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
| 2869 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
| 2870 |
+
float hann_window[WHISPER_N_FFT];
|
| 2871 |
+
float hann_window2x[WHISPER_N_FFT * 2];
|
| 2872 |
+
|
| 2873 |
+
whisper_global_cache() {
|
| 2874 |
+
fill_sin_cos_table();
|
| 2875 |
+
#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr)
|
| 2876 |
+
FILL_HANN_WINDOW(hann_window);
|
| 2877 |
+
FILL_HANN_WINDOW(hann_window2x);
|
| 2878 |
+
}
|
| 2879 |
+
|
| 2880 |
+
void fill_sin_cos_table() {
|
| 2881 |
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
| 2882 |
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
| 2883 |
+
sin_vals[i] = sinf(theta);
|
| 2884 |
+
cos_vals[i] = cosf(theta);
|
| 2885 |
+
}
|
| 2886 |
+
}
|
| 2887 |
|
| 2888 |
+
void fill_hann_window(int length, bool periodic, float* output) {
|
| 2889 |
+
int offset = -1;
|
| 2890 |
+
if (periodic) {
|
| 2891 |
+
offset = 0;
|
| 2892 |
+
}
|
| 2893 |
+
for (int i = 0; i < length; i++) {
|
| 2894 |
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
| 2895 |
+
}
|
|
|
|
| 2896 |
}
|
| 2897 |
+
} global_cache;
|
| 2898 |
}
|
| 2899 |
|
| 2900 |
// naive Discrete Fourier Transform
|
|
|
|
| 2912 |
|
| 2913 |
for (int n = 0; n < N; n++) {
|
| 2914 |
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
| 2915 |
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
| 2916 |
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
| 2917 |
}
|
| 2918 |
|
| 2919 |
out[k*2 + 0] = re;
|
|
|
|
| 2964 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 2965 |
for (int k = 0; k < N/2; k++) {
|
| 2966 |
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
| 2967 |
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
| 2968 |
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
| 2969 |
|
| 2970 |
float re_odd = odd_fft[2*k + 0];
|
| 2971 |
float im_odd = odd_fft[2*k + 1];
|
|
|
|
| 2978 |
}
|
| 2979 |
}
|
| 2980 |
|
| 2981 |
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2982 |
int n_samples, int frame_size, int frame_step, int n_threads,
|
| 2983 |
const whisper_filters & filters, whisper_mel & mel) {
|
| 2984 |
std::vector<float> fft_in(frame_size, 0.0);
|
|
|
|
| 2993 |
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
| 2994 |
const int offset = i * frame_step;
|
| 2995 |
|
| 2996 |
+
// apply Hann window (~10% faster)
|
| 2997 |
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
| 2998 |
fft_in[j] = hann[j] * samples[offset + j];
|
| 2999 |
}
|
|
|
|
| 3060 |
whisper_mel & mel) {
|
| 3061 |
const int64_t t_start_us = ggml_time_us();
|
| 3062 |
|
| 3063 |
+
// Hann window
|
| 3064 |
+
const float * hann = nullptr;
|
| 3065 |
+
if (frame_size == WHISPER_N_FFT) {
|
| 3066 |
+
hann = global_cache.hann_window;
|
| 3067 |
+
} else if (frame_size == 2 * WHISPER_N_FFT) {
|
| 3068 |
+
hann = global_cache.hann_window2x;
|
| 3069 |
+
} else {
|
| 3070 |
+
WHISPER_ASSERT(false && "Unsupported frame_size");
|
| 3071 |
+
return false;
|
| 3072 |
+
}
|
| 3073 |
|
| 3074 |
// Calculate the length of padding
|
| 3075 |
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
|
|
|
| 3099 |
std::vector<std::thread> workers(n_threads - 1);
|
| 3100 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 3101 |
workers[iw] = std::thread(
|
| 3102 |
+
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
| 3103 |
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
| 3104 |
std::cref(filters), std::ref(mel));
|
| 3105 |
}
|
|
|
|
| 3259 |
#endif
|
| 3260 |
|
| 3261 |
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
|
|
|
|
| 3262 |
whisper_state * state = new whisper_state;
|
| 3263 |
|
| 3264 |
state->backend = whisper_backend_init(ctx->params);
|
|
|
|
| 7246 |
// operation (after median filter)
|
| 7247 |
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
| 7248 |
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
| 7249 |
+
w = ggml_norm(gctx, w, 1e-9f);
|
| 7250 |
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
| 7251 |
|
| 7252 |
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|