ggerganov commited on
Commit
0131aa6
·
unverified ·
1 Parent(s): d03f526

whisper : add batched decoding (#1486)

Browse files

* whisper : add whisper_batch

* whisper : move kv_self to whisper_state

* whisper : full batched decoding support

* whisper : fix memory leak in whisper_batch

* whisper : fix mem leak again + remove oboslete function

* whisper : clear kv cache when using whisper_decode API

* whisper : speed-up sampling

* whisper : fix decoders initializer

* bench : add batch size 5 bench

* whisper : add comment about the KV cache size

* whisper : add check for max number of decoders

* whisper : avoid starting sampling threads with bs=1

* whisper : enable beam-search by default

* cuda : sync llama.cpp fixes

Files changed (7) hide show
  1. examples/bench/bench.cpp +20 -10
  2. examples/main/main.cpp +4 -4
  3. extra/bench-all.sh +4 -3
  4. ggml-cuda.cu +188 -118
  5. ggml-cuda.h +5 -0
  6. whisper.cpp +602 -426
  7. whisper.h +3 -1
examples/bench/bench.cpp CHANGED
@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
81
  }
82
  // heat encoder
83
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
84
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
85
  return 4;
86
  }
87
 
@@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
90
 
91
  // prompt heat
92
  if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
93
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
94
  return 4;
95
  }
96
 
97
  // text-generation heat
98
  if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
99
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
100
  return 4;
101
  }
102
 
@@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
104
 
105
  // actual run
106
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
107
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
108
  return 4;
109
  }
110
 
111
- for (int i = 0; i < 16; i++) {
112
- if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
113
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
 
114
  return 4;
115
  }
116
  }
117
 
118
- for (int i = 0; i < 256; i++) {
119
- if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
120
- fprintf(stderr, "error: failed to encode model: %d\n", ret);
 
 
 
 
 
 
 
 
 
121
  return 4;
122
  }
123
  }
 
81
  }
82
  // heat encoder
83
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
84
+ fprintf(stderr, "error: failed to encode: %d\n", ret);
85
  return 4;
86
  }
87
 
 
90
 
91
  // prompt heat
92
  if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
93
+ fprintf(stderr, "error: failed to decode: %d\n", ret);
94
  return 4;
95
  }
96
 
97
  // text-generation heat
98
  if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
99
+ fprintf(stderr, "error: failed to decode: %d\n", ret);
100
  return 4;
101
  }
102
 
 
104
 
105
  // actual run
106
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
107
+ fprintf(stderr, "error: failed to encode: %d\n", ret);
108
  return 4;
109
  }
110
 
111
+ // text-generation
112
+ for (int i = 0; i < 256; i++) {
113
+ if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
114
+ fprintf(stderr, "error: failed to decode: %d\n", ret);
115
  return 4;
116
  }
117
  }
118
 
119
+ // batched decoding
120
+ for (int i = 0; i < 64; i++) {
121
+ if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
122
+ fprintf(stderr, "error: failed to decode: %d\n", ret);
123
+ return 4;
124
+ }
125
+ }
126
+
127
+ // prompt processing
128
+ for (int i = 0; i < 16; i++) {
129
+ if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
130
+ fprintf(stderr, "error: failed to decode: %d\n", ret);
131
  return 4;
132
  }
133
  }
examples/main/main.cpp CHANGED
@@ -62,8 +62,8 @@ struct whisper_params {
62
  int32_t progress_step = 5;
63
  int32_t max_context = -1;
64
  int32_t max_len = 0;
65
- int32_t best_of = 2;
66
- int32_t beam_size = -1;
67
 
68
  float word_thold = 0.01f;
69
  float entropy_thold = 2.40f;
@@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
925
  if (params.detect_language) {
926
  params.language = "auto";
927
  }
928
- fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
929
  __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
930
- params.n_threads, params.n_processors,
931
  params.language.c_str(),
932
  params.translate ? "translate" : "transcribe",
933
  params.tinydiarize ? "tdrz = 1, " : "",
 
62
  int32_t progress_step = 5;
63
  int32_t max_context = -1;
64
  int32_t max_len = 0;
65
+ int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
66
+ int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
67
 
68
  float word_thold = 0.01f;
69
  float entropy_thold = 2.40f;
 
925
  if (params.detect_language) {
926
  params.language = "auto";
927
  }
928
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
929
  __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
930
+ params.n_threads, params.n_processors, params.beam_size, params.best_of,
931
  params.language.c_str(),
932
  params.translate ? "translate" : "transcribe",
933
  params.tinydiarize ? "tdrz = 1, " : "",
extra/bench-all.sh CHANGED
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
44
  printf "\n"
45
  fi
46
 
47
- printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48
- printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
49
 
50
  for model in "${models[@]}"; do
51
  # actual run
@@ -56,6 +56,7 @@ for model in "${models[@]}"; do
56
  # parse the output:
57
  encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
58
  decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
 
59
  prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
60
  system_info=$(echo "$output" | grep "system_info")
61
  n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
@@ -94,6 +95,6 @@ for model in "${models[@]}"; do
94
  commit=$(git rev-parse --short HEAD)
95
 
96
  if [ $ret -eq 0 ]; then
97
- printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
98
  fi
99
  done
 
44
  printf "\n"
45
  fi
46
 
47
+ printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
48
+ printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
49
 
50
  for model in "${models[@]}"; do
51
  # actual run
 
56
  # parse the output:
57
  encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
58
  decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
59
+ batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
60
  prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
61
  system_info=$(echo "$output" | grep "system_info")
62
  n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
 
95
  commit=$(git rev-parse --short HEAD)
96
 
97
  if [ $ret -eq 0 ]; then
98
+ printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
99
  fi
100
  done
ggml-cuda.cu CHANGED
@@ -39,7 +39,6 @@
39
  #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
40
  #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
41
  #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
42
- #define cudaDeviceGetMemPool hipDeviceGetMemPool
43
  #define cudaDeviceProp hipDeviceProp_t
44
  #define cudaDeviceSynchronize hipDeviceSynchronize
45
  #define cudaError_t hipError_t
@@ -49,7 +48,6 @@
49
  #define cudaEvent_t hipEvent_t
50
  #define cudaEventDestroy hipEventDestroy
51
  #define cudaFree hipFree
52
- #define cudaFreeAsync hipFreeAsync
53
  #define cudaFreeHost hipHostFree
54
  #define cudaGetDevice hipGetDevice
55
  #define cudaGetDeviceCount hipGetDeviceCount
@@ -57,7 +55,6 @@
57
  #define cudaGetErrorString hipGetErrorString
58
  #define cudaGetLastError hipGetLastError
59
  #define cudaMalloc hipMalloc
60
- #define cudaMallocFromPoolAsync hipMallocFromPoolAsync
61
  #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
62
  #define cudaMemcpy hipMemcpy
63
  #define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -66,9 +63,6 @@
66
  #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
67
  #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
68
  #define cudaMemcpyKind hipMemcpyKind
69
- #define cudaMemPool_t hipMemPool_t
70
- #define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
71
- #define cudaMemPoolSetAttribute hipMemPoolSetAttribute
72
  #define cudaMemset hipMemset
73
  #define cudaMemsetAsync hipMemsetAsync
74
  #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
@@ -94,6 +88,8 @@
94
  #define CC_OFFSET_AMD 1000000
95
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
96
 
 
 
97
  // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
98
  // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
99
  // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
@@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
188
  do { \
189
  cudaError_t err_ = (err); \
190
  if (err_ != cudaSuccess) { \
191
- int dev_id; \
192
- cudaGetDevice(&dev_id); \
193
  fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
194
  cudaGetErrorString(err_)); \
195
- fprintf(stderr, "current device: %d\n", dev_id); \
196
  exit(1); \
197
  } \
198
  } while (0)
@@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
202
  do { \
203
  cublasStatus_t err_ = (err); \
204
  if (err_ != CUBLAS_STATUS_SUCCESS) { \
205
- int dev_id; \
206
- cudaGetDevice(&dev_id); \
207
  fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
208
  err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
209
- fprintf(stderr, "current device: %d\n", dev_id); \
210
  exit(1); \
211
  } \
212
  } while (0)
@@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
440
  #define CUDA_MUL_BLOCK_SIZE 256
441
  #define CUDA_GELU_BLOCK_SIZE 256
442
  #define CUDA_SILU_BLOCK_SIZE 256
 
 
443
  #define CUDA_CPY_BLOCK_SIZE 32
444
  #define CUDA_SCALE_BLOCK_SIZE 256
445
  #define CUDA_CLAMP_BLOCK_SIZE 256
@@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
472
 
473
  #define MAX_STREAMS 8
474
  static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
475
- static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
476
 
477
  struct ggml_tensor_extra_gpu {
478
  void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
561
  dst[i] = x[i] / (1.0f + expf(-x[i]));
562
  }
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
565
  #pragma unroll
566
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
990
 
991
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
992
 
993
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
994
  if (row > nrows) return;
995
 
996
  const int num_blocks_per_row = ncols / QK_K;
@@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
1094
 
1095
  static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1096
 
1097
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
1098
  if (row > nrows) return;
1099
 
1100
  const int num_blocks_per_row = ncols / QK_K;
@@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
1198
 
1199
  static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1200
 
1201
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
1202
  if (row > nrows) return;
1203
  const int num_blocks_per_row = ncols / QK_K;
1204
  const int ib0 = row*num_blocks_per_row;
@@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
1452
 
1453
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
1454
 
1455
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
1456
  if (row > nrows) return;
1457
 
1458
  const int num_blocks_per_row = ncols / QK_K;
@@ -4262,7 +4277,7 @@ template <bool need_check> static __global__ void
4262
 
4263
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
4264
  static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
4265
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
4266
 
4267
  if (row >= nrows) {
4268
  return;
@@ -4302,7 +4317,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
4302
  static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
4303
  // qk = quantized weights per x block
4304
  // qr = number of quantized weights per data value in x block
4305
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
4306
 
4307
  if (row >= nrows) {
4308
  return;
@@ -4741,7 +4756,7 @@ static __global__ void im2col_f32_f16(
4741
  int ofs0, int ofs1, int IW, int IH, int CHW,
4742
  int s0, int s1, int p0, int p1, int d0, int d1) {
4743
  const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4744
- const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4745
 
4746
  const int offset_dst =
4747
  (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
@@ -4793,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
4793
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4794
  }
4795
 
 
 
 
 
 
 
 
 
 
 
4796
  static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4797
  GGML_ASSERT(ncols % WARP_SIZE == 0);
4798
  if (ncols < 1024) {
@@ -4901,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
4901
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4902
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4903
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4904
- const dim3 block_nums(1, block_num_y, 1);
 
4905
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4906
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
4907
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4910,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
4910
  static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4911
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4912
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4913
- const dim3 block_nums(1, block_num_y, 1);
4914
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4915
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
4916
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4919,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
4919
  static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4920
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4921
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4922
- const dim3 block_nums(1, block_num_y, 1);
4923
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4924
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
4925
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4928,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
4928
  static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4929
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4930
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4931
- const dim3 block_nums(1, block_num_y, 1);
4932
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4933
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
4934
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4937,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
4937
  static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4938
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4939
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4940
- const dim3 block_nums(1, block_num_y, 1);
4941
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4942
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
4943
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4947,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
4947
  GGML_ASSERT(ncols % QK_K == 0);
4948
  const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
4949
  const int block_num_y = (nrows + ny - 1) / ny;
4950
- const dim3 block_nums(1, block_num_y, 1);
4951
  const dim3 block_dims(32, ny, 1);
4952
  dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4953
  }
@@ -4956,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
4956
  GGML_ASSERT(ncols % QK_K == 0);
4957
  const int ny = 2 / K_QUANTS_PER_ITERATION;
4958
  const int block_num_y = (nrows + ny - 1) / ny;
4959
- const dim3 block_nums(1, block_num_y, 1);
4960
  const dim3 block_dims(32, ny, 1);
4961
  dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4962
  }
@@ -4965,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
4965
  GGML_ASSERT(ncols % QK_K == 0);
4966
  const int ny = 2 / K_QUANTS_PER_ITERATION;
4967
  const int block_num_y = (nrows + ny - 1) / ny;
4968
- const dim3 block_nums(1, block_num_y, 1);
4969
  const dim3 block_dims(32, ny, 1);
4970
  dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4971
  }
@@ -4980,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
4980
  GGML_ASSERT(ncols % QK_K == 0);
4981
  const int ny = 2 / K_QUANTS_PER_ITERATION;
4982
  const int block_num_y = (nrows + ny - 1) / ny;
4983
- const dim3 block_nums(1, block_num_y, 1);
4984
  const dim3 block_dims(32, ny, 1);
4985
  dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4986
  }
@@ -4988,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
4988
  static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4989
  GGML_ASSERT(ncols % QK4_0 == 0);
4990
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4991
- const dim3 block_nums(1, block_num_y, 1);
4992
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4993
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
4994
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4997,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
4997
  static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4998
  GGML_ASSERT(ncols % QK4_1 == 0);
4999
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5000
- const dim3 block_nums(1, block_num_y, 1);
5001
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5002
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
5003
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5006,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
5006
  static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5007
  GGML_ASSERT(ncols % QK5_0 == 0);
5008
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5009
- const dim3 block_nums(1, block_num_y, 1);
5010
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5011
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
5012
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5015,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
5015
  static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5016
  GGML_ASSERT(ncols % QK5_1 == 0);
5017
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5018
- const dim3 block_nums(1, block_num_y, 1);
5019
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5020
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
5021
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5024,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
5024
  static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5025
  GGML_ASSERT(ncols % QK8_0 == 0);
5026
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5027
- const dim3 block_nums(1, block_num_y, 1);
5028
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5029
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
5030
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5033,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
5033
  static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5034
  GGML_ASSERT(ncols % QK_K == 0);
5035
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5036
- const dim3 block_nums(1, block_num_y, 1);
5037
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5038
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
5039
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5042,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
5042
  static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5043
  GGML_ASSERT(ncols % QK_K == 0);
5044
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5045
- const dim3 block_nums(1, block_num_y, 1);
5046
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5047
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
5048
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5051,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
5051
  static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5052
  GGML_ASSERT(ncols % QK_K == 0);
5053
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5054
- const dim3 block_nums(1, block_num_y, 1);
5055
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5056
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
5057
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5060,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
5060
  static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5061
  GGML_ASSERT(ncols % QK_K == 0);
5062
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5063
- const dim3 block_nums(1, block_num_y, 1);
5064
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5065
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
5066
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5069,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
5069
  static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5070
  GGML_ASSERT(ncols % QK_K == 0);
5071
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5072
- const dim3 block_nums(1, block_num_y, 1);
5073
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5074
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
5075
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5088,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
5088
  static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5089
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
5090
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5091
- const dim3 block_nums(1, block_num_y, 1);
5092
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5093
  dequantize_mul_mat_vec<1, 1, convert_f16>
5094
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -5825,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
5825
  return ptr;
5826
  }
5827
 
5828
- static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5829
- if (g_cudaMemPools[id] == nullptr) {
5830
- return ggml_cuda_pool_malloc(size, actual_size);
5831
- }
5832
- void *ptr;
5833
- CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
5834
- *actual_size = size;
5835
- return ptr;
5836
- }
5837
-
5838
  static void ggml_cuda_pool_free(void * ptr, size_t size) {
5839
  scoped_spin_lock lock(g_cuda_pool_lock);
5840
  int id;
@@ -5852,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
5852
  CUDA_CHECK(cudaFree(ptr));
5853
  }
5854
 
 
5855
 
5856
- static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5857
- if (g_cudaMemPools[id] == nullptr) {
5858
- return ggml_cuda_pool_free(ptr, actual_size);
5859
- }
5860
- CUDA_CHECK(cudaFreeAsync(ptr, stream));
5861
  }
5862
 
5863
  void ggml_init_cublas() {
@@ -5872,7 +5886,12 @@ void ggml_init_cublas() {
5872
  CUDA_CHECK(cudaDeviceSynchronize());
5873
  #endif
5874
 
5875
- CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
 
 
 
 
 
5876
  GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
5877
  int64_t total_vram = 0;
5878
  #if defined(GGML_CUDA_FORCE_MMQ)
@@ -5914,19 +5933,13 @@ void ggml_init_cublas() {
5914
  // create cublas handle
5915
  CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
5916
  CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5917
-
5918
- // configure memory pool
5919
- cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
5920
- if (err == cudaSuccess) {
5921
- size_t treshold = UINT64_MAX;
5922
- CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5923
- }
5924
  }
5925
 
5926
  // configure logging to stdout
5927
  // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
5928
 
5929
  initialized = true;
 
5930
  }
5931
  }
5932
 
@@ -6193,6 +6206,34 @@ inline void ggml_cuda_op_silu(
6193
  (void) src1_dd;
6194
  }
6195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6196
  inline void ggml_cuda_op_norm(
6197
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6198
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6514,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6514
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
6515
  GGML_ASSERT(to_fp16_cuda != nullptr);
6516
  size_t ne = row_diff*ne00;
6517
- src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
6518
  to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
6519
  }
6520
  const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6525,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
6525
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
6526
  GGML_ASSERT(to_fp16_cuda != nullptr);
6527
  size_t ne = src1_ncols*ne10;
6528
- src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
6529
  to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6530
  }
6531
  const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
6532
- size_t dst_f16_as = 0;
6533
- half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
6534
 
6535
  const half alpha_f16 = 1.0f;
6536
  const half beta_f16 = 0.0f;
@@ -6548,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
6548
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6549
  to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6550
 
6551
- if (dst_f16_as != 0) {
6552
- ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
6553
- }
6554
 
6555
  if (src0_as != 0) {
6556
- ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
6557
  }
 
6558
  if (src1_as != 0) {
6559
- ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
6560
  }
6561
  }
6562
  else {
@@ -6566,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6566
  if (src0->type != GGML_TYPE_F32) {
6567
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6568
  GGML_ASSERT(to_fp32_cuda != nullptr);
6569
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
6570
  to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6571
  }
6572
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6583,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6583
  &beta, dst_dd_i, ldc));
6584
 
6585
  if (src0_as != 0) {
6586
- ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
6587
  }
6588
  }
6589
 
@@ -7008,6 +7048,8 @@ static void ggml_cuda_op_mul_mat(
7008
  int64_t row_low[GGML_CUDA_MAX_DEVICES];
7009
  int64_t row_high[GGML_CUDA_MAX_DEVICES];
7010
 
 
 
7011
  for (int64_t id = 0; id < g_device_count; ++id) {
7012
  // by default, use all rows
7013
  row_low[id] = 0;
@@ -7035,6 +7077,8 @@ static void ggml_cuda_op_mul_mat(
7035
  continue;
7036
  }
7037
 
 
 
7038
  const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
7039
  const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
7040
 
@@ -7045,22 +7089,21 @@ static void ggml_cuda_op_mul_mat(
7045
  src0_dd[id] = (char *) src0_extra->data_device[id];
7046
  } else {
7047
  const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
7048
- src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
7049
  }
7050
 
7051
  if (src1_on_device && src1_is_contiguous) {
7052
  src1_ddf[id] = (float *) src1_extra->data_device[id];
7053
  } else {
7054
- src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
7055
  }
7056
 
7057
  if (convert_src1_to_q8_1) {
7058
- const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
7059
- src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
7060
 
7061
  if (src1_on_device && src1_is_contiguous) {
7062
  quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
7063
- // CUDA_CHECK(cudaGetLastError());
7064
  }
7065
  }
7066
 
@@ -7068,18 +7111,18 @@ static void ggml_cuda_op_mul_mat(
7068
  dst_dd[id] = (float *) dst_extra->data_device[id];
7069
  } else {
7070
  const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
7071
- dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
7072
  }
7073
  }
7074
 
7075
  // if multiple devices are used they need to wait for the main device
7076
  // here an event is recorded that signals that the main device has finished calculating the input data
7077
- if (split && g_device_count > 1) {
7078
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7079
  CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
7080
  }
7081
 
7082
- const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
7083
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
7084
  const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
7085
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7194,6 +7237,27 @@ static void ggml_cuda_op_mul_mat(
7194
  }
7195
  }
7196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7197
  // main device waits for all other devices to be finished
7198
  if (split && g_device_count > 1) {
7199
  int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7201,6 +7265,9 @@ static void ggml_cuda_op_mul_mat(
7201
 
7202
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7203
  for (int64_t id = 0; id < g_device_count; ++id) {
 
 
 
7204
  for (int64_t is = 0; is < is_max; ++is) {
7205
  CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
7206
  }
@@ -7211,21 +7278,6 @@ static void ggml_cuda_op_mul_mat(
7211
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7212
  CUDA_CHECK(cudaDeviceSynchronize());
7213
  }
7214
-
7215
- for (int64_t id = 0; id < g_device_count; ++id) {
7216
- if (src0_as[id] > 0) {
7217
- ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
7218
- }
7219
- if (src1_asf[id] > 0) {
7220
- ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
7221
- }
7222
- if (src1_asq[id] > 0) {
7223
- ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
7224
- }
7225
- if (dst_as[id] > 0) {
7226
- ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
7227
- }
7228
- }
7229
  }
7230
 
7231
  static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7252,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
7252
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
7253
  }
7254
 
 
 
 
 
 
 
 
 
7255
  static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7256
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
7257
  }
@@ -7261,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
7261
  }
7262
 
7263
  bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
 
 
7264
  const int64_t ne10 = src1->ne[0];
7265
 
7266
  const int64_t ne0 = dst->ne[0];
@@ -7412,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7412
  GGML_ASSERT(to_fp16_cuda != nullptr);
7413
 
7414
  size_t src1_as = 0;
7415
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
7416
  to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
7417
 
7418
  size_t dst_as = 0;
7419
- half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
7420
 
7421
  GGML_ASSERT(ne12 % ne02 == 0);
7422
  GGML_ASSERT(ne13 % ne03 == 0);
@@ -7470,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7470
  size_t ptrs_src_s = 0;
7471
  size_t ptrs_dst_s = 0;
7472
 
7473
- ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
7474
- ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
7475
 
7476
  dim3 block_dims(ne13, ne12);
7477
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7484,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7484
  dst->nb[2], dst->nb[3],
7485
  r2, r3);
7486
  CUDA_CHECK(cudaGetLastError());
 
7487
  CUBLAS_CHECK(
7488
  cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7489
  ne01, ne11, ne10,
@@ -7495,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7495
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7496
 
7497
  if (ptrs_src_s != 0) {
7498
- ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
7499
  }
7500
  if (ptrs_dst_s != 0) {
7501
- ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
7502
  }
7503
  }
7504
  #endif
7505
 
7506
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7507
  to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7508
- if (src1_as != 0) {
7509
- ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
7510
- }
7511
- if (dst_as != 0) {
7512
- ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
7513
- }
7514
  }
7515
 
7516
  static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7517
  const bool all_on_device =
7518
- (src0->backend == GGML_BACKEND_GPU) &&
7519
  (src1->backend == GGML_BACKEND_GPU) &&
7520
  ( dst->backend == GGML_BACKEND_GPU);
7521
 
 
 
7522
  int64_t min_compute_capability = INT_MAX;
7523
  for (int64_t id = 0; id < g_device_count; ++id) {
7524
  if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
@@ -7540,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
7540
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7541
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7542
 
7543
- if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7544
  // KQ single-batch
7545
  ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7546
- } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7547
  // KQV single-batch
7548
  ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7549
- } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
7550
  // KQ + KQV multi-batch
7551
  ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
7552
  } else if (src0->type == GGML_TYPE_F32) {
@@ -7667,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
7667
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
7668
  }
7669
 
7670
- void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7671
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7672
  }
7673
 
@@ -7782,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0;
7782
 
7783
  static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
7784
  if (g_temp_tensor_extras == nullptr) {
7785
- g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
7786
  }
7787
 
7788
  size_t alloc_index = g_temp_tensor_extra_index;
7789
- g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
7790
  ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
7791
  memset(extra, 0, sizeof(*extra));
7792
 
@@ -7953,6 +8015,8 @@ void ggml_cuda_free_scratch() {
7953
  }
7954
 
7955
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
 
 
7956
  ggml_cuda_func_t func;
7957
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
7958
  || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
@@ -7995,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
7995
  case GGML_UNARY_OP_SILU:
7996
  func = ggml_cuda_silu;
7997
  break;
 
 
 
7998
  default:
7999
  return false;
8000
  } break;
@@ -8013,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8013
  case GGML_OP_SCALE:
8014
  func = ggml_cuda_scale;
8015
  break;
 
 
 
8016
  case GGML_OP_CLAMP:
8017
  if (!any_on_device) {
8018
  return false;
@@ -8105,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda {
8105
 
8106
  ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
8107
  if (temp_tensor_extras == nullptr) {
8108
- temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
8109
  }
8110
 
8111
  size_t alloc_index = temp_tensor_extra_index;
8112
- temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
8113
  ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
8114
  memset(extra, 0, sizeof(*extra));
8115
 
 
39
  #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
40
  #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
41
  #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 
42
  #define cudaDeviceProp hipDeviceProp_t
43
  #define cudaDeviceSynchronize hipDeviceSynchronize
44
  #define cudaError_t hipError_t
 
48
  #define cudaEvent_t hipEvent_t
49
  #define cudaEventDestroy hipEventDestroy
50
  #define cudaFree hipFree
 
51
  #define cudaFreeHost hipHostFree
52
  #define cudaGetDevice hipGetDevice
53
  #define cudaGetDeviceCount hipGetDeviceCount
 
55
  #define cudaGetErrorString hipGetErrorString
56
  #define cudaGetLastError hipGetLastError
57
  #define cudaMalloc hipMalloc
 
58
  #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
59
  #define cudaMemcpy hipMemcpy
60
  #define cudaMemcpy2DAsync hipMemcpy2DAsync
 
63
  #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
64
  #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
65
  #define cudaMemcpyKind hipMemcpyKind
 
 
 
66
  #define cudaMemset hipMemset
67
  #define cudaMemsetAsync hipMemsetAsync
68
  #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
 
88
  #define CC_OFFSET_AMD 1000000
89
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
90
 
91
+ #define GGML_CUDA_MAX_NODES 8192
92
+
93
  // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
94
  // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
95
  // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
 
184
  do { \
185
  cudaError_t err_ = (err); \
186
  if (err_ != cudaSuccess) { \
187
+ int id; \
188
+ cudaGetDevice(&id); \
189
  fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
190
  cudaGetErrorString(err_)); \
191
+ fprintf(stderr, "current device: %d\n", id); \
192
  exit(1); \
193
  } \
194
  } while (0)
 
198
  do { \
199
  cublasStatus_t err_ = (err); \
200
  if (err_ != CUBLAS_STATUS_SUCCESS) { \
201
+ int id; \
202
+ cudaGetDevice(&id); \
203
  fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
204
  err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
205
+ fprintf(stderr, "current device: %d\n", id); \
206
  exit(1); \
207
  } \
208
  } while (0)
 
436
  #define CUDA_MUL_BLOCK_SIZE 256
437
  #define CUDA_GELU_BLOCK_SIZE 256
438
  #define CUDA_SILU_BLOCK_SIZE 256
439
+ #define CUDA_RELU_BLOCK_SIZE 256
440
+ #define CUDA_SQR_BLOCK_SIZE 256
441
  #define CUDA_CPY_BLOCK_SIZE 32
442
  #define CUDA_SCALE_BLOCK_SIZE 256
443
  #define CUDA_CLAMP_BLOCK_SIZE 256
 
470
 
471
  #define MAX_STREAMS 8
472
  static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
 
473
 
474
  struct ggml_tensor_extra_gpu {
475
  void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
 
558
  dst[i] = x[i] / (1.0f + expf(-x[i]));
559
  }
560
 
561
+ static __global__ void relu_f32(const float * x, float * dst, const int k) {
562
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
563
+
564
+ if (i >= k) {
565
+ return;
566
+ }
567
+ dst[i] = fmaxf(x[i], 0);
568
+ }
569
+
570
+ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
571
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
572
+
573
+ if (i >= k) {
574
+ return;
575
+ }
576
+ dst[i] = x[i] * x[i];
577
+ }
578
+
579
  static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
580
  #pragma unroll
581
  for (int mask = 16; mask > 0; mask >>= 1) {
 
1005
 
1006
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
1007
 
1008
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
1009
  if (row > nrows) return;
1010
 
1011
  const int num_blocks_per_row = ncols / QK_K;
 
1109
 
1110
  static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1111
 
1112
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
1113
  if (row > nrows) return;
1114
 
1115
  const int num_blocks_per_row = ncols / QK_K;
 
1213
 
1214
  static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1215
 
1216
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
1217
  if (row > nrows) return;
1218
  const int num_blocks_per_row = ncols / QK_K;
1219
  const int ib0 = row*num_blocks_per_row;
 
1467
 
1468
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
1469
 
1470
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
1471
  if (row > nrows) return;
1472
 
1473
  const int num_blocks_per_row = ncols / QK_K;
 
4277
 
4278
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
4279
  static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
4280
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
4281
 
4282
  if (row >= nrows) {
4283
  return;
 
4317
  static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
4318
  // qk = quantized weights per x block
4319
  // qr = number of quantized weights per data value in x block
4320
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
4321
 
4322
  if (row >= nrows) {
4323
  return;
 
4756
  int ofs0, int ofs1, int IW, int IH, int CHW,
4757
  int s0, int s1, int p0, int p1, int d0, int d1) {
4758
  const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4759
+ const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4760
 
4761
  const int offset_dst =
4762
  (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
 
4808
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4809
  }
4810
 
4811
+ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
4812
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
4813
+ relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4814
+ }
4815
+
4816
+ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
4817
+ const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
4818
+ sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4819
+ }
4820
+
4821
  static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4822
  GGML_ASSERT(ncols % WARP_SIZE == 0);
4823
  if (ncols < 1024) {
 
4926
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4927
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4928
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4929
+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
4930
+ const dim3 block_nums(block_num_y, 1, 1);
4931
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4932
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
4933
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
4936
  static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4937
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4938
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4939
+ const dim3 block_nums(block_num_y, 1, 1);
4940
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4941
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
4942
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
4945
  static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4946
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4947
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4948
+ const dim3 block_nums(block_num_y, 1, 1);
4949
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4950
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
4951
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
4954
  static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4955
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4956
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4957
+ const dim3 block_nums(block_num_y, 1, 1);
4958
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4959
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
4960
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
4963
  static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4964
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4965
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4966
+ const dim3 block_nums(block_num_y, 1, 1);
4967
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4968
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
4969
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
4973
  GGML_ASSERT(ncols % QK_K == 0);
4974
  const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
4975
  const int block_num_y = (nrows + ny - 1) / ny;
4976
+ const dim3 block_nums(block_num_y, 1, 1);
4977
  const dim3 block_dims(32, ny, 1);
4978
  dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4979
  }
 
4982
  GGML_ASSERT(ncols % QK_K == 0);
4983
  const int ny = 2 / K_QUANTS_PER_ITERATION;
4984
  const int block_num_y = (nrows + ny - 1) / ny;
4985
+ const dim3 block_nums(block_num_y, 1, 1);
4986
  const dim3 block_dims(32, ny, 1);
4987
  dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4988
  }
 
4991
  GGML_ASSERT(ncols % QK_K == 0);
4992
  const int ny = 2 / K_QUANTS_PER_ITERATION;
4993
  const int block_num_y = (nrows + ny - 1) / ny;
4994
+ const dim3 block_nums(block_num_y, 1, 1);
4995
  const dim3 block_dims(32, ny, 1);
4996
  dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4997
  }
 
5006
  GGML_ASSERT(ncols % QK_K == 0);
5007
  const int ny = 2 / K_QUANTS_PER_ITERATION;
5008
  const int block_num_y = (nrows + ny - 1) / ny;
5009
+ const dim3 block_nums(block_num_y, 1, 1);
5010
  const dim3 block_dims(32, ny, 1);
5011
  dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
5012
  }
 
5014
  static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5015
  GGML_ASSERT(ncols % QK4_0 == 0);
5016
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5017
+ const dim3 block_nums(block_num_y, 1, 1);
5018
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5019
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
5020
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5023
  static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5024
  GGML_ASSERT(ncols % QK4_1 == 0);
5025
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5026
+ const dim3 block_nums(block_num_y, 1, 1);
5027
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5028
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
5029
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5032
  static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5033
  GGML_ASSERT(ncols % QK5_0 == 0);
5034
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5035
+ const dim3 block_nums(block_num_y, 1, 1);
5036
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5037
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
5038
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5041
  static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5042
  GGML_ASSERT(ncols % QK5_1 == 0);
5043
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5044
+ const dim3 block_nums(block_num_y, 1, 1);
5045
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5046
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
5047
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5050
  static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5051
  GGML_ASSERT(ncols % QK8_0 == 0);
5052
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5053
+ const dim3 block_nums(block_num_y, 1, 1);
5054
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5055
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
5056
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5059
  static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5060
  GGML_ASSERT(ncols % QK_K == 0);
5061
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5062
+ const dim3 block_nums(block_num_y, 1, 1);
5063
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5064
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
5065
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5068
  static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5069
  GGML_ASSERT(ncols % QK_K == 0);
5070
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5071
+ const dim3 block_nums(block_num_y, 1, 1);
5072
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5073
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
5074
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5077
  static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5078
  GGML_ASSERT(ncols % QK_K == 0);
5079
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5080
+ const dim3 block_nums(block_num_y, 1, 1);
5081
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5082
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
5083
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5086
  static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5087
  GGML_ASSERT(ncols % QK_K == 0);
5088
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5089
+ const dim3 block_nums(block_num_y, 1, 1);
5090
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5091
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
5092
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5095
  static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5096
  GGML_ASSERT(ncols % QK_K == 0);
5097
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5098
+ const dim3 block_nums(block_num_y, 1, 1);
5099
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5100
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
5101
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
5114
  static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5115
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
5116
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5117
+ const dim3 block_nums(block_num_y, 1, 1);
5118
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5119
  dequantize_mul_mat_vec<1, 1, convert_f16>
5120
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 
5851
  return ptr;
5852
  }
5853
 
 
 
 
 
 
 
 
 
 
 
5854
  static void ggml_cuda_pool_free(void * ptr, size_t size) {
5855
  scoped_spin_lock lock(g_cuda_pool_lock);
5856
  int id;
 
5868
  CUDA_CHECK(cudaFree(ptr));
5869
  }
5870
 
5871
+ static bool g_cublas_loaded = false;
5872
 
5873
+ bool ggml_cublas_loaded(void) {
5874
+ return g_cublas_loaded;
 
 
 
5875
  }
5876
 
5877
  void ggml_init_cublas() {
 
5886
  CUDA_CHECK(cudaDeviceSynchronize());
5887
  #endif
5888
 
5889
+ if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
5890
+ initialized = true;
5891
+ g_cublas_loaded = false;
5892
+ return;
5893
+ }
5894
+
5895
  GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
5896
  int64_t total_vram = 0;
5897
  #if defined(GGML_CUDA_FORCE_MMQ)
 
5933
  // create cublas handle
5934
  CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
5935
  CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
 
 
 
 
 
 
 
5936
  }
5937
 
5938
  // configure logging to stdout
5939
  // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
5940
 
5941
  initialized = true;
5942
+ g_cublas_loaded = true;
5943
  }
5944
  }
5945
 
 
6206
  (void) src1_dd;
6207
  }
6208
 
6209
+ inline void ggml_cuda_op_relu(
6210
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6211
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6212
+
6213
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6214
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6215
+
6216
+ relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6217
+
6218
+ (void) src1;
6219
+ (void) dst;
6220
+ (void) src1_dd;
6221
+ }
6222
+
6223
+ inline void ggml_cuda_op_sqr(
6224
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6225
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6226
+
6227
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6228
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6229
+
6230
+ sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6231
+
6232
+ (void) src1;
6233
+ (void) dst;
6234
+ (void) src1_dd;
6235
+ }
6236
+
6237
  inline void ggml_cuda_op_norm(
6238
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6239
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
6555
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
6556
  GGML_ASSERT(to_fp16_cuda != nullptr);
6557
  size_t ne = row_diff*ne00;
6558
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
6559
  to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
6560
  }
6561
  const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
 
6566
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
6567
  GGML_ASSERT(to_fp16_cuda != nullptr);
6568
  size_t ne = src1_ncols*ne10;
6569
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
6570
  to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6571
  }
6572
  const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
6573
+ size_t dst_as = 0;
6574
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6575
 
6576
  const half alpha_f16 = 1.0f;
6577
  const half beta_f16 = 0.0f;
 
6589
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6590
  to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6591
 
6592
+ ggml_cuda_pool_free(dst_f16, dst_as);
 
 
6593
 
6594
  if (src0_as != 0) {
6595
+ ggml_cuda_pool_free(src0_as_f16, src0_as);
6596
  }
6597
+
6598
  if (src1_as != 0) {
6599
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
6600
  }
6601
  }
6602
  else {
 
6606
  if (src0->type != GGML_TYPE_F32) {
6607
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6608
  GGML_ASSERT(to_fp32_cuda != nullptr);
6609
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6610
  to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6611
  }
6612
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
 
6623
  &beta, dst_dd_i, ldc));
6624
 
6625
  if (src0_as != 0) {
6626
+ ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6627
  }
6628
  }
6629
 
 
7048
  int64_t row_low[GGML_CUDA_MAX_DEVICES];
7049
  int64_t row_high[GGML_CUDA_MAX_DEVICES];
7050
 
7051
+ int used_devices = 0;
7052
+
7053
  for (int64_t id = 0; id < g_device_count; ++id) {
7054
  // by default, use all rows
7055
  row_low[id] = 0;
 
7077
  continue;
7078
  }
7079
 
7080
+ used_devices++;
7081
+
7082
  const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
7083
  const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
7084
 
 
7089
  src0_dd[id] = (char *) src0_extra->data_device[id];
7090
  } else {
7091
  const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
7092
+ src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
7093
  }
7094
 
7095
  if (src1_on_device && src1_is_contiguous) {
7096
  src1_ddf[id] = (float *) src1_extra->data_device[id];
7097
  } else {
7098
+ src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
7099
  }
7100
 
7101
  if (convert_src1_to_q8_1) {
7102
+ src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
 
7103
 
7104
  if (src1_on_device && src1_is_contiguous) {
7105
  quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
7106
+ CUDA_CHECK(cudaGetLastError());
7107
  }
7108
  }
7109
 
 
7111
  dst_dd[id] = (float *) dst_extra->data_device[id];
7112
  } else {
7113
  const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
7114
+ dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
7115
  }
7116
  }
7117
 
7118
  // if multiple devices are used they need to wait for the main device
7119
  // here an event is recorded that signals that the main device has finished calculating the input data
7120
+ if (split && used_devices > 1) {
7121
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7122
  CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
7123
  }
7124
 
7125
+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
7126
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
7127
  const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
7128
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
 
7237
  }
7238
  }
7239
 
7240
+ for (int64_t id = 0; id < g_device_count; ++id) {
7241
+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
7242
+ continue;
7243
+ }
7244
+ CUDA_CHECK(ggml_cuda_set_device(id));
7245
+
7246
+ // free buffers again when done
7247
+ if (src0_as[id] > 0) {
7248
+ ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
7249
+ }
7250
+ if (src1_asf[id] > 0) {
7251
+ ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
7252
+ }
7253
+ if (src1_asq[id] > 0) {
7254
+ ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
7255
+ }
7256
+ if (dst_as[id] > 0) {
7257
+ ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
7258
+ }
7259
+ }
7260
+
7261
  // main device waits for all other devices to be finished
7262
  if (split && g_device_count > 1) {
7263
  int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
 
7265
 
7266
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7267
  for (int64_t id = 0; id < g_device_count; ++id) {
7268
+ if (row_low[id] == row_high[id]) {
7269
+ continue;
7270
+ }
7271
  for (int64_t is = 0; is < is_max; ++is) {
7272
  CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
7273
  }
 
7278
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7279
  CUDA_CHECK(cudaDeviceSynchronize());
7280
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7281
  }
7282
 
7283
  static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
7304
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
7305
  }
7306
 
7307
+ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7308
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
7309
+ }
7310
+
7311
+ static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7312
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
7313
+ }
7314
+
7315
  static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7316
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
7317
  }
 
7321
  }
7322
 
7323
  bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
7324
+ if (!g_cublas_loaded) return false;
7325
+
7326
  const int64_t ne10 = src1->ne[0];
7327
 
7328
  const int64_t ne0 = dst->ne[0];
 
7474
  GGML_ASSERT(to_fp16_cuda != nullptr);
7475
 
7476
  size_t src1_as = 0;
7477
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7478
  to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
7479
 
7480
  size_t dst_as = 0;
7481
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7482
 
7483
  GGML_ASSERT(ne12 % ne02 == 0);
7484
  GGML_ASSERT(ne13 % ne03 == 0);
 
7532
  size_t ptrs_src_s = 0;
7533
  size_t ptrs_dst_s = 0;
7534
 
7535
+ ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
7536
+ ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
7537
 
7538
  dim3 block_dims(ne13, ne12);
7539
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
 
7546
  dst->nb[2], dst->nb[3],
7547
  r2, r3);
7548
  CUDA_CHECK(cudaGetLastError());
7549
+
7550
  CUBLAS_CHECK(
7551
  cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7552
  ne01, ne11, ne10,
 
7558
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7559
 
7560
  if (ptrs_src_s != 0) {
7561
+ ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
7562
  }
7563
  if (ptrs_dst_s != 0) {
7564
+ ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
7565
  }
7566
  }
7567
  #endif
7568
 
7569
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7570
  to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7571
+
7572
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
7573
+ ggml_cuda_pool_free(dst_f16, dst_as);
 
 
 
7574
  }
7575
 
7576
  static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7577
  const bool all_on_device =
7578
+ (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
7579
  (src1->backend == GGML_BACKEND_GPU) &&
7580
  ( dst->backend == GGML_BACKEND_GPU);
7581
 
7582
+ const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
7583
+
7584
  int64_t min_compute_capability = INT_MAX;
7585
  for (int64_t id = 0; id < g_device_count; ++id) {
7586
  if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
 
7602
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7603
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7604
 
7605
+ if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7606
  // KQ single-batch
7607
  ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7608
+ } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7609
  // KQV single-batch
7610
  ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7611
+ } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
7612
  // KQ + KQV multi-batch
7613
  ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
7614
  } else if (src0->type == GGML_TYPE_F32) {
 
7729
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
7730
  }
7731
 
7732
+ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7733
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7734
  }
7735
 
 
7844
 
7845
  static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
7846
  if (g_temp_tensor_extras == nullptr) {
7847
+ g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
7848
  }
7849
 
7850
  size_t alloc_index = g_temp_tensor_extra_index;
7851
+ g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
7852
  ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
7853
  memset(extra, 0, sizeof(*extra));
7854
 
 
8015
  }
8016
 
8017
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
8018
+ if (!g_cublas_loaded) return false;
8019
+
8020
  ggml_cuda_func_t func;
8021
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
8022
  || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
 
8059
  case GGML_UNARY_OP_SILU:
8060
  func = ggml_cuda_silu;
8061
  break;
8062
+ case GGML_UNARY_OP_RELU:
8063
+ func = ggml_cuda_relu;
8064
+ break;
8065
  default:
8066
  return false;
8067
  } break;
 
8080
  case GGML_OP_SCALE:
8081
  func = ggml_cuda_scale;
8082
  break;
8083
+ case GGML_OP_SQR:
8084
+ func = ggml_cuda_sqr;
8085
+ break;
8086
  case GGML_OP_CLAMP:
8087
  if (!any_on_device) {
8088
  return false;
 
8175
 
8176
  ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
8177
  if (temp_tensor_extras == nullptr) {
8178
+ temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
8179
  }
8180
 
8181
  size_t alloc_index = temp_tensor_extra_index;
8182
+ temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
8183
  ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
8184
  memset(extra, 0, sizeof(*extra));
8185
 
ggml-cuda.h CHANGED
@@ -17,7 +17,12 @@ extern "C" {
17
 
18
  #define GGML_CUDA_MAX_DEVICES 16
19
 
 
20
  GGML_API void ggml_init_cublas(void);
 
 
 
 
21
  GGML_API void * ggml_cuda_host_malloc(size_t size);
22
  GGML_API void ggml_cuda_host_free(void * ptr);
23
 
 
17
 
18
  #define GGML_CUDA_MAX_DEVICES 16
19
 
20
+ // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
21
  GGML_API void ggml_init_cublas(void);
22
+
23
+ // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24
+ GGML_API bool ggml_cublas_loaded(void);
25
+
26
  GGML_API void * ggml_cuda_host_malloc(size_t size);
27
  GGML_API void ggml_cuda_host_free(void * ptr);
28
 
whisper.cpp CHANGED
@@ -20,6 +20,7 @@
20
  #include "ggml-alloc.h"
21
  #include "ggml-backend.h"
22
 
 
23
  #include <algorithm>
24
  #include <cassert>
25
  #define _USE_MATH_DEFINES
@@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
147
 
148
  //#define WHISPER_USE_FLASH_ATTN
149
  //#define WHISPER_USE_FLASH_FF
150
- #define WHISPER_MAX_DECODERS 16
151
  #define WHISPER_MAX_NODES 4096
152
 
153
  //
@@ -406,6 +407,121 @@ struct whisper_segment {
406
  bool speaker_turn_next;
407
  };
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  // medium
410
  // hparams: {
411
  // 'n_mels': 80,
@@ -523,15 +639,31 @@ struct whisper_layer_decoder {
523
  struct ggml_tensor * mlp_1_b;
524
  };
525
 
 
 
 
 
 
 
 
 
 
 
526
  struct whisper_kv_cache {
 
 
 
 
 
 
 
 
527
  struct ggml_tensor * k;
528
  struct ggml_tensor * v;
529
 
530
  struct ggml_context * ctx;
531
 
532
  ggml_backend_buffer_t buffer;
533
-
534
- int n; // number of tokens currently in the cache
535
  };
536
 
537
  struct whisper_model {
@@ -585,11 +717,11 @@ struct whisper_partial_utf8 {
585
  };
586
 
587
  struct whisper_grammar {
588
- /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
589
- std::vector<std::vector<const whisper_grammar_element *>> stacks;
590
 
591
  // buffer for partially generated UTF-8 sequence from accepted tokens
592
- whisper_partial_utf8 partial_utf8;
593
  };
594
 
595
  struct whisper_grammar_candidate {
@@ -613,15 +745,13 @@ struct whisper_sequence {
613
 
614
  // TAGS: WHISPER_DECODER_INIT
615
  struct whisper_decoder {
616
- // each decoder keeps its own KV-cache
617
- whisper_kv_cache kv_self;
618
-
619
  // the currently generated sequence of tokens
620
  whisper_sequence sequence;
621
 
622
  // grammar parse state of generated sequence of tokens
623
  whisper_grammar grammar;
624
 
 
625
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
626
 
627
  bool failed; // has the current segment failed to decode?
@@ -633,100 +763,40 @@ struct whisper_decoder {
633
  std::vector<float> logits;
634
  std::vector<float> logprobs;
635
 
636
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
637
- };
638
-
639
- // replace std::pair by using customized pair struct (reason: std::pair is very slow)
640
- template<typename A, typename B>
641
- struct whisper_pair {
642
- A first;
643
- B second;
644
-
645
- // Define a constructor that takes two arguments.
646
- whisper_pair(const A& a, const B& b) : first(a), second(b) {}
647
- // Define a constructor that takes no argument.
648
- whisper_pair() : first(A()), second(B()) {}
649
- };
650
-
651
- // beam-search helpers
652
- struct kv_buf {
653
- std::vector<uint8_t> k;
654
- std::vector<uint8_t> v;
655
- };
656
-
657
- // ggml_allocr wrapper for whisper usage
658
- struct whisper_allocr {
659
- ggml_allocr * alloc = nullptr;
660
-
661
- std::vector<uint8_t> meta;
662
 
663
- ggml_backend_buffer_t buffer;
664
  };
665
 
666
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
667
- return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
668
- }
669
-
670
- // measure the memory usage of a graph and prepare the allocr's internal data buffer
671
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
672
- auto & alloc = allocr.alloc;
673
- auto & meta = allocr.meta;
674
-
675
- alloc = ggml_allocr_new_measure_from_backend(backend);
676
-
677
- meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
678
-
679
- ggml_allocr_alloc_graph(alloc, get_graph());
680
- }
681
-
682
- static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
683
- if (allocr.alloc == nullptr) {
684
- // this can be null if we use external encoder like CoreML or OpenVINO
685
- return;
686
- }
687
-
688
- auto & alloc = allocr.alloc;
689
- auto & buffer = allocr.buffer;
690
-
691
- size_t size = ggml_allocr_max_size(alloc);
692
-
693
- ggml_allocr_free(alloc);
694
-
695
- buffer = ggml_backend_alloc_buffer(backend, size);
696
- alloc = ggml_allocr_new_from_buffer(buffer);
697
- }
698
-
699
- static void whisper_allocr_free(struct whisper_allocr & allocr) {
700
- if (allocr.alloc) {
701
- ggml_allocr_free(allocr.alloc);
702
- ggml_backend_buffer_free(allocr.buffer);
703
- allocr.alloc = nullptr;
704
- }
705
- }
706
-
707
  struct whisper_state {
708
  int64_t t_sample_us = 0;
709
  int64_t t_encode_us = 0;
710
  int64_t t_decode_us = 0;
 
711
  int64_t t_prompt_us = 0;
712
  int64_t t_mel_us = 0;
713
 
714
  int32_t n_sample = 0; // number of tokens sampled
715
  int32_t n_encode = 0; // number of encoder calls
716
- int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
717
- int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
 
718
  int32_t n_fail_p = 0; // number of logprob threshold failures
719
  int32_t n_fail_h = 0; // number of entropy threshold failures
720
 
 
 
 
721
  // cross-attention KV cache for the decoders
722
  // shared between all decoders
723
  whisper_kv_cache kv_cross;
 
724
  whisper_mel mel;
725
 
726
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
727
 
728
- // buffer for swapping KV caches between decoders during beam-search
729
- std::vector<kv_buf> kv_swap_bufs;
730
 
731
  ggml_backend_t backend = nullptr;
732
 
@@ -742,8 +812,9 @@ struct whisper_state {
742
  struct ggml_tensor * embd_conv = nullptr;
743
  struct ggml_tensor * embd_enc = nullptr;
744
 
745
- // helper for GPU offloading
746
  std::vector<float> inp_mel;
 
747
 
748
  // decode output (2-dimensional array: [n_tokens][n_vocab])
749
  std::vector<float> logits;
@@ -751,11 +822,6 @@ struct whisper_state {
751
  std::vector<whisper_segment> result_all;
752
  std::vector<whisper_token> prompt_past;
753
 
754
- // work container used to avoid memory allocations
755
- std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
756
-
757
- mutable std::mt19937 rng; // used for sampling at t > 0.0
758
-
759
  int lang_id = 0; // english by default
760
 
761
  std::string path_model; // populated by whisper_init_from_file_with_params()
@@ -831,6 +897,12 @@ static bool kv_cache_init(
831
  /*.no_alloc =*/ true,
832
  };
833
 
 
 
 
 
 
 
834
  cache.ctx = ggml_init(params);
835
 
836
  if (!cache.ctx) {
@@ -858,54 +930,129 @@ static bool kv_cache_init(
858
  return true;
859
  }
860
 
861
- // TODO: remove after batched decoding
862
- static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
863
- WHISPER_ASSERT(cache.ctx);
 
 
 
 
864
 
865
- const int n_elements = ggml_nelements(cache.k);
866
- WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
 
 
 
867
 
868
- const ggml_type wtype = cache.k->type;
869
- WHISPER_ASSERT(wtype == cache.v->type);
 
 
870
 
871
- struct ggml_init_params params = {
872
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
873
- /*.mem_buffer =*/ nullptr,
874
- /*.no_alloc =*/ true,
875
- };
876
 
877
- cache.ctx = ggml_init(params);
 
 
 
 
 
878
 
879
- if (!cache.ctx) {
880
- WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
881
- return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  }
883
 
884
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
885
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
886
 
887
- const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
 
 
 
888
 
889
- cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
 
890
 
891
- // allocate the tensors into the backend buffer
892
- {
893
- ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
 
 
 
 
894
 
895
- ggml_allocr_alloc(alloc, cache.k);
896
- ggml_allocr_alloc(alloc, cache.v);
897
 
898
- ggml_allocr_free(alloc);
 
 
 
899
  }
 
 
900
 
901
- return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
  }
903
 
904
- static void kv_cache_free(struct whisper_kv_cache & cache) {
905
- if (cache.ctx) {
906
- ggml_free(cache.ctx);
907
- ggml_backend_buffer_free(cache.buffer);
908
- cache.ctx = nullptr;
 
 
 
 
 
 
 
 
 
 
909
  }
910
  }
911
 
@@ -914,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
914
 
915
  // initialize the backends
916
  #ifdef GGML_USE_CUBLAS
917
- if (params.use_gpu) {
918
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
919
  backend_gpu = ggml_backend_cuda_init();
920
  if (!backend_gpu) {
@@ -1116,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1116
  word = "[_EOT_]";
1117
  } else if (i == vocab.token_sot) {
1118
  word = "[_SOT_]";
 
 
 
 
1119
  } else if (i == vocab.token_solm) {
1120
  word = "[_SOLM_]";
1121
  } else if (i == vocab.token_prev) {
@@ -1126,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1126
  word = "[_NOT_]";
1127
  } else if (i == vocab.token_beg) {
1128
  word = "[_BEG_]";
 
 
1129
  } else {
1130
  word = "[_extra_token_" + std::to_string(i) + "]";
1131
  }
@@ -2031,26 +2184,28 @@ static bool whisper_encode_internal(
2031
  static struct ggml_cgraph * whisper_build_graph_decoder(
2032
  whisper_context & wctx,
2033
  whisper_state & wstate,
2034
- whisper_decoder & decoder,
2035
- const whisper_token * tokens,
2036
- int n_tokens,
2037
- int n_past) {
2038
  const auto & model = wctx.model;
2039
  const auto & hparams = model.hparams;
2040
 
2041
- auto & kv_self = decoder.kv_self;
2042
 
2043
  WHISPER_ASSERT(!!kv_self.ctx);
2044
 
2045
- const int n_ctx = hparams.n_text_ctx;
 
 
2046
  const int n_state = hparams.n_text_state;
2047
  const int n_head = hparams.n_text_head;
2048
  const int n_layer = hparams.n_text_layer;
2049
 
2050
- const int N = n_tokens;
2051
- const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2052
 
2053
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
 
 
 
2054
 
2055
  struct ggml_init_params params = {
2056
  /*.mem_size =*/ wstate.alloc_decode.meta.size(),
@@ -2062,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2062
 
2063
  ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2064
 
2065
- ggml_allocr * alloc = wstate.alloc_decode.alloc;
2066
-
2067
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2068
  ggml_allocr_alloc(alloc, embd);
2069
 
2070
  if (!ggml_allocr_is_measure(alloc)) {
2071
- ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
2072
  }
2073
 
2074
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2075
  ggml_allocr_alloc(alloc, position);
2076
 
2077
  if (!ggml_allocr_is_measure(alloc)) {
2078
- for (int i = 0; i < N; ++i) {
2079
- const int32_t val = n_past + i;
2080
  ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2081
  }
2082
  }
@@ -2089,6 +2242,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2089
  ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2090
  }
2091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2092
  // token encoding + position encoding
2093
  struct ggml_tensor * cur =
2094
  ggml_add(ctx0,
@@ -2141,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2141
  Vcur,
2142
  layer.attn_v_b);
2143
 
2144
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
2145
 
2146
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
2147
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
2148
  ( n_ctx)*ggml_element_size(kv_self.v),
2149
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
2150
 
2151
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2152
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2156,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2156
 
2157
  struct ggml_tensor * Q =
2158
  ggml_permute(ctx0,
2159
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2160
  0, 2, 1, 3);
2161
 
2162
  struct ggml_tensor * K =
2163
  ggml_view_3d(ctx0, kv_self.k,
2164
- n_state/n_head, n_past + N, n_head,
2165
  ggml_element_size(kv_self.k)*n_state,
2166
  ggml_element_size(kv_self.k)*n_state/n_head,
2167
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@@ -2171,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2171
 
2172
  //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
2173
 
2174
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
 
2175
 
2176
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2177
 
2178
  struct ggml_tensor * V =
2179
  ggml_view_3d(ctx0, kv_self.v,
2180
- n_past + N, n_state/n_head, n_head,
2181
  n_ctx*ggml_element_size(kv_self.v),
2182
  n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2183
- il*n_ctx*ggml_element_size(kv_self.v)*n_state);
2184
 
2185
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2186
 
@@ -2188,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2188
 
2189
  cur = ggml_cpy(ctx0,
2190
  KQV_merged,
2191
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2192
  }
2193
 
2194
  // projection
@@ -2232,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2232
  // Kcross is already scaled
2233
  struct ggml_tensor * Kcross =
2234
  ggml_view_3d(ctx0, wstate.kv_cross.k,
2235
- n_state/n_head, M, n_head,
2236
  ggml_element_size(wstate.kv_cross.k)*n_state,
2237
  ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2238
- ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2239
 
2240
  //struct ggml_tensor * Vcross =
2241
  // ggml_reshape_3d(ctx0,
2242
- // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2243
- // n_state/n_head, n_head, M);
2244
 
2245
  //struct ggml_tensor * V_trans =
2246
  // ggml_cpy(ctx0,
2247
  // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2248
- // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
2249
 
2250
  struct ggml_tensor * V =
2251
  ggml_view_3d(ctx0, wstate.kv_cross.v,
2252
- M, n_state/n_head, n_head,
2253
- M*ggml_element_size(wstate.kv_cross.v),
2254
- M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2255
- il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
2256
 
2257
  // ------
2258
 
2259
  struct ggml_tensor * Q =
2260
  ggml_permute(ctx0,
2261
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2262
  0, 2, 1, 3);
2263
 
2264
  // K * Q
@@ -2279,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2279
 
2280
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2281
 
2282
- // cur = KQV_merged.contiguous().view(n_state, N)
2283
  cur = ggml_cpy(ctx0,
2284
  KQV_merged,
2285
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2286
  }
2287
 
2288
  // projection
@@ -2354,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2354
  }
2355
 
2356
  // compute logits only for the last token
2357
- // comment this line to compute logits for all N tokens
2358
  // might be useful in the future
2359
- cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2360
 
2361
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2362
 
@@ -2380,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2380
  static bool whisper_decode_internal(
2381
  whisper_context & wctx,
2382
  whisper_state & wstate,
2383
- whisper_decoder & decoder,
2384
- const whisper_token * tokens,
2385
- const int n_tokens,
2386
- const int n_past,
2387
  const int n_threads,
2388
  whisper_abort_callback abort_callback,
2389
  void * abort_callback_data) {
@@ -2392,19 +2568,33 @@ static bool whisper_decode_internal(
2392
  const auto & model = wctx.model;
2393
  const auto & hparams = model.hparams;
2394
 
2395
- const int n_vocab = hparams.n_vocab;
 
2396
 
2397
  auto & logits_out = wstate.logits;
2398
 
2399
  struct ggml_tensor * logits;
2400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2401
  // decoder
2402
  {
2403
  auto & alloc = wstate.alloc_decode.alloc;
2404
 
2405
  ggml_allocr_reset(alloc);
2406
 
2407
- ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2408
 
2409
  ggml_allocr_alloc_graph(alloc, gf);
2410
 
@@ -2413,17 +2603,15 @@ static bool whisper_decode_internal(
2413
  ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2414
  }
2415
 
2416
- // extract logits for all N tokens
2417
- //logits_out.resize(n_tokens*n_vocab);
2418
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2419
- //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
2420
-
2421
- // extract logits only for the last token
2422
- logits_out.resize(n_vocab);
2423
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2424
- ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
2425
 
2426
- if (n_tokens > 1) {
2427
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2428
  // ggml_used_mem(ctx0)/1024.0/1024.0,
2429
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2432,18 +2620,20 @@ static bool whisper_decode_internal(
2432
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2433
  }
2434
 
2435
- if (n_tokens == 1) {
2436
  wstate.t_decode_us += ggml_time_us() - t_start_us;
2437
  wstate.n_decode++;
 
 
 
2438
  } else {
2439
  wstate.t_prompt_us += ggml_time_us() - t_start_us;
2440
- wstate.n_prompt++;
2441
  }
2442
 
2443
  return !(abort_callback && abort_callback(abort_callback_data));
2444
  }
2445
 
2446
-
2447
  // 500 -> 00:05.000
2448
  // 6000 -> 01:00.000
2449
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2855,14 +3045,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2855
 
2856
  state->backend = whisper_backend_init(ctx->params);
2857
 
2858
- if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
 
 
 
 
2859
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2860
  delete state;
2861
  return nullptr;
2862
  }
2863
 
2864
  {
2865
- const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2866
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2867
  }
2868
 
@@ -2897,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2897
 
2898
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2899
 
2900
- state->logits_id.reserve(ctx->model.hparams.n_vocab);
2901
 
2902
  // TAGS: WHISPER_DECODER_INIT
2903
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2904
 
2905
- state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2906
- state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2907
- state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
 
 
 
2908
 
2909
  // conv allocator
2910
  {
@@ -2946,7 +3143,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2946
  const int n_tokens = hparams.n_text_ctx;
2947
  const int n_past = 0;
2948
 
2949
- return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
 
 
2950
  });
2951
 
2952
  WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
@@ -2957,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2957
  whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
2958
  whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
2959
 
2960
- state->rng = std::mt19937(0);
2961
-
2962
  return state;
2963
  }
2964
 
@@ -3183,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3183
  void whisper_free_state(struct whisper_state * state)
3184
  {
3185
  if (state) {
 
3186
  kv_cache_free(state->kv_cross);
3187
 
3188
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
3189
- kv_cache_free(state->decoders[i].kv_self);
3190
- }
3191
-
3192
  #ifdef WHISPER_USE_COREML
3193
  if (state->ctx_coreml != nullptr) {
3194
  whisper_coreml_free(state->ctx_coreml);
@@ -3203,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state)
3203
  }
3204
  #endif
3205
 
 
 
3206
  whisper_allocr_free(state->alloc_conv);
3207
  whisper_allocr_free(state->alloc_encode);
3208
  whisper_allocr_free(state->alloc_cross);
@@ -3329,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3329
  }
3330
 
3331
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3332
- const int selected_decoder_id = 0;
 
 
3333
 
3334
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3335
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3336
  return 1;
3337
  }
@@ -3340,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3340
  }
3341
 
3342
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3343
- // TODO: add selected_decoder_id to state
3344
- const int selected_decoder_id = 0;
3345
-
3346
  if (ctx->state == nullptr) {
3347
  WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
3348
  return false;
3349
  }
3350
 
3351
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
 
 
 
 
3352
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3353
  return 1;
3354
  }
@@ -3436,7 +3635,7 @@ int whisper_lang_auto_detect_with_state(
3436
  return -7;
3437
  }
3438
 
3439
- auto & logits_id = state->logits_id;
3440
  logits_id.clear();
3441
 
3442
  for (const auto & kv : g_lang) {
@@ -3639,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
3639
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3640
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3641
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
 
3642
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3643
 
3644
  WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
@@ -3646,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
3646
  WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3647
  WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3648
  WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
 
3649
  WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3650
  }
3651
  WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
@@ -3662,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3662
  ctx->state->n_sample = 0;
3663
  ctx->state->n_encode = 0;
3664
  ctx->state->n_decode = 0;
 
3665
  ctx->state->n_prompt = 0;
3666
  }
3667
  }
@@ -3969,8 +4171,7 @@ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_
3969
  if (*tok.code_points == 0) {
3970
  // reached end of full codepoints in token, reject iff it ended in a partial sequence
3971
  // that cannot satisfy this position in grammar
3972
- if (tok.partial_utf8.n_remain != 0 &&
3973
- !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
3974
  rejects.push_back(tok);
3975
  }
3976
  } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
@@ -4189,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4189
  /*.max_initial_ts =*/ 1.0f,
4190
  /*.length_penalty =*/ -1.0f,
4191
 
4192
- /*.temperature_inc =*/ 0.4f,
4193
  /*.entropy_thold =*/ 2.4f,
4194
  /*.logprob_thold =*/ -1.0f,
4195
  /*.no_speech_thold =*/ 0.6f,
@@ -4229,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4229
  case WHISPER_SAMPLING_GREEDY:
4230
  {
4231
  result.greedy = {
4232
- /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4233
  };
4234
  } break;
4235
  case WHISPER_SAMPLING_BEAM_SEARCH:
4236
  {
4237
  result.beam_search = {
4238
- /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4239
 
4240
  /*.patience =*/ -1.0f,
4241
  };
@@ -4325,11 +4526,12 @@ static const std::vector<std::string> non_speech_tokens = {
4325
  // process the logits for the selected decoder
4326
  // - applies logit filters
4327
  // - computes logprobs and probs
 
4328
  static void whisper_process_logits(
4329
  struct whisper_context & ctx,
4330
  struct whisper_state & state,
4331
- const struct whisper_full_params params,
4332
  struct whisper_decoder & decoder,
 
4333
  float temperature) {
4334
  const auto & vocab = ctx.vocab;
4335
  const auto & tokens_cur = decoder.sequence.tokens;
@@ -4346,7 +4548,7 @@ static void whisper_process_logits(
4346
  auto & logprobs = decoder.logprobs;
4347
  {
4348
  logits.resize(n_logits);
4349
- memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
4350
 
4351
  if (temperature > 0.0f) {
4352
  for (int i = 0; i < n_logits; i++) {
@@ -4512,30 +4714,31 @@ static void whisper_process_logits(
4512
  //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4513
 
4514
  if (timestamp_logprob > max_text_token_logprob) {
4515
- //printf("sampling timestamp\n");
4516
  for (int i = 0; i < vocab.token_beg; ++i) {
4517
  logits[i] = -INFINITY;
4518
  logprobs[i] = -INFINITY;
4519
  }
4520
- } else if (params.n_grammar_rules > 0) {
4521
- whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
 
4522
 
4523
- // populate the logprobs array (log_softmax)
4524
- {
4525
- const float logit_max = *std::max_element(logits.begin(), logits.end());
4526
- float logsumexp = 0.0f;
4527
- for (int i = 0; i < n_logits; ++i) {
4528
- if (logits[i] > -INFINITY) {
4529
- logsumexp += expf(logits[i] - logit_max);
 
4530
  }
4531
- }
4532
- logsumexp = logf(logsumexp) + logit_max;
4533
 
4534
- for (int i = 0; i < n_logits; ++i) {
4535
- if (logits[i] > -INFINITY) {
4536
- logprobs[i] = logits[i] - logsumexp;
4537
- } else {
4538
- logprobs[i] = -INFINITY;
 
4539
  }
4540
  }
4541
  }
@@ -4610,7 +4813,6 @@ static void whisper_process_logits(
4610
 
4611
  static whisper_token_data whisper_sample_token(
4612
  whisper_context & ctx,
4613
- whisper_state & state,
4614
  const whisper_decoder & decoder,
4615
  bool best) {
4616
  whisper_token_data result = {
@@ -4655,7 +4857,7 @@ static whisper_token_data whisper_sample_token(
4655
  } else {
4656
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4657
 
4658
- result.id = dist(state.rng);
4659
  result.p = probs[result.id];
4660
  result.plog = logprobs[result.id];
4661
  }
@@ -4665,15 +4867,12 @@ static whisper_token_data whisper_sample_token(
4665
  result.pt = result.p;
4666
  }
4667
 
4668
- state.n_sample++;
4669
-
4670
  return result;
4671
  }
4672
 
4673
  static std::vector<whisper_token_data> whisper_sample_token_topk(
4674
  whisper_context & ctx,
4675
- whisper_state & state,
4676
- const whisper_decoder & decoder,
4677
  int k) {
4678
  const auto & vocab = ctx.vocab;
4679
 
@@ -4683,7 +4882,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4683
 
4684
  const int n_logits = vocab.n_vocab;
4685
 
4686
- auto & logits_id = state.logits_id;
4687
 
4688
  logits_id.resize(n_logits);
4689
  for (int i = 0; i < n_logits; ++i) {
@@ -4732,7 +4931,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4732
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4733
 
4734
  for (int i = 0; i < k; ++i) {
4735
- const auto id = dist(state.rng);
4736
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4737
 
4738
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
@@ -4743,8 +4942,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4743
  }
4744
  }
4745
 
4746
- state.n_sample++;
4747
-
4748
  return result;
4749
  }
4750
 
@@ -4797,125 +4994,6 @@ static void whisper_sequence_score(
4797
  }
4798
  }
4799
 
4800
- static bool whisper_kv_swap_fast(
4801
- std::vector<int> & view,
4802
- whisper_decoder src[],
4803
- std::vector<kv_buf> & kv_swap_bufs,
4804
- const int & n_decoders) {
4805
- WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4806
-
4807
- // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4808
- std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4809
-
4810
- // (buffer->decoder or decoder->decoder)
4811
- std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4812
-
4813
- // (decoder<->decoder)
4814
- std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4815
- std::vector<whisper_pair<int, int>> p_swap_vec;
4816
- p_swap_vec.reserve(n_decoders);
4817
-
4818
- // see https://github.com/ggerganov/whisper.cpp/wiki
4819
- for (int i = 0; i < n_decoders; i++) {
4820
- // zero-copy (no modification)
4821
- if (i == view[i] || view[i] < 0) {
4822
- continue;
4823
- }
4824
-
4825
- bool is_one_copy = true;
4826
- // since we modify data sequentially, we only consider decoder indices after current index
4827
- for (int j = i + 1; j < n_decoders; j++) {
4828
- if (i == view[j]) {
4829
- // detect symmetric diagram
4830
- if (j == view[i]) {
4831
- p_swap_set.insert(i);
4832
- p_swap_set.insert(j);
4833
- p_swap_vec.emplace_back(i, j);
4834
- } else {
4835
- two_copy.insert(i);
4836
- is_one_copy = false;
4837
- }
4838
- break;
4839
- }
4840
- }
4841
- if (is_one_copy) {
4842
- one_copy.insert(i);
4843
- }
4844
- }
4845
-
4846
- kv_swap_bufs.resize(n_decoders);
4847
-
4848
- for (int i = 0; i < n_decoders; i++) {
4849
- kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
4850
- kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
4851
- }
4852
-
4853
- for (auto & i : two_copy) {
4854
- // make a copy of KV caches
4855
- WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4856
- //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4857
- //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4858
- ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
4859
- ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
4860
- }
4861
-
4862
- // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4863
- for (auto & i : two_copy) {
4864
- // skip the decoder indices that require pointer swapping
4865
- if (p_swap_set.find(i) != p_swap_set.end()) {
4866
- continue;
4867
- }
4868
-
4869
- if (two_copy.find(view[i]) != two_copy.end()) {
4870
- // modify KV caches of decoder using data from kv_swap_bufs
4871
- WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4872
- //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4873
- //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4874
- ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
4875
- ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
4876
- } else {
4877
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4878
- WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4879
- //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4880
- //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4881
- ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
4882
- ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
4883
- }
4884
- }
4885
-
4886
- // then modify one-copy decoder KV caches
4887
- for (auto & i : one_copy) {
4888
- // skip the decoder indices that require pointer swapping
4889
- if (p_swap_set.find(i) != p_swap_set.end()) {
4890
- continue;
4891
- }
4892
-
4893
- if (two_copy.find(view[i]) != two_copy.end()) {
4894
- // modify KV caches of decoder using data from kv_swap_bufs
4895
- WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4896
- //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4897
- //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4898
- ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
4899
- ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
4900
- } else {
4901
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4902
- WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4903
- //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4904
- //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4905
- ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
4906
- ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
4907
- }
4908
- }
4909
-
4910
- // swap the pointers
4911
- for (auto & i : p_swap_vec) {
4912
- WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4913
- std::swap(src[i.first].kv_self, src[i.second].kv_self);
4914
- }
4915
-
4916
- return true;
4917
- }
4918
-
4919
  int whisper_full_with_state(
4920
  struct whisper_context * ctx,
4921
  struct whisper_state * state,
@@ -5005,25 +5083,23 @@ int whisper_full_with_state(
5005
 
5006
  n_decoders = std::max(1, n_decoders);
5007
 
 
 
 
 
 
5008
  // TAGS: WHISPER_DECODER_INIT
5009
  for (int j = 1; j < n_decoders; j++) {
5010
  auto & decoder = state->decoders[j];
5011
 
5012
- if (decoder.kv_self.ctx == nullptr) {
5013
- decoder.kv_self = state->decoders[0].kv_self;
5014
- if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
5015
- WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
5016
- return -4;
5017
- }
5018
-
5019
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
5020
 
5021
- decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
 
 
5022
 
5023
- decoder.probs.resize (ctx->vocab.n_vocab);
5024
- decoder.logits.resize (ctx->vocab.n_vocab);
5025
- decoder.logprobs.resize(ctx->vocab.n_vocab);
5026
- }
5027
  }
5028
 
5029
  // the accumulated text context so far
@@ -5100,8 +5176,10 @@ int whisper_full_with_state(
5100
  bool has_ts;
5101
 
5102
  whisper_sequence sequence;
 
5103
  };
5104
 
 
5105
  std::vector<beam_candidate> beam_candidates;
5106
 
5107
  // main loop
@@ -5169,8 +5247,6 @@ int whisper_full_with_state(
5169
  for (int j = 0; j < n_decoders_cur; ++j) {
5170
  auto & decoder = state->decoders[j];
5171
 
5172
- decoder.kv_self.n = 0;
5173
-
5174
  decoder.sequence.tokens.clear();
5175
  decoder.sequence.result_len = 0;
5176
  decoder.sequence.sum_logprobs_all = 0.0;
@@ -5186,15 +5262,14 @@ int whisper_full_with_state(
5186
  decoder.has_ts = false;
5187
 
5188
  if (params.grammar_rules != nullptr) {
5189
- decoder.grammar = whisper_grammar_init(
5190
- params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
5191
  } else {
5192
  decoder.grammar = {};
5193
  }
5194
  }
5195
 
5196
  // init prompt and kv cache for the current iteration
5197
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
5198
  {
5199
  prompt.clear();
5200
 
@@ -5216,7 +5291,11 @@ int whisper_full_with_state(
5216
  }
5217
  WHISPER_PRINT_DEBUG("\n\n");
5218
 
5219
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
 
 
 
 
5220
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5221
  return -7;
5222
  }
@@ -5224,20 +5303,14 @@ int whisper_full_with_state(
5224
  {
5225
  const int64_t t_start_sample_us = ggml_time_us();
5226
 
5227
- whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
5228
 
5229
- state->decoders[0].kv_self.n += prompt.size();
5230
 
5231
  for (int j = 1; j < n_decoders_cur; ++j) {
5232
  auto & decoder = state->decoders[j];
5233
 
5234
- // TODO: fix CUDA
5235
- //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
5236
- //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
5237
- ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
5238
- ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
5239
-
5240
- decoder.kv_self.n += prompt.size();
5241
 
5242
  memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
5243
  memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
@@ -5252,41 +5325,81 @@ int whisper_full_with_state(
5252
  const int64_t t_start_sample_us = ggml_time_us();
5253
 
5254
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
5255
- beam_candidates.clear();
 
 
5256
  }
5257
 
5258
- // generate new sequence candidates for each decoder
5259
- for (int j = 0; j < n_decoders_cur; ++j) {
5260
- auto & decoder = state->decoders[j];
 
5261
 
5262
- if (decoder.completed || decoder.failed) {
5263
- continue;
5264
- }
5265
 
5266
- switch (params.strategy) {
5267
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
5268
- {
5269
- if (t_cur < 1e-6f) {
5270
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
5271
- } else {
5272
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
5273
- }
5274
 
5275
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
5276
- } break;
5277
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
5278
- {
5279
- const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
5280
 
5281
- for (const auto & token : tokens_new) {
5282
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
5283
- beam_candidates.back().sequence.tokens.push_back(token);
5284
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
5285
 
5286
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
5287
- }
5288
- } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5289
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5290
  }
5291
 
5292
  // for beam-search, choose the top candidates and update the KV caches
@@ -5299,7 +5412,6 @@ int whisper_full_with_state(
5299
  });
5300
 
5301
  uint32_t cur_c = 0;
5302
- std::vector<int> decoder_idx(n_decoders_cur, -1);
5303
 
5304
  for (int j = 0; j < n_decoders_cur; ++j) {
5305
  auto & decoder = state->decoders[j];
@@ -5318,17 +5430,28 @@ int whisper_full_with_state(
5318
  ++cur_c;
5319
  }
5320
 
5321
- decoder.sequence = cur.sequence;
5322
  decoder.seek_delta = cur.seek_delta;
5323
  decoder.has_ts = cur.has_ts;
 
 
 
 
5324
 
5325
- decoder_idx[j] = cur.decoder_idx;
5326
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
5327
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
5328
  }
5329
 
5330
- // update KV caches
5331
- whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
 
 
 
 
 
 
 
 
 
5332
  }
5333
 
5334
  // update the decoder state
@@ -5437,32 +5560,83 @@ int whisper_full_with_state(
5437
  state->t_sample_us += ggml_time_us() - t_start_sample_us;
5438
 
5439
  // obtain logits for the next token
5440
- for (int j = 0; j < n_decoders_cur; ++j) {
5441
- auto & decoder = state->decoders[j];
5442
 
5443
- if (decoder.failed || decoder.completed) {
5444
- continue;
5445
- }
 
 
 
5446
 
5447
- decoder.tokens_tmp.resize(1);
5448
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
 
5449
 
5450
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
5451
 
5452
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
 
 
 
 
 
 
 
 
 
 
 
 
5453
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5454
  return -8;
5455
  }
5456
 
 
 
 
5457
  {
5458
- const int64_t t_start_sample_us = ggml_time_us();
 
 
 
 
 
 
 
 
5459
 
5460
- whisper_process_logits(*ctx, *state, params, decoder, t_cur);
5461
 
5462
- ++decoder.kv_self.n;
 
 
5463
 
5464
- state->t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5465
  }
 
 
5466
  }
5467
  }
5468
 
@@ -5759,11 +5933,13 @@ int whisper_full_parallel(
5759
  ctx->state->t_sample_us += states[i]->t_sample_us;
5760
  ctx->state->t_encode_us += states[i]->t_encode_us;
5761
  ctx->state->t_decode_us += states[i]->t_decode_us;
 
5762
  ctx->state->t_prompt_us += states[i]->t_prompt_us;
5763
 
5764
  ctx->state->n_sample += states[i]->n_sample;
5765
  ctx->state->n_encode += states[i]->n_encode;
5766
  ctx->state->n_decode += states[i]->n_decode;
 
5767
  ctx->state->n_prompt += states[i]->n_prompt;
5768
 
5769
  whisper_free_state(states[i]);
 
20
  #include "ggml-alloc.h"
21
  #include "ggml-backend.h"
22
 
23
+ #include <atomic>
24
  #include <algorithm>
25
  #include <cassert>
26
  #define _USE_MATH_DEFINES
 
148
 
149
  //#define WHISPER_USE_FLASH_ATTN
150
  //#define WHISPER_USE_FLASH_FF
151
+ #define WHISPER_MAX_DECODERS 8
152
  #define WHISPER_MAX_NODES 4096
153
 
154
  //
 
407
  bool speaker_turn_next;
408
  };
409
 
410
+ struct whisper_batch {
411
+ int32_t n_tokens;
412
+
413
+ whisper_token * token;
414
+ whisper_pos * pos;
415
+ int32_t * n_seq_id;
416
+ whisper_seq_id ** seq_id; // null terminated
417
+ int8_t * logits;
418
+ };
419
+
420
+ static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
421
+ whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
422
+
423
+ batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
424
+ batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
425
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
426
+ batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
427
+ for (int i = 0; i < n_tokens; ++i) {
428
+ batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
429
+ }
430
+ batch.seq_id[n_tokens] = nullptr;
431
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
432
+
433
+ return batch;
434
+ }
435
+
436
+ static void whisper_batch_free(struct whisper_batch batch) {
437
+ if (batch.token) free(batch.token);
438
+ if (batch.pos) free(batch.pos);
439
+ if (batch.n_seq_id) free(batch.n_seq_id);
440
+ if (batch.seq_id) {
441
+ for (int i = 0; batch.seq_id[i]; ++i) {
442
+ free(batch.seq_id[i]);
443
+ }
444
+ free(batch.seq_id);
445
+ }
446
+ if (batch.logits) free(batch.logits);
447
+ }
448
+
449
+ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
450
+ batch.n_tokens = n_tokens;
451
+ for (int i = 0; i < n_tokens; ++i) {
452
+ if (tokens) {
453
+ batch.token[i] = tokens[i];
454
+ }
455
+ batch.pos [i] = n_past + i;
456
+ batch.n_seq_id[i] = 1;
457
+ batch.seq_id [i][0] = seq_id;
458
+ batch.logits [i] = 0;
459
+ }
460
+ batch.logits[n_tokens - 1] = 1;
461
+ }
462
+
463
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
464
+ template<typename A, typename B>
465
+ struct whisper_pair {
466
+ A first;
467
+ B second;
468
+
469
+ // Define a constructor that takes two arguments.
470
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
471
+ // Define a constructor that takes no argument.
472
+ whisper_pair() : first(A()), second(B()) {}
473
+ };
474
+
475
+ // ggml_allocr wrapper for whisper usage
476
+ struct whisper_allocr {
477
+ ggml_allocr * alloc = nullptr;
478
+
479
+ std::vector<uint8_t> meta;
480
+
481
+ ggml_backend_buffer_t buffer;
482
+ };
483
+
484
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
485
+ return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
486
+ }
487
+
488
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
489
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
490
+ auto & alloc = allocr.alloc;
491
+ auto & meta = allocr.meta;
492
+
493
+ alloc = ggml_allocr_new_measure_from_backend(backend);
494
+
495
+ meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
496
+
497
+ ggml_allocr_alloc_graph(alloc, get_graph());
498
+ }
499
+
500
+ static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
501
+ if (allocr.alloc == nullptr) {
502
+ // this can be null if we use external encoder like CoreML or OpenVINO
503
+ return;
504
+ }
505
+
506
+ auto & alloc = allocr.alloc;
507
+ auto & buffer = allocr.buffer;
508
+
509
+ size_t size = ggml_allocr_max_size(alloc);
510
+
511
+ ggml_allocr_free(alloc);
512
+
513
+ buffer = ggml_backend_alloc_buffer(backend, size);
514
+ alloc = ggml_allocr_new_from_buffer(buffer);
515
+ }
516
+
517
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
518
+ if (allocr.alloc) {
519
+ ggml_allocr_free(allocr.alloc);
520
+ ggml_backend_buffer_free(allocr.buffer);
521
+ allocr.alloc = nullptr;
522
+ }
523
+ }
524
+
525
  // medium
526
  // hparams: {
527
  // 'n_mels': 80,
 
639
  struct ggml_tensor * mlp_1_b;
640
  };
641
 
642
+ struct whisper_kv_cell {
643
+ whisper_pos pos = -1;
644
+
645
+ std::set<whisper_seq_id> seq_id;
646
+
647
+ bool has_seq_id(const whisper_seq_id & id) const {
648
+ return seq_id.find(id) != seq_id.end();
649
+ }
650
+ };
651
+
652
  struct whisper_kv_cache {
653
+ uint32_t head = 0;
654
+ uint32_t size = 0;
655
+
656
+ // computed before each graph build
657
+ uint32_t n = 0;
658
+
659
+ std::vector<whisper_kv_cell> cells;
660
+
661
  struct ggml_tensor * k;
662
  struct ggml_tensor * v;
663
 
664
  struct ggml_context * ctx;
665
 
666
  ggml_backend_buffer_t buffer;
 
 
667
  };
668
 
669
  struct whisper_model {
 
717
  };
718
 
719
  struct whisper_grammar {
720
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
721
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
722
 
723
  // buffer for partially generated UTF-8 sequence from accepted tokens
724
+ whisper_partial_utf8 partial_utf8;
725
  };
726
 
727
  struct whisper_grammar_candidate {
 
745
 
746
  // TAGS: WHISPER_DECODER_INIT
747
  struct whisper_decoder {
 
 
 
748
  // the currently generated sequence of tokens
749
  whisper_sequence sequence;
750
 
751
  // grammar parse state of generated sequence of tokens
752
  whisper_grammar grammar;
753
 
754
+ int i_batch; // the index of the token in the current batch
755
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
756
 
757
  bool failed; // has the current segment failed to decode?
 
763
  std::vector<float> logits;
764
  std::vector<float> logprobs;
765
 
766
+ // work container used to avoid memory allocations
767
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
770
  };
771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  struct whisper_state {
773
  int64_t t_sample_us = 0;
774
  int64_t t_encode_us = 0;
775
  int64_t t_decode_us = 0;
776
+ int64_t t_batchd_us = 0;
777
  int64_t t_prompt_us = 0;
778
  int64_t t_mel_us = 0;
779
 
780
  int32_t n_sample = 0; // number of tokens sampled
781
  int32_t n_encode = 0; // number of encoder calls
782
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
783
+ int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
784
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
785
  int32_t n_fail_p = 0; // number of logprob threshold failures
786
  int32_t n_fail_h = 0; // number of entropy threshold failures
787
 
788
+ // unified self-attention KV cache for all decoders
789
+ whisper_kv_cache kv_self;
790
+
791
  // cross-attention KV cache for the decoders
792
  // shared between all decoders
793
  whisper_kv_cache kv_cross;
794
+
795
  whisper_mel mel;
796
 
797
+ whisper_batch batch;
798
 
799
+ whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
800
 
801
  ggml_backend_t backend = nullptr;
802
 
 
812
  struct ggml_tensor * embd_conv = nullptr;
813
  struct ggml_tensor * embd_enc = nullptr;
814
 
815
+ // helpers for GPU offloading
816
  std::vector<float> inp_mel;
817
+ std::vector<float> inp_mask;
818
 
819
  // decode output (2-dimensional array: [n_tokens][n_vocab])
820
  std::vector<float> logits;
 
822
  std::vector<whisper_segment> result_all;
823
  std::vector<whisper_token> prompt_past;
824
 
 
 
 
 
 
825
  int lang_id = 0; // english by default
826
 
827
  std::string path_model; // populated by whisper_init_from_file_with_params()
 
897
  /*.no_alloc =*/ true,
898
  };
899
 
900
+ cache.head = 0;
901
+ cache.size = n_ctx;
902
+
903
+ cache.cells.clear();
904
+ cache.cells.resize(n_ctx);
905
+
906
  cache.ctx = ggml_init(params);
907
 
908
  if (!cache.ctx) {
 
930
  return true;
931
  }
932
 
933
+ static void kv_cache_free(struct whisper_kv_cache & cache) {
934
+ if (cache.ctx) {
935
+ ggml_free(cache.ctx);
936
+ ggml_backend_buffer_free(cache.buffer);
937
+ cache.ctx = nullptr;
938
+ }
939
+ }
940
 
941
+ static bool whisper_kv_cache_find_slot(
942
+ struct whisper_kv_cache & cache,
943
+ const struct whisper_batch & batch) {
944
+ const uint32_t n_ctx = cache.size;
945
+ const uint32_t n_tokens = batch.n_tokens;
946
 
947
+ if (n_tokens > n_ctx) {
948
+ WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
949
+ return false;
950
+ }
951
 
952
+ uint32_t n_tested = 0;
 
 
 
 
953
 
954
+ while (true) {
955
+ if (cache.head + n_tokens > n_ctx) {
956
+ n_tested += n_ctx - cache.head;
957
+ cache.head = 0;
958
+ continue;
959
+ }
960
 
961
+ bool found = true;
962
+ for (uint32_t i = 0; i < n_tokens; i++) {
963
+ if (cache.cells[cache.head + i].pos >= 0) {
964
+ found = false;
965
+ cache.head += i + 1;
966
+ n_tested += i + 1;
967
+ break;
968
+ }
969
+ }
970
+
971
+ if (found) {
972
+ break;
973
+ }
974
+
975
+ if (n_tested >= n_ctx) {
976
+ //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
977
+ return false;
978
+ }
979
  }
980
 
981
+ for (uint32_t i = 0; i < n_tokens; i++) {
982
+ cache.cells[cache.head + i].pos = batch.pos[i];
983
 
984
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
985
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
986
+ }
987
+ }
988
 
989
+ return true;
990
+ }
991
 
992
+ // find how many cells are currently in use
993
+ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
994
+ for (uint32_t i = cache.size - 1; i > 0; --i) {
995
+ if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
996
+ return i + 1;
997
+ }
998
+ }
999
 
1000
+ return 1;
1001
+ }
1002
 
1003
+ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1004
+ for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
1005
+ cache.cells[i].pos = -1;
1006
+ cache.cells[i].seq_id.clear();
1007
  }
1008
+ cache.head = 0;
1009
+ }
1010
 
1011
+ static void whisper_kv_cache_seq_rm(
1012
+ struct whisper_kv_cache & cache,
1013
+ whisper_seq_id seq_id,
1014
+ whisper_pos p0,
1015
+ whisper_pos p1) {
1016
+ uint32_t new_head = cache.size;
1017
+
1018
+ if (p0 < 0) p0 = 0;
1019
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1020
+
1021
+ for (uint32_t i = 0; i < cache.size; ++i) {
1022
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1023
+ if (seq_id < 0) {
1024
+ cache.cells[i].seq_id.clear();
1025
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
1026
+ cache.cells[i].seq_id.erase(seq_id);
1027
+ } else {
1028
+ continue;
1029
+ }
1030
+ if (cache.cells[i].seq_id.empty()) {
1031
+ cache.cells[i].pos = -1;
1032
+ if (new_head == cache.size) new_head = i;
1033
+ }
1034
+ }
1035
+ }
1036
+
1037
+ // If we freed up a slot, set head to it so searching can start there.
1038
+ if (new_head != cache.size) cache.head = new_head;
1039
  }
1040
 
1041
+ static void whisper_kv_cache_seq_cp(
1042
+ struct whisper_kv_cache & cache,
1043
+ whisper_seq_id seq_id_src,
1044
+ whisper_seq_id seq_id_dst,
1045
+ whisper_pos p0,
1046
+ whisper_pos p1) {
1047
+ if (p0 < 0) p0 = 0;
1048
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1049
+
1050
+ cache.head = 0;
1051
+
1052
+ for (uint32_t i = 0; i < cache.size; ++i) {
1053
+ if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1054
+ cache.cells[i].seq_id.insert(seq_id_dst);
1055
+ }
1056
  }
1057
  }
1058
 
 
1061
 
1062
  // initialize the backends
1063
  #ifdef GGML_USE_CUBLAS
1064
+ if (params.use_gpu && ggml_cublas_loaded()) {
1065
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
  backend_gpu = ggml_backend_cuda_init();
1067
  if (!backend_gpu) {
 
1263
  word = "[_EOT_]";
1264
  } else if (i == vocab.token_sot) {
1265
  word = "[_SOT_]";
1266
+ } else if (i == vocab.token_translate) {
1267
+ word = "[_TRANSLATE_]";
1268
+ } else if (i == vocab.token_transcribe) {
1269
+ word = "[_TRANSCRIBE_]";
1270
  } else if (i == vocab.token_solm) {
1271
  word = "[_SOLM_]";
1272
  } else if (i == vocab.token_prev) {
 
1277
  word = "[_NOT_]";
1278
  } else if (i == vocab.token_beg) {
1279
  word = "[_BEG_]";
1280
+ } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
1281
+ word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
1282
  } else {
1283
  word = "[_extra_token_" + std::to_string(i) + "]";
1284
  }
 
2184
  static struct ggml_cgraph * whisper_build_graph_decoder(
2185
  whisper_context & wctx,
2186
  whisper_state & wstate,
2187
+ const whisper_batch & batch) {
 
 
 
2188
  const auto & model = wctx.model;
2189
  const auto & hparams = model.hparams;
2190
 
2191
+ auto & kv_self = wstate.kv_self;
2192
 
2193
  WHISPER_ASSERT(!!kv_self.ctx);
2194
 
2195
+ ggml_allocr * alloc = wstate.alloc_decode.alloc;
2196
+
2197
+ const int n_ctx = kv_self.size;
2198
  const int n_state = hparams.n_text_state;
2199
  const int n_head = hparams.n_text_head;
2200
  const int n_layer = hparams.n_text_layer;
2201
 
2202
+ const int n_tokens = batch.n_tokens;
2203
+ const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2204
 
2205
+ const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
2206
+ const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
2207
+
2208
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2209
 
2210
  struct ggml_init_params params = {
2211
  /*.mem_size =*/ wstate.alloc_decode.meta.size(),
 
2217
 
2218
  ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2219
 
2220
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
 
 
2221
  ggml_allocr_alloc(alloc, embd);
2222
 
2223
  if (!ggml_allocr_is_measure(alloc)) {
2224
+ ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
2225
  }
2226
 
2227
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
2228
  ggml_allocr_alloc(alloc, position);
2229
 
2230
  if (!ggml_allocr_is_measure(alloc)) {
2231
+ for (int i = 0; i < n_tokens; ++i) {
2232
+ const int32_t val = batch.pos[i];
2233
  ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2234
  }
2235
  }
 
2242
  ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2243
  }
2244
 
2245
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2246
+ ggml_allocr_alloc(alloc, KQ_mask);
2247
+
2248
+ if (!ggml_allocr_is_measure(alloc)) {
2249
+ wstate.inp_mask.resize(n_kv*n_tokens);
2250
+
2251
+ float * data = wstate.inp_mask.data();
2252
+ memset(data, 0, ggml_nbytes(KQ_mask));
2253
+
2254
+ for (int h = 0; h < 1; ++h) {
2255
+ for (int j = 0; j < n_tokens; ++j) {
2256
+ const whisper_pos pos = batch.pos[j];
2257
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
2258
+
2259
+ for (int i = 0; i < n_kv; ++i) {
2260
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2261
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2262
+ }
2263
+ }
2264
+ }
2265
+ }
2266
+
2267
+ ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2268
+ }
2269
+
2270
  // token encoding + position encoding
2271
  struct ggml_tensor * cur =
2272
  ggml_add(ctx0,
 
2319
  Vcur,
2320
  layer.attn_v_b);
2321
 
2322
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2323
 
2324
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2325
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2326
  ( n_ctx)*ggml_element_size(kv_self.v),
2327
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2328
 
2329
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2330
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
 
2334
 
2335
  struct ggml_tensor * Q =
2336
  ggml_permute(ctx0,
2337
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2338
  0, 2, 1, 3);
2339
 
2340
  struct ggml_tensor * K =
2341
  ggml_view_3d(ctx0, kv_self.k,
2342
+ n_state/n_head, n_kv, n_head,
2343
  ggml_element_size(kv_self.k)*n_state,
2344
  ggml_element_size(kv_self.k)*n_state/n_head,
2345
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
 
2349
 
2350
  //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
2351
 
2352
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
2353
+ struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
2354
 
2355
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2356
 
2357
  struct ggml_tensor * V =
2358
  ggml_view_3d(ctx0, kv_self.v,
2359
+ n_kv, n_state/n_head, n_head,
2360
  n_ctx*ggml_element_size(kv_self.v),
2361
  n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2362
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2363
 
2364
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2365
 
 
2367
 
2368
  cur = ggml_cpy(ctx0,
2369
  KQV_merged,
2370
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2371
  }
2372
 
2373
  // projection
 
2411
  // Kcross is already scaled
2412
  struct ggml_tensor * Kcross =
2413
  ggml_view_3d(ctx0, wstate.kv_cross.k,
2414
+ n_state/n_head, n_audio_ctx, n_head,
2415
  ggml_element_size(wstate.kv_cross.k)*n_state,
2416
  ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2417
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2418
 
2419
  //struct ggml_tensor * Vcross =
2420
  // ggml_reshape_3d(ctx0,
2421
+ // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
2422
+ // n_state/n_head, n_head, n_audio_ctx);
2423
 
2424
  //struct ggml_tensor * V_trans =
2425
  // ggml_cpy(ctx0,
2426
  // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2427
+ // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2428
 
2429
  struct ggml_tensor * V =
2430
  ggml_view_3d(ctx0, wstate.kv_cross.v,
2431
+ n_audio_ctx, n_state/n_head, n_head,
2432
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2433
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2434
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2435
 
2436
  // ------
2437
 
2438
  struct ggml_tensor * Q =
2439
  ggml_permute(ctx0,
2440
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2441
  0, 2, 1, 3);
2442
 
2443
  // K * Q
 
2458
 
2459
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2460
 
2461
+ // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2462
  cur = ggml_cpy(ctx0,
2463
  KQV_merged,
2464
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2465
  }
2466
 
2467
  // projection
 
2533
  }
2534
 
2535
  // compute logits only for the last token
2536
+ // comment this line to compute logits for all n_tokens
2537
  // might be useful in the future
2538
+ //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2539
 
2540
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2541
 
 
2559
  static bool whisper_decode_internal(
2560
  whisper_context & wctx,
2561
  whisper_state & wstate,
2562
+ const whisper_batch & batch,
 
 
 
2563
  const int n_threads,
2564
  whisper_abort_callback abort_callback,
2565
  void * abort_callback_data) {
 
2568
  const auto & model = wctx.model;
2569
  const auto & hparams = model.hparams;
2570
 
2571
+ const int n_vocab = hparams.n_vocab;
2572
+ const int n_tokens = batch.n_tokens;
2573
 
2574
  auto & logits_out = wstate.logits;
2575
 
2576
  struct ggml_tensor * logits;
2577
 
2578
+ // find KV slot for the batch
2579
+ {
2580
+ auto & kv_self = wstate.kv_self;
2581
+
2582
+ if (!whisper_kv_cache_find_slot(kv_self, batch)) {
2583
+ return false;
2584
+ }
2585
+
2586
+ kv_self.n = whisper_kv_cache_cell_max(kv_self);
2587
+ //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2588
+ //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2589
+ }
2590
+
2591
  // decoder
2592
  {
2593
  auto & alloc = wstate.alloc_decode.alloc;
2594
 
2595
  ggml_allocr_reset(alloc);
2596
 
2597
+ ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
2598
 
2599
  ggml_allocr_alloc_graph(alloc, gf);
2600
 
 
2603
  ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2604
  }
2605
 
2606
+ logits_out.resize(n_tokens*n_vocab);
2607
+ for (int i = 0; i < n_tokens; i++) {
2608
+ if (batch.logits[i] == 0) {
2609
+ continue;
2610
+ }
2611
+ ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
2612
+ }
 
 
2613
 
2614
+ if (batch.n_tokens > 1) {
2615
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2616
  // ggml_used_mem(ctx0)/1024.0/1024.0,
2617
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
 
2620
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2621
  }
2622
 
2623
+ if (batch.n_tokens == 1) {
2624
  wstate.t_decode_us += ggml_time_us() - t_start_us;
2625
  wstate.n_decode++;
2626
+ } else if (batch.n_tokens < 16) {
2627
+ wstate.t_batchd_us += ggml_time_us() - t_start_us;
2628
+ wstate.n_batchd += n_tokens;
2629
  } else {
2630
  wstate.t_prompt_us += ggml_time_us() - t_start_us;
2631
+ wstate.n_prompt += n_tokens;
2632
  }
2633
 
2634
  return !(abort_callback && abort_callback(abort_callback_data));
2635
  }
2636
 
 
2637
  // 500 -> 00:05.000
2638
  // 6000 -> 01:00.000
2639
  static std::string to_timestamp(int64_t t, bool comma = false) {
 
3045
 
3046
  state->backend = whisper_backend_init(ctx->params);
3047
 
3048
+ // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3049
+ // in theory, there can be a case where this is not enough, but in practice it should always be enough
3050
+ const int factor = 3;
3051
+
3052
+ if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3053
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3054
  delete state;
3055
  return nullptr;
3056
  }
3057
 
3058
  {
3059
+ const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
3060
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
3061
  }
3062
 
 
3091
 
3092
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
3093
 
3094
+ state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
3095
 
3096
  // TAGS: WHISPER_DECODER_INIT
3097
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
3098
 
3099
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
3100
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
3101
+ state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
3102
+ state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
3103
+
3104
+ state->decoders[0].rng = std::mt19937(0);
3105
 
3106
  // conv allocator
3107
  {
 
3143
  const int n_tokens = hparams.n_text_ctx;
3144
  const int n_past = 0;
3145
 
3146
+ whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3147
+
3148
+ return whisper_build_graph_decoder(*ctx, *state, state->batch);
3149
  });
3150
 
3151
  WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
 
3156
  whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
3157
  whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
3158
 
 
 
3159
  return state;
3160
  }
3161
 
 
3380
  void whisper_free_state(struct whisper_state * state)
3381
  {
3382
  if (state) {
3383
+ kv_cache_free(state->kv_self);
3384
  kv_cache_free(state->kv_cross);
3385
 
 
 
 
 
3386
  #ifdef WHISPER_USE_COREML
3387
  if (state->ctx_coreml != nullptr) {
3388
  whisper_coreml_free(state->ctx_coreml);
 
3397
  }
3398
  #endif
3399
 
3400
+ whisper_batch_free(state->batch);
3401
+
3402
  whisper_allocr_free(state->alloc_conv);
3403
  whisper_allocr_free(state->alloc_encode);
3404
  whisper_allocr_free(state->alloc_cross);
 
3525
  }
3526
 
3527
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3528
+ whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
3529
+
3530
+ whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
3531
 
3532
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3533
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3534
  return 1;
3535
  }
 
3538
  }
3539
 
3540
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
 
 
 
3541
  if (ctx->state == nullptr) {
3542
  WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
3543
  return false;
3544
  }
3545
 
3546
+ whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
3547
+
3548
+ whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
3549
+
3550
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
3551
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3552
  return 1;
3553
  }
 
3635
  return -7;
3636
  }
3637
 
3638
+ auto & logits_id = state->decoders[0].logits_id;
3639
  logits_id.clear();
3640
 
3641
  for (const auto & kv : g_lang) {
 
3838
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3839
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3840
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3841
+ const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
3842
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3843
 
3844
  WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
 
3846
  WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3847
  WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3848
  WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3849
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
3850
  WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3851
  }
3852
  WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 
3863
  ctx->state->n_sample = 0;
3864
  ctx->state->n_encode = 0;
3865
  ctx->state->n_decode = 0;
3866
+ ctx->state->n_batchd = 0;
3867
  ctx->state->n_prompt = 0;
3868
  }
3869
  }
 
4171
  if (*tok.code_points == 0) {
4172
  // reached end of full codepoints in token, reject iff it ended in a partial sequence
4173
  // that cannot satisfy this position in grammar
4174
+ if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
 
4175
  rejects.push_back(tok);
4176
  }
4177
  } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
 
4390
  /*.max_initial_ts =*/ 1.0f,
4391
  /*.length_penalty =*/ -1.0f,
4392
 
4393
+ /*.temperature_inc =*/ 0.2f,
4394
  /*.entropy_thold =*/ 2.4f,
4395
  /*.logprob_thold =*/ -1.0f,
4396
  /*.no_speech_thold =*/ 0.6f,
 
4430
  case WHISPER_SAMPLING_GREEDY:
4431
  {
4432
  result.greedy = {
4433
+ /*.best_of =*/ 5,
4434
  };
4435
  } break;
4436
  case WHISPER_SAMPLING_BEAM_SEARCH:
4437
  {
4438
  result.beam_search = {
4439
+ /*.beam_size =*/ 5,
4440
 
4441
  /*.patience =*/ -1.0f,
4442
  };
 
4526
  // process the logits for the selected decoder
4527
  // - applies logit filters
4528
  // - computes logprobs and probs
4529
+ // TODO: optimize
4530
  static void whisper_process_logits(
4531
  struct whisper_context & ctx,
4532
  struct whisper_state & state,
 
4533
  struct whisper_decoder & decoder,
4534
+ const struct whisper_full_params params,
4535
  float temperature) {
4536
  const auto & vocab = ctx.vocab;
4537
  const auto & tokens_cur = decoder.sequence.tokens;
 
4548
  auto & logprobs = decoder.logprobs;
4549
  {
4550
  logits.resize(n_logits);
4551
+ memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
4552
 
4553
  if (temperature > 0.0f) {
4554
  for (int i = 0; i < n_logits; i++) {
 
4714
  //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4715
 
4716
  if (timestamp_logprob > max_text_token_logprob) {
 
4717
  for (int i = 0; i < vocab.token_beg; ++i) {
4718
  logits[i] = -INFINITY;
4719
  logprobs[i] = -INFINITY;
4720
  }
4721
+ } else {
4722
+ if (params.n_grammar_rules > 0) {
4723
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
4724
 
4725
+ // populate the logprobs array (log_softmax)
4726
+ {
4727
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4728
+ float logsumexp = 0.0f;
4729
+ for (int i = 0; i < n_logits; ++i) {
4730
+ if (logits[i] > -INFINITY) {
4731
+ logsumexp += expf(logits[i] - logit_max);
4732
+ }
4733
  }
4734
+ logsumexp = logf(logsumexp) + logit_max;
 
4735
 
4736
+ for (int i = 0; i < n_logits; ++i) {
4737
+ if (logits[i] > -INFINITY) {
4738
+ logprobs[i] = logits[i] - logsumexp;
4739
+ } else {
4740
+ logprobs[i] = -INFINITY;
4741
+ }
4742
  }
4743
  }
4744
  }
 
4813
 
4814
  static whisper_token_data whisper_sample_token(
4815
  whisper_context & ctx,
 
4816
  const whisper_decoder & decoder,
4817
  bool best) {
4818
  whisper_token_data result = {
 
4857
  } else {
4858
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4859
 
4860
+ result.id = dist(decoder.rng);
4861
  result.p = probs[result.id];
4862
  result.plog = logprobs[result.id];
4863
  }
 
4867
  result.pt = result.p;
4868
  }
4869
 
 
 
4870
  return result;
4871
  }
4872
 
4873
  static std::vector<whisper_token_data> whisper_sample_token_topk(
4874
  whisper_context & ctx,
4875
+ whisper_decoder & decoder,
 
4876
  int k) {
4877
  const auto & vocab = ctx.vocab;
4878
 
 
4882
 
4883
  const int n_logits = vocab.n_vocab;
4884
 
4885
+ auto & logits_id = decoder.logits_id;
4886
 
4887
  logits_id.resize(n_logits);
4888
  for (int i = 0; i < n_logits; ++i) {
 
4931
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4932
 
4933
  for (int i = 0; i < k; ++i) {
4934
+ const auto id = dist(decoder.rng);
4935
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4936
 
4937
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
 
4942
  }
4943
  }
4944
 
 
 
4945
  return result;
4946
  }
4947
 
 
4994
  }
4995
  }
4996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4997
  int whisper_full_with_state(
4998
  struct whisper_context * ctx,
4999
  struct whisper_state * state,
 
5083
 
5084
  n_decoders = std::max(1, n_decoders);
5085
 
5086
+ if (n_decoders > WHISPER_MAX_DECODERS) {
5087
+ WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
5088
+ return -4;
5089
+ }
5090
+
5091
  // TAGS: WHISPER_DECODER_INIT
5092
  for (int j = 1; j < n_decoders; j++) {
5093
  auto & decoder = state->decoders[j];
5094
 
5095
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
 
 
 
 
 
 
5096
 
5097
+ decoder.probs.resize (ctx->vocab.n_vocab);
5098
+ decoder.logits.resize (ctx->vocab.n_vocab);
5099
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
5100
+ decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
5101
 
5102
+ decoder.rng = std::mt19937(0);
 
 
 
5103
  }
5104
 
5105
  // the accumulated text context so far
 
5176
  bool has_ts;
5177
 
5178
  whisper_sequence sequence;
5179
+ whisper_grammar grammar;
5180
  };
5181
 
5182
+ std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
5183
  std::vector<beam_candidate> beam_candidates;
5184
 
5185
  // main loop
 
5247
  for (int j = 0; j < n_decoders_cur; ++j) {
5248
  auto & decoder = state->decoders[j];
5249
 
 
 
5250
  decoder.sequence.tokens.clear();
5251
  decoder.sequence.result_len = 0;
5252
  decoder.sequence.sum_logprobs_all = 0.0;
 
5262
  decoder.has_ts = false;
5263
 
5264
  if (params.grammar_rules != nullptr) {
5265
+ decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
 
5266
  } else {
5267
  decoder.grammar = {};
5268
  }
5269
  }
5270
 
5271
  // init prompt and kv cache for the current iteration
5272
+ // TODO: do not recompute the prompt if it is the same as previous time
5273
  {
5274
  prompt.clear();
5275
 
 
5291
  }
5292
  WHISPER_PRINT_DEBUG("\n\n");
5293
 
5294
+ whisper_kv_cache_clear(state->kv_self);
5295
+
5296
+ whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5297
+
5298
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5299
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5300
  return -7;
5301
  }
 
5303
  {
5304
  const int64_t t_start_sample_us = ggml_time_us();
5305
 
5306
+ state->decoders[0].i_batch = prompt.size() - 1;
5307
 
5308
+ whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
5309
 
5310
  for (int j = 1; j < n_decoders_cur; ++j) {
5311
  auto & decoder = state->decoders[j];
5312
 
5313
+ whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
 
 
 
 
 
 
5314
 
5315
  memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
5316
  memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
 
5325
  const int64_t t_start_sample_us = ggml_time_us();
5326
 
5327
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
5328
+ for (auto & bc : bc_per_dec) {
5329
+ bc.clear();
5330
+ }
5331
  }
5332
 
5333
+ // sampling
5334
+ // TODO: avoid memory allocations, optimize, avoid threads?
5335
+ {
5336
+ std::atomic<int> j_cur(0);
5337
 
5338
+ auto process = [&]() {
5339
+ while (true) {
5340
+ const int j = j_cur.fetch_add(1);
5341
 
5342
+ if (j >= n_decoders_cur) {
5343
+ break;
5344
+ }
 
 
 
 
 
5345
 
5346
+ auto & decoder = state->decoders[j];
 
 
 
 
5347
 
5348
+ if (decoder.completed || decoder.failed) {
5349
+ continue;
5350
+ }
 
5351
 
5352
+ switch (params.strategy) {
5353
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
5354
+ {
5355
+ if (t_cur < 1e-6f) {
5356
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
5357
+ } else {
5358
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
5359
+ }
5360
+
5361
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
5362
+ } break;
5363
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
5364
+ {
5365
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
5366
+
5367
+ for (const auto & token : tokens_new) {
5368
+ bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
5369
+ bc_per_dec[j].back().sequence.tokens.push_back(token);
5370
+ bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
5371
+ }
5372
+ } break;
5373
+ };
5374
+ }
5375
  };
5376
+
5377
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5378
+
5379
+ if (n_threads == 1) {
5380
+ process();
5381
+ } else {
5382
+ std::vector<std::thread> threads(n_threads - 1);
5383
+
5384
+ for (int t = 0; t < n_threads - 1; ++t) {
5385
+ threads[t] = std::thread(process);
5386
+ }
5387
+
5388
+ process();
5389
+
5390
+ for (int t = 0; t < n_threads - 1; ++t) {
5391
+ threads[t].join();
5392
+ }
5393
+ }
5394
+ }
5395
+
5396
+ beam_candidates.clear();
5397
+ for (const auto & bc : bc_per_dec) {
5398
+ beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
5399
+
5400
+ if (!bc.empty()) {
5401
+ state->n_sample += 1;
5402
+ }
5403
  }
5404
 
5405
  // for beam-search, choose the top candidates and update the KV caches
 
5412
  });
5413
 
5414
  uint32_t cur_c = 0;
 
5415
 
5416
  for (int j = 0; j < n_decoders_cur; ++j) {
5417
  auto & decoder = state->decoders[j];
 
5430
  ++cur_c;
5431
  }
5432
 
 
5433
  decoder.seek_delta = cur.seek_delta;
5434
  decoder.has_ts = cur.has_ts;
5435
+ decoder.sequence = cur.sequence;
5436
+ decoder.grammar = cur.grammar;
5437
+
5438
+ whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
5439
 
 
5440
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
5441
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
5442
  }
5443
 
5444
+ for (int j = 0; j < n_decoders_cur; ++j) {
5445
+ auto & decoder = state->decoders[j];
5446
+
5447
+ if (decoder.completed || decoder.failed) {
5448
+ continue;
5449
+ }
5450
+
5451
+ whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
5452
+ whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
5453
+ whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
5454
+ }
5455
  }
5456
 
5457
  // update the decoder state
 
5560
  state->t_sample_us += ggml_time_us() - t_start_sample_us;
5561
 
5562
  // obtain logits for the next token
5563
+ {
5564
+ auto & batch = state->batch;
5565
 
5566
+ batch.n_tokens = 0;
5567
+
5568
+ const int n_past = prompt.size() + i;
5569
+
5570
+ for (int j = 0; j < n_decoders_cur; ++j) {
5571
+ auto & decoder = state->decoders[j];
5572
 
5573
+ if (decoder.failed || decoder.completed) {
5574
+ continue;
5575
+ }
5576
 
5577
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
5578
 
5579
+ decoder.i_batch = batch.n_tokens;
5580
+
5581
+ batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
5582
+ batch.pos [batch.n_tokens] = n_past;
5583
+ batch.n_seq_id[batch.n_tokens] = 1;
5584
+ batch.seq_id [batch.n_tokens][0] = j;
5585
+ batch.logits [batch.n_tokens] = 1;
5586
+ batch.n_tokens++;
5587
+ }
5588
+
5589
+ assert(batch.n_tokens > 0);
5590
+
5591
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5592
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5593
  return -8;
5594
  }
5595
 
5596
+ const int64_t t_start_sample_us = ggml_time_us();
5597
+
5598
+ // TODO: avoid memory allocations, optimize, avoid threads?
5599
  {
5600
+ std::atomic<int> j_cur(0);
5601
+
5602
+ auto process = [&]() {
5603
+ while (true) {
5604
+ const int j = j_cur.fetch_add(1);
5605
+
5606
+ if (j >= n_decoders_cur) {
5607
+ break;
5608
+ }
5609
 
5610
+ auto & decoder = state->decoders[j];
5611
 
5612
+ if (decoder.failed || decoder.completed) {
5613
+ continue;
5614
+ }
5615
 
5616
+ whisper_process_logits(*ctx, *state, decoder, params, t_cur);
5617
+ }
5618
+ };
5619
+
5620
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5621
+
5622
+ if (n_threads == 1) {
5623
+ process();
5624
+ } else {
5625
+ std::vector<std::thread> threads(n_threads - 1);
5626
+
5627
+ for (int t = 0; t < n_threads - 1; ++t) {
5628
+ threads[t] = std::thread(process);
5629
+ }
5630
+
5631
+ process();
5632
+
5633
+ for (int t = 0; t < n_threads - 1; ++t) {
5634
+ threads[t].join();
5635
+ }
5636
+ }
5637
  }
5638
+
5639
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
5640
  }
5641
  }
5642
 
 
5933
  ctx->state->t_sample_us += states[i]->t_sample_us;
5934
  ctx->state->t_encode_us += states[i]->t_encode_us;
5935
  ctx->state->t_decode_us += states[i]->t_decode_us;
5936
+ ctx->state->t_batchd_us += states[i]->t_batchd_us;
5937
  ctx->state->t_prompt_us += states[i]->t_prompt_us;
5938
 
5939
  ctx->state->n_sample += states[i]->n_sample;
5940
  ctx->state->n_encode += states[i]->n_encode;
5941
  ctx->state->n_decode += states[i]->n_decode;
5942
+ ctx->state->n_batchd += states[i]->n_batchd;
5943
  ctx->state->n_prompt += states[i]->n_prompt;
5944
 
5945
  whisper_free_state(states[i]);
whisper.h CHANGED
@@ -78,7 +78,9 @@ extern "C" {
78
  struct whisper_state;
79
  struct whisper_full_params;
80
 
81
- typedef int whisper_token;
 
 
82
 
83
  struct whisper_context_params {
84
  bool use_gpu;
 
78
  struct whisper_state;
79
  struct whisper_full_params;
80
 
81
+ typedef int32_t whisper_pos;
82
+ typedef int32_t whisper_token;
83
+ typedef int32_t whisper_seq_id;
84
 
85
  struct whisper_context_params {
86
  bool use_gpu;