stanimirovb commited on
Commit
3a04f56
·
unverified ·
1 Parent(s): 3e54141

whisper: use global cache for sin/cos vals and Hann window (#2194)

Browse files

- also rename Hanning to Hann as it's named after Julius von Hann
as per Wikipedia

Files changed (1) hide show
  1. whisper.cpp +54 -43
whisper.cpp CHANGED
@@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2857
  }
2858
 
2859
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2860
- static float sin_vals[SIN_COS_N_COUNT];
2861
- static float cos_vals[SIN_COS_N_COUNT];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2862
 
2863
- // In FFT, we frequently use sine and cosine operations with the same values.
2864
- // We can use precalculated values to speed up the process.
2865
- static void fill_sin_cos_table() {
2866
- static bool is_filled = false;
2867
- if (is_filled) return;
2868
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2869
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2870
- sin_vals[i] = sinf(theta);
2871
- cos_vals[i] = cosf(theta);
2872
  }
2873
- is_filled = true;
2874
  }
2875
 
2876
  // naive Discrete Fourier Transform
@@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2888
 
2889
  for (int n = 0; n < N; n++) {
2890
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2891
- re += in[n]*cos_vals[idx]; // cos(t)
2892
- im -= in[n]*sin_vals[idx]; // sin(t)
2893
  }
2894
 
2895
  out[k*2 + 0] = re;
@@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2940
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2941
  for (int k = 0; k < N/2; k++) {
2942
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2943
- float re = cos_vals[idx]; // cos(t)
2944
- float im = -sin_vals[idx]; // sin(t)
2945
 
2946
  float re_odd = odd_fft[2*k + 0];
2947
  float im_odd = odd_fft[2*k + 1];
@@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2954
  }
2955
  }
2956
 
2957
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2958
- if (output.size() < static_cast<size_t>(length)) {
2959
- output.resize(length);
2960
- }
2961
- int offset = -1;
2962
- if (periodic) {
2963
- offset = 0;
2964
- }
2965
- for (int i = 0; i < length; i++) {
2966
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2967
- }
2968
-
2969
- return true;
2970
- }
2971
-
2972
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
2973
  int n_samples, int frame_size, int frame_step, int n_threads,
2974
  const whisper_filters & filters, whisper_mel & mel) {
2975
  std::vector<float> fft_in(frame_size, 0.0);
@@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2984
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2985
  const int offset = i * frame_step;
2986
 
2987
- // apply Hanning window (~10% faster)
2988
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2989
  fft_in[j] = hann[j] * samples[offset + j];
2990
  }
@@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
3051
  whisper_mel & mel) {
3052
  const int64_t t_start_us = ggml_time_us();
3053
 
3054
- // Hanning window (Use cosf to eliminate difference)
3055
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
3056
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
3057
- std::vector<float> hann;
3058
- hann_window(frame_size, true, hann);
3059
-
 
 
 
 
3060
 
3061
  // Calculate the length of padding
3062
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
3086
  std::vector<std::thread> workers(n_threads - 1);
3087
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3088
  workers[iw] = std::thread(
3089
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3090
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3091
  std::cref(filters), std::ref(mel));
3092
  }
@@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3246
  #endif
3247
 
3248
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3249
- fill_sin_cos_table();
3250
-
3251
  whisper_state * state = new whisper_state;
3252
 
3253
  state->backend = whisper_backend_init(ctx->params);
@@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7235
  // operation (after median filter)
7236
  // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7237
  // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7238
- w = ggml_norm(gctx, w, 1e-9);
7239
  w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7240
 
7241
  // Pass median filter - this is done over AUDIO_TOKENS dimension.
 
2857
  }
2858
 
2859
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2860
+ namespace {
2861
+ struct whisper_global_cache {
2862
+ // In FFT, we frequently use sine and cosine operations with the same values.
2863
+ // We can use precalculated values to speed up the process.
2864
+ float sin_vals[SIN_COS_N_COUNT];
2865
+ float cos_vals[SIN_COS_N_COUNT];
2866
+
2867
+ // Hann window (Use cosf to eliminate difference)
2868
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2869
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2870
+ float hann_window[WHISPER_N_FFT];
2871
+ float hann_window2x[WHISPER_N_FFT * 2];
2872
+
2873
+ whisper_global_cache() {
2874
+ fill_sin_cos_table();
2875
+ #define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr)
2876
+ FILL_HANN_WINDOW(hann_window);
2877
+ FILL_HANN_WINDOW(hann_window2x);
2878
+ }
2879
+
2880
+ void fill_sin_cos_table() {
2881
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2882
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2883
+ sin_vals[i] = sinf(theta);
2884
+ cos_vals[i] = cosf(theta);
2885
+ }
2886
+ }
2887
 
2888
+ void fill_hann_window(int length, bool periodic, float* output) {
2889
+ int offset = -1;
2890
+ if (periodic) {
2891
+ offset = 0;
2892
+ }
2893
+ for (int i = 0; i < length; i++) {
2894
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2895
+ }
 
2896
  }
2897
+ } global_cache;
2898
  }
2899
 
2900
  // naive Discrete Fourier Transform
 
2912
 
2913
  for (int n = 0; n < N; n++) {
2914
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2915
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
2916
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2917
  }
2918
 
2919
  out[k*2 + 0] = re;
 
2964
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2965
  for (int k = 0; k < N/2; k++) {
2966
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2967
+ float re = global_cache.cos_vals[idx]; // cos(t)
2968
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2969
 
2970
  float re_odd = odd_fft[2*k + 0];
2971
  float im_odd = odd_fft[2*k + 1];
 
2978
  }
2979
  }
2980
 
2981
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2982
  int n_samples, int frame_size, int frame_step, int n_threads,
2983
  const whisper_filters & filters, whisper_mel & mel) {
2984
  std::vector<float> fft_in(frame_size, 0.0);
 
2993
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2994
  const int offset = i * frame_step;
2995
 
2996
+ // apply Hann window (~10% faster)
2997
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2998
  fft_in[j] = hann[j] * samples[offset + j];
2999
  }
 
3060
  whisper_mel & mel) {
3061
  const int64_t t_start_us = ggml_time_us();
3062
 
3063
+ // Hann window
3064
+ const float * hann = nullptr;
3065
+ if (frame_size == WHISPER_N_FFT) {
3066
+ hann = global_cache.hann_window;
3067
+ } else if (frame_size == 2 * WHISPER_N_FFT) {
3068
+ hann = global_cache.hann_window2x;
3069
+ } else {
3070
+ WHISPER_ASSERT(false && "Unsupported frame_size");
3071
+ return false;
3072
+ }
3073
 
3074
  // Calculate the length of padding
3075
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
 
3099
  std::vector<std::thread> workers(n_threads - 1);
3100
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3101
  workers[iw] = std::thread(
3102
+ log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
3103
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3104
  std::cref(filters), std::ref(mel));
3105
  }
 
3259
  #endif
3260
 
3261
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
 
3262
  whisper_state * state = new whisper_state;
3263
 
3264
  state->backend = whisper_backend_init(ctx->params);
 
7246
  // operation (after median filter)
7247
  // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7248
  // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7249
+ w = ggml_norm(gctx, w, 1e-9f);
7250
  w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7251
 
7252
  // Pass median filter - this is done over AUDIO_TOKENS dimension.