Spaces:
Running
Running
whisper : add batched decoding (#1486)
Browse files* whisper : add whisper_batch
* whisper : move kv_self to whisper_state
* whisper : full batched decoding support
* whisper : fix memory leak in whisper_batch
* whisper : fix mem leak again + remove oboslete function
* whisper : clear kv cache when using whisper_decode API
* whisper : speed-up sampling
* whisper : fix decoders initializer
* bench : add batch size 5 bench
* whisper : add comment about the KV cache size
* whisper : add check for max number of decoders
* whisper : avoid starting sampling threads with bs=1
* whisper : enable beam-search by default
* cuda : sync llama.cpp fixes
- examples/bench/bench.cpp +20 -10
- examples/main/main.cpp +4 -4
- extra/bench-all.sh +4 -3
- ggml-cuda.cu +188 -118
- ggml-cuda.h +5 -0
- whisper.cpp +602 -426
- whisper.h +3 -1
examples/bench/bench.cpp
CHANGED
|
@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
|
| 81 |
}
|
| 82 |
// heat encoder
|
| 83 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 84 |
-
fprintf(stderr, "error: failed to encode
|
| 85 |
return 4;
|
| 86 |
}
|
| 87 |
|
|
@@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
|
| 90 |
|
| 91 |
// prompt heat
|
| 92 |
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
| 93 |
-
fprintf(stderr, "error: failed to
|
| 94 |
return 4;
|
| 95 |
}
|
| 96 |
|
| 97 |
// text-generation heat
|
| 98 |
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
| 99 |
-
fprintf(stderr, "error: failed to
|
| 100 |
return 4;
|
| 101 |
}
|
| 102 |
|
|
@@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
|
|
| 104 |
|
| 105 |
// actual run
|
| 106 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 107 |
-
fprintf(stderr, "error: failed to encode
|
| 108 |
return 4;
|
| 109 |
}
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
| 114 |
return 4;
|
| 115 |
}
|
| 116 |
}
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return 4;
|
| 122 |
}
|
| 123 |
}
|
|
|
|
| 81 |
}
|
| 82 |
// heat encoder
|
| 83 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 84 |
+
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
| 85 |
return 4;
|
| 86 |
}
|
| 87 |
|
|
|
|
| 90 |
|
| 91 |
// prompt heat
|
| 92 |
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
| 93 |
+
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
| 94 |
return 4;
|
| 95 |
}
|
| 96 |
|
| 97 |
// text-generation heat
|
| 98 |
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
| 99 |
+
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
| 100 |
return 4;
|
| 101 |
}
|
| 102 |
|
|
|
|
| 104 |
|
| 105 |
// actual run
|
| 106 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 107 |
+
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
| 108 |
return 4;
|
| 109 |
}
|
| 110 |
|
| 111 |
+
// text-generation
|
| 112 |
+
for (int i = 0; i < 256; i++) {
|
| 113 |
+
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
| 114 |
+
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
| 115 |
return 4;
|
| 116 |
}
|
| 117 |
}
|
| 118 |
|
| 119 |
+
// batched decoding
|
| 120 |
+
for (int i = 0; i < 64; i++) {
|
| 121 |
+
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
| 122 |
+
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
| 123 |
+
return 4;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// prompt processing
|
| 128 |
+
for (int i = 0; i < 16; i++) {
|
| 129 |
+
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
| 130 |
+
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
| 131 |
return 4;
|
| 132 |
}
|
| 133 |
}
|
examples/main/main.cpp
CHANGED
|
@@ -62,8 +62,8 @@ struct whisper_params {
|
|
| 62 |
int32_t progress_step = 5;
|
| 63 |
int32_t max_context = -1;
|
| 64 |
int32_t max_len = 0;
|
| 65 |
-
int32_t best_of =
|
| 66 |
-
int32_t beam_size =
|
| 67 |
|
| 68 |
float word_thold = 0.01f;
|
| 69 |
float entropy_thold = 2.40f;
|
|
@@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
|
| 925 |
if (params.detect_language) {
|
| 926 |
params.language = "auto";
|
| 927 |
}
|
| 928 |
-
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
| 929 |
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
| 930 |
-
params.n_threads, params.n_processors,
|
| 931 |
params.language.c_str(),
|
| 932 |
params.translate ? "translate" : "transcribe",
|
| 933 |
params.tinydiarize ? "tdrz = 1, " : "",
|
|
|
|
| 62 |
int32_t progress_step = 5;
|
| 63 |
int32_t max_context = -1;
|
| 64 |
int32_t max_len = 0;
|
| 65 |
+
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
| 66 |
+
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
| 67 |
|
| 68 |
float word_thold = 0.01f;
|
| 69 |
float entropy_thold = 2.40f;
|
|
|
|
| 925 |
if (params.detect_language) {
|
| 926 |
params.language = "auto";
|
| 927 |
}
|
| 928 |
+
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
| 929 |
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
| 930 |
+
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
| 931 |
params.language.c_str(),
|
| 932 |
params.translate ? "translate" : "transcribe",
|
| 933 |
params.tinydiarize ? "tdrz = 1, " : "",
|
extra/bench-all.sh
CHANGED
|
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
|
| 44 |
printf "\n"
|
| 45 |
fi
|
| 46 |
|
| 47 |
-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
| 48 |
-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
| 49 |
|
| 50 |
for model in "${models[@]}"; do
|
| 51 |
# actual run
|
|
@@ -56,6 +56,7 @@ for model in "${models[@]}"; do
|
|
| 56 |
# parse the output:
|
| 57 |
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
| 58 |
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
|
|
|
| 59 |
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
| 60 |
system_info=$(echo "$output" | grep "system_info")
|
| 61 |
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
|
@@ -94,6 +95,6 @@ for model in "${models[@]}"; do
|
|
| 94 |
commit=$(git rev-parse --short HEAD)
|
| 95 |
|
| 96 |
if [ $ret -eq 0 ]; then
|
| 97 |
-
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
| 98 |
fi
|
| 99 |
done
|
|
|
|
| 44 |
printf "\n"
|
| 45 |
fi
|
| 46 |
|
| 47 |
+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
|
| 48 |
+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
| 49 |
|
| 50 |
for model in "${models[@]}"; do
|
| 51 |
# actual run
|
|
|
|
| 56 |
# parse the output:
|
| 57 |
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
| 58 |
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
| 59 |
+
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
|
| 60 |
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
| 61 |
system_info=$(echo "$output" | grep "system_info")
|
| 62 |
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
|
|
|
| 95 |
commit=$(git rev-parse --short HEAD)
|
| 96 |
|
| 97 |
if [ $ret -eq 0 ]; then
|
| 98 |
+
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
| 99 |
fi
|
| 100 |
done
|
ggml-cuda.cu
CHANGED
|
@@ -39,7 +39,6 @@
|
|
| 39 |
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
| 40 |
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
| 41 |
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
| 42 |
-
#define cudaDeviceGetMemPool hipDeviceGetMemPool
|
| 43 |
#define cudaDeviceProp hipDeviceProp_t
|
| 44 |
#define cudaDeviceSynchronize hipDeviceSynchronize
|
| 45 |
#define cudaError_t hipError_t
|
|
@@ -49,7 +48,6 @@
|
|
| 49 |
#define cudaEvent_t hipEvent_t
|
| 50 |
#define cudaEventDestroy hipEventDestroy
|
| 51 |
#define cudaFree hipFree
|
| 52 |
-
#define cudaFreeAsync hipFreeAsync
|
| 53 |
#define cudaFreeHost hipHostFree
|
| 54 |
#define cudaGetDevice hipGetDevice
|
| 55 |
#define cudaGetDeviceCount hipGetDeviceCount
|
|
@@ -57,7 +55,6 @@
|
|
| 57 |
#define cudaGetErrorString hipGetErrorString
|
| 58 |
#define cudaGetLastError hipGetLastError
|
| 59 |
#define cudaMalloc hipMalloc
|
| 60 |
-
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
|
| 61 |
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
| 62 |
#define cudaMemcpy hipMemcpy
|
| 63 |
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
|
@@ -66,9 +63,6 @@
|
|
| 66 |
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
| 67 |
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
| 68 |
#define cudaMemcpyKind hipMemcpyKind
|
| 69 |
-
#define cudaMemPool_t hipMemPool_t
|
| 70 |
-
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
|
| 71 |
-
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
|
| 72 |
#define cudaMemset hipMemset
|
| 73 |
#define cudaMemsetAsync hipMemsetAsync
|
| 74 |
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
@@ -94,6 +88,8 @@
|
|
| 94 |
#define CC_OFFSET_AMD 1000000
|
| 95 |
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
| 96 |
|
|
|
|
|
|
|
| 97 |
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
| 98 |
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
| 99 |
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
|
@@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|
| 188 |
do { \
|
| 189 |
cudaError_t err_ = (err); \
|
| 190 |
if (err_ != cudaSuccess) { \
|
| 191 |
-
int
|
| 192 |
-
cudaGetDevice(&
|
| 193 |
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
| 194 |
cudaGetErrorString(err_)); \
|
| 195 |
-
fprintf(stderr, "current device: %d\n",
|
| 196 |
exit(1); \
|
| 197 |
} \
|
| 198 |
} while (0)
|
|
@@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|
| 202 |
do { \
|
| 203 |
cublasStatus_t err_ = (err); \
|
| 204 |
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
| 205 |
-
int
|
| 206 |
-
cudaGetDevice(&
|
| 207 |
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
| 208 |
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
| 209 |
-
fprintf(stderr, "current device: %d\n",
|
| 210 |
exit(1); \
|
| 211 |
} \
|
| 212 |
} while (0)
|
|
@@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|
| 440 |
#define CUDA_MUL_BLOCK_SIZE 256
|
| 441 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 442 |
#define CUDA_SILU_BLOCK_SIZE 256
|
|
|
|
|
|
|
| 443 |
#define CUDA_CPY_BLOCK_SIZE 32
|
| 444 |
#define CUDA_SCALE_BLOCK_SIZE 256
|
| 445 |
#define CUDA_CLAMP_BLOCK_SIZE 256
|
|
@@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
|
| 472 |
|
| 473 |
#define MAX_STREAMS 8
|
| 474 |
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
| 475 |
-
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
| 476 |
|
| 477 |
struct ggml_tensor_extra_gpu {
|
| 478 |
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
|
@@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
|
| 561 |
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
| 562 |
}
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
| 565 |
#pragma unroll
|
| 566 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
@@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|
| 990 |
|
| 991 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
| 992 |
|
| 993 |
-
const int row = blockIdx.
|
| 994 |
if (row > nrows) return;
|
| 995 |
|
| 996 |
const int num_blocks_per_row = ncols / QK_K;
|
|
@@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|
| 1094 |
|
| 1095 |
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1096 |
|
| 1097 |
-
const int row = blockIdx.
|
| 1098 |
if (row > nrows) return;
|
| 1099 |
|
| 1100 |
const int num_blocks_per_row = ncols / QK_K;
|
|
@@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
|
| 1198 |
|
| 1199 |
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1200 |
|
| 1201 |
-
const int row = blockIdx.
|
| 1202 |
if (row > nrows) return;
|
| 1203 |
const int num_blocks_per_row = ncols / QK_K;
|
| 1204 |
const int ib0 = row*num_blocks_per_row;
|
|
@@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|
| 1452 |
|
| 1453 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
| 1454 |
|
| 1455 |
-
const int row = blockIdx.
|
| 1456 |
if (row > nrows) return;
|
| 1457 |
|
| 1458 |
const int num_blocks_per_row = ncols / QK_K;
|
|
@@ -4262,7 +4277,7 @@ template <bool need_check> static __global__ void
|
|
| 4262 |
|
| 4263 |
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
| 4264 |
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
| 4265 |
-
const int row = blockIdx.
|
| 4266 |
|
| 4267 |
if (row >= nrows) {
|
| 4268 |
return;
|
|
@@ -4302,7 +4317,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
|
| 4302 |
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
| 4303 |
// qk = quantized weights per x block
|
| 4304 |
// qr = number of quantized weights per data value in x block
|
| 4305 |
-
const int row = blockIdx.
|
| 4306 |
|
| 4307 |
if (row >= nrows) {
|
| 4308 |
return;
|
|
@@ -4741,7 +4756,7 @@ static __global__ void im2col_f32_f16(
|
|
| 4741 |
int ofs0, int ofs1, int IW, int IH, int CHW,
|
| 4742 |
int s0, int s1, int p0, int p1, int d0, int d1) {
|
| 4743 |
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
| 4744 |
-
|
| 4745 |
|
| 4746 |
const int offset_dst =
|
| 4747 |
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
|
@@ -4793,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|
| 4793 |
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 4794 |
}
|
| 4795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4796 |
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4797 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 4798 |
if (ncols < 1024) {
|
|
@@ -4901,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|
| 4901 |
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4902 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4903 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4904 |
-
|
|
|
|
| 4905 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4906 |
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
| 4907 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -4910,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
|
|
| 4910 |
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4911 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4912 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4913 |
-
const dim3 block_nums(
|
| 4914 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4915 |
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
| 4916 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -4919,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
|
|
| 4919 |
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4920 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4921 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4922 |
-
const dim3 block_nums(
|
| 4923 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4924 |
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
| 4925 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -4928,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
|
|
| 4928 |
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4929 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4930 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4931 |
-
const dim3 block_nums(
|
| 4932 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4933 |
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
| 4934 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -4937,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
|
|
| 4937 |
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4938 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4939 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4940 |
-
const dim3 block_nums(
|
| 4941 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4942 |
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
| 4943 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -4947,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
|
|
| 4947 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4948 |
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
| 4949 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4950 |
-
const dim3 block_nums(
|
| 4951 |
const dim3 block_dims(32, ny, 1);
|
| 4952 |
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4953 |
}
|
|
@@ -4956,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
|
|
| 4956 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4957 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 4958 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4959 |
-
const dim3 block_nums(
|
| 4960 |
const dim3 block_dims(32, ny, 1);
|
| 4961 |
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4962 |
}
|
|
@@ -4965,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
|
|
| 4965 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4966 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 4967 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4968 |
-
const dim3 block_nums(
|
| 4969 |
const dim3 block_dims(32, ny, 1);
|
| 4970 |
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4971 |
}
|
|
@@ -4980,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
|
| 4980 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4981 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 4982 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4983 |
-
const dim3 block_nums(
|
| 4984 |
const dim3 block_dims(32, ny, 1);
|
| 4985 |
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4986 |
}
|
|
@@ -4988,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
|
| 4988 |
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4989 |
GGML_ASSERT(ncols % QK4_0 == 0);
|
| 4990 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4991 |
-
const dim3 block_nums(
|
| 4992 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4993 |
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
| 4994 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -4997,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 4997 |
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4998 |
GGML_ASSERT(ncols % QK4_1 == 0);
|
| 4999 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5000 |
-
const dim3 block_nums(
|
| 5001 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5002 |
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
| 5003 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5006,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5006 |
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5007 |
GGML_ASSERT(ncols % QK5_0 == 0);
|
| 5008 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5009 |
-
const dim3 block_nums(
|
| 5010 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5011 |
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
| 5012 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5015,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5015 |
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5016 |
GGML_ASSERT(ncols % QK5_1 == 0);
|
| 5017 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5018 |
-
const dim3 block_nums(
|
| 5019 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5020 |
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
| 5021 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5024,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5024 |
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5025 |
GGML_ASSERT(ncols % QK8_0 == 0);
|
| 5026 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5027 |
-
const dim3 block_nums(
|
| 5028 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5029 |
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
| 5030 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5033,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5033 |
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5034 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5035 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5036 |
-
const dim3 block_nums(
|
| 5037 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5038 |
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
| 5039 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5042,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5042 |
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5043 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5044 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5045 |
-
const dim3 block_nums(
|
| 5046 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5047 |
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
| 5048 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5051,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5051 |
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5052 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5053 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5054 |
-
const dim3 block_nums(
|
| 5055 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5056 |
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
| 5057 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5060,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5060 |
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5061 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5062 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5063 |
-
const dim3 block_nums(
|
| 5064 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5065 |
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
| 5066 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5069,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5069 |
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5070 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5071 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5072 |
-
const dim3 block_nums(
|
| 5073 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5074 |
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
| 5075 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
@@ -5088,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
|
|
| 5088 |
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5089 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 5090 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5091 |
-
const dim3 block_nums(
|
| 5092 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5093 |
dequantize_mul_mat_vec<1, 1, convert_f16>
|
| 5094 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
@@ -5825,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|
| 5825 |
return ptr;
|
| 5826 |
}
|
| 5827 |
|
| 5828 |
-
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
|
| 5829 |
-
if (g_cudaMemPools[id] == nullptr) {
|
| 5830 |
-
return ggml_cuda_pool_malloc(size, actual_size);
|
| 5831 |
-
}
|
| 5832 |
-
void *ptr;
|
| 5833 |
-
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
| 5834 |
-
*actual_size = size;
|
| 5835 |
-
return ptr;
|
| 5836 |
-
}
|
| 5837 |
-
|
| 5838 |
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
| 5839 |
scoped_spin_lock lock(g_cuda_pool_lock);
|
| 5840 |
int id;
|
|
@@ -5852,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|
| 5852 |
CUDA_CHECK(cudaFree(ptr));
|
| 5853 |
}
|
| 5854 |
|
|
|
|
| 5855 |
|
| 5856 |
-
|
| 5857 |
-
|
| 5858 |
-
return ggml_cuda_pool_free(ptr, actual_size);
|
| 5859 |
-
}
|
| 5860 |
-
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
| 5861 |
}
|
| 5862 |
|
| 5863 |
void ggml_init_cublas() {
|
|
@@ -5872,7 +5886,12 @@ void ggml_init_cublas() {
|
|
| 5872 |
CUDA_CHECK(cudaDeviceSynchronize());
|
| 5873 |
#endif
|
| 5874 |
|
| 5875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5876 |
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
| 5877 |
int64_t total_vram = 0;
|
| 5878 |
#if defined(GGML_CUDA_FORCE_MMQ)
|
|
@@ -5914,19 +5933,13 @@ void ggml_init_cublas() {
|
|
| 5914 |
// create cublas handle
|
| 5915 |
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
| 5916 |
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
| 5917 |
-
|
| 5918 |
-
// configure memory pool
|
| 5919 |
-
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
| 5920 |
-
if (err == cudaSuccess) {
|
| 5921 |
-
size_t treshold = UINT64_MAX;
|
| 5922 |
-
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
| 5923 |
-
}
|
| 5924 |
}
|
| 5925 |
|
| 5926 |
// configure logging to stdout
|
| 5927 |
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
| 5928 |
|
| 5929 |
initialized = true;
|
|
|
|
| 5930 |
}
|
| 5931 |
}
|
| 5932 |
|
|
@@ -6193,6 +6206,34 @@ inline void ggml_cuda_op_silu(
|
|
| 6193 |
(void) src1_dd;
|
| 6194 |
}
|
| 6195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6196 |
inline void ggml_cuda_op_norm(
|
| 6197 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6198 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -6514,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6514 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
| 6515 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 6516 |
size_t ne = row_diff*ne00;
|
| 6517 |
-
src0_as_f16 = (half *)
|
| 6518 |
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
| 6519 |
}
|
| 6520 |
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
|
@@ -6525,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6525 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
| 6526 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 6527 |
size_t ne = src1_ncols*ne10;
|
| 6528 |
-
src1_as_f16 = (half *)
|
| 6529 |
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
| 6530 |
}
|
| 6531 |
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
| 6532 |
-
size_t
|
| 6533 |
-
half * dst_f16 = (half *)
|
| 6534 |
|
| 6535 |
const half alpha_f16 = 1.0f;
|
| 6536 |
const half beta_f16 = 0.0f;
|
|
@@ -6548,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6548 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
| 6549 |
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
| 6550 |
|
| 6551 |
-
|
| 6552 |
-
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
|
| 6553 |
-
}
|
| 6554 |
|
| 6555 |
if (src0_as != 0) {
|
| 6556 |
-
|
| 6557 |
}
|
|
|
|
| 6558 |
if (src1_as != 0) {
|
| 6559 |
-
|
| 6560 |
}
|
| 6561 |
}
|
| 6562 |
else {
|
|
@@ -6566,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6566 |
if (src0->type != GGML_TYPE_F32) {
|
| 6567 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
| 6568 |
GGML_ASSERT(to_fp32_cuda != nullptr);
|
| 6569 |
-
src0_ddq_as_f32 = (float *)
|
| 6570 |
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
| 6571 |
}
|
| 6572 |
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
|
@@ -6583,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6583 |
&beta, dst_dd_i, ldc));
|
| 6584 |
|
| 6585 |
if (src0_as != 0) {
|
| 6586 |
-
|
| 6587 |
}
|
| 6588 |
}
|
| 6589 |
|
|
@@ -7008,6 +7048,8 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7008 |
int64_t row_low[GGML_CUDA_MAX_DEVICES];
|
| 7009 |
int64_t row_high[GGML_CUDA_MAX_DEVICES];
|
| 7010 |
|
|
|
|
|
|
|
| 7011 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7012 |
// by default, use all rows
|
| 7013 |
row_low[id] = 0;
|
|
@@ -7035,6 +7077,8 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7035 |
continue;
|
| 7036 |
}
|
| 7037 |
|
|
|
|
|
|
|
| 7038 |
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
|
| 7039 |
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
|
| 7040 |
|
|
@@ -7045,22 +7089,21 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7045 |
src0_dd[id] = (char *) src0_extra->data_device[id];
|
| 7046 |
} else {
|
| 7047 |
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
| 7048 |
-
src0_dd[id] = (char *)
|
| 7049 |
}
|
| 7050 |
|
| 7051 |
if (src1_on_device && src1_is_contiguous) {
|
| 7052 |
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
| 7053 |
} else {
|
| 7054 |
-
src1_ddf[id] = (float *)
|
| 7055 |
}
|
| 7056 |
|
| 7057 |
if (convert_src1_to_q8_1) {
|
| 7058 |
-
|
| 7059 |
-
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
|
| 7060 |
|
| 7061 |
if (src1_on_device && src1_is_contiguous) {
|
| 7062 |
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
| 7063 |
-
|
| 7064 |
}
|
| 7065 |
}
|
| 7066 |
|
|
@@ -7068,18 +7111,18 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7068 |
dst_dd[id] = (float *) dst_extra->data_device[id];
|
| 7069 |
} else {
|
| 7070 |
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
| 7071 |
-
dst_dd[id] = (float *)
|
| 7072 |
}
|
| 7073 |
}
|
| 7074 |
|
| 7075 |
// if multiple devices are used they need to wait for the main device
|
| 7076 |
// here an event is recorded that signals that the main device has finished calculating the input data
|
| 7077 |
-
if (split &&
|
| 7078 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7079 |
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
| 7080 |
}
|
| 7081 |
|
| 7082 |
-
const int64_t src1_col_stride = split &&
|
| 7083 |
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
| 7084 |
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
| 7085 |
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
|
@@ -7194,6 +7237,27 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7194 |
}
|
| 7195 |
}
|
| 7196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7197 |
// main device waits for all other devices to be finished
|
| 7198 |
if (split && g_device_count > 1) {
|
| 7199 |
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
|
@@ -7201,6 +7265,9 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7201 |
|
| 7202 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7203 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
|
|
|
|
|
|
|
|
|
| 7204 |
for (int64_t is = 0; is < is_max; ++is) {
|
| 7205 |
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
| 7206 |
}
|
|
@@ -7211,21 +7278,6 @@ static void ggml_cuda_op_mul_mat(
|
|
| 7211 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7212 |
CUDA_CHECK(cudaDeviceSynchronize());
|
| 7213 |
}
|
| 7214 |
-
|
| 7215 |
-
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7216 |
-
if (src0_as[id] > 0) {
|
| 7217 |
-
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
|
| 7218 |
-
}
|
| 7219 |
-
if (src1_asf[id] > 0) {
|
| 7220 |
-
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
|
| 7221 |
-
}
|
| 7222 |
-
if (src1_asq[id] > 0) {
|
| 7223 |
-
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
|
| 7224 |
-
}
|
| 7225 |
-
if (dst_as[id] > 0) {
|
| 7226 |
-
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
|
| 7227 |
-
}
|
| 7228 |
-
}
|
| 7229 |
}
|
| 7230 |
|
| 7231 |
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -7252,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|
| 7252 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
| 7253 |
}
|
| 7254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7255 |
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7256 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
| 7257 |
}
|
|
@@ -7261,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
|
|
| 7261 |
}
|
| 7262 |
|
| 7263 |
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
|
|
|
|
|
|
| 7264 |
const int64_t ne10 = src1->ne[0];
|
| 7265 |
|
| 7266 |
const int64_t ne0 = dst->ne[0];
|
|
@@ -7412,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|
| 7412 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 7413 |
|
| 7414 |
size_t src1_as = 0;
|
| 7415 |
-
half * src1_as_f16 = (half *)
|
| 7416 |
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
| 7417 |
|
| 7418 |
size_t dst_as = 0;
|
| 7419 |
-
half * dst_f16 = (half *)
|
| 7420 |
|
| 7421 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 7422 |
GGML_ASSERT(ne13 % ne03 == 0);
|
|
@@ -7470,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|
| 7470 |
size_t ptrs_src_s = 0;
|
| 7471 |
size_t ptrs_dst_s = 0;
|
| 7472 |
|
| 7473 |
-
ptrs_src = (const void **)
|
| 7474 |
-
ptrs_dst = ( void **)
|
| 7475 |
|
| 7476 |
dim3 block_dims(ne13, ne12);
|
| 7477 |
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
|
@@ -7484,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|
| 7484 |
dst->nb[2], dst->nb[3],
|
| 7485 |
r2, r3);
|
| 7486 |
CUDA_CHECK(cudaGetLastError());
|
|
|
|
| 7487 |
CUBLAS_CHECK(
|
| 7488 |
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
| 7489 |
ne01, ne11, ne10,
|
|
@@ -7495,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|
| 7495 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 7496 |
|
| 7497 |
if (ptrs_src_s != 0) {
|
| 7498 |
-
|
| 7499 |
}
|
| 7500 |
if (ptrs_dst_s != 0) {
|
| 7501 |
-
|
| 7502 |
}
|
| 7503 |
}
|
| 7504 |
#endif
|
| 7505 |
|
| 7506 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
| 7507 |
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
| 7508 |
-
|
| 7509 |
-
|
| 7510 |
-
|
| 7511 |
-
if (dst_as != 0) {
|
| 7512 |
-
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
|
| 7513 |
-
}
|
| 7514 |
}
|
| 7515 |
|
| 7516 |
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7517 |
const bool all_on_device =
|
| 7518 |
-
(src0->backend == GGML_BACKEND_GPU) &&
|
| 7519 |
(src1->backend == GGML_BACKEND_GPU) &&
|
| 7520 |
( dst->backend == GGML_BACKEND_GPU);
|
| 7521 |
|
|
|
|
|
|
|
| 7522 |
int64_t min_compute_capability = INT_MAX;
|
| 7523 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7524 |
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
|
@@ -7540,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|
| 7540 |
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
| 7541 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 7542 |
|
| 7543 |
-
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
| 7544 |
// KQ single-batch
|
| 7545 |
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
| 7546 |
-
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
| 7547 |
// KQV single-batch
|
| 7548 |
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
| 7549 |
-
} else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
| 7550 |
// KQ + KQV multi-batch
|
| 7551 |
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
| 7552 |
} else if (src0->type == GGML_TYPE_F32) {
|
|
@@ -7667,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
| 7667 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
| 7668 |
}
|
| 7669 |
|
| 7670 |
-
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7671 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
| 7672 |
}
|
| 7673 |
|
|
@@ -7782,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0;
|
|
| 7782 |
|
| 7783 |
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
| 7784 |
if (g_temp_tensor_extras == nullptr) {
|
| 7785 |
-
g_temp_tensor_extras = new ggml_tensor_extra_gpu[
|
| 7786 |
}
|
| 7787 |
|
| 7788 |
size_t alloc_index = g_temp_tensor_extra_index;
|
| 7789 |
-
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) %
|
| 7790 |
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
| 7791 |
memset(extra, 0, sizeof(*extra));
|
| 7792 |
|
|
@@ -7953,6 +8015,8 @@ void ggml_cuda_free_scratch() {
|
|
| 7953 |
}
|
| 7954 |
|
| 7955 |
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
|
|
|
|
|
| 7956 |
ggml_cuda_func_t func;
|
| 7957 |
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
| 7958 |
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
|
@@ -7995,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 7995 |
case GGML_UNARY_OP_SILU:
|
| 7996 |
func = ggml_cuda_silu;
|
| 7997 |
break;
|
|
|
|
|
|
|
|
|
|
| 7998 |
default:
|
| 7999 |
return false;
|
| 8000 |
} break;
|
|
@@ -8013,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8013 |
case GGML_OP_SCALE:
|
| 8014 |
func = ggml_cuda_scale;
|
| 8015 |
break;
|
|
|
|
|
|
|
|
|
|
| 8016 |
case GGML_OP_CLAMP:
|
| 8017 |
if (!any_on_device) {
|
| 8018 |
return false;
|
|
@@ -8105,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda {
|
|
| 8105 |
|
| 8106 |
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
| 8107 |
if (temp_tensor_extras == nullptr) {
|
| 8108 |
-
temp_tensor_extras = new ggml_tensor_extra_gpu[
|
| 8109 |
}
|
| 8110 |
|
| 8111 |
size_t alloc_index = temp_tensor_extra_index;
|
| 8112 |
-
temp_tensor_extra_index = (temp_tensor_extra_index + 1) %
|
| 8113 |
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
| 8114 |
memset(extra, 0, sizeof(*extra));
|
| 8115 |
|
|
|
|
| 39 |
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
| 40 |
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
| 41 |
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
|
|
|
| 42 |
#define cudaDeviceProp hipDeviceProp_t
|
| 43 |
#define cudaDeviceSynchronize hipDeviceSynchronize
|
| 44 |
#define cudaError_t hipError_t
|
|
|
|
| 48 |
#define cudaEvent_t hipEvent_t
|
| 49 |
#define cudaEventDestroy hipEventDestroy
|
| 50 |
#define cudaFree hipFree
|
|
|
|
| 51 |
#define cudaFreeHost hipHostFree
|
| 52 |
#define cudaGetDevice hipGetDevice
|
| 53 |
#define cudaGetDeviceCount hipGetDeviceCount
|
|
|
|
| 55 |
#define cudaGetErrorString hipGetErrorString
|
| 56 |
#define cudaGetLastError hipGetLastError
|
| 57 |
#define cudaMalloc hipMalloc
|
|
|
|
| 58 |
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
| 59 |
#define cudaMemcpy hipMemcpy
|
| 60 |
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
|
|
|
| 63 |
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
| 64 |
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
| 65 |
#define cudaMemcpyKind hipMemcpyKind
|
|
|
|
|
|
|
|
|
|
| 66 |
#define cudaMemset hipMemset
|
| 67 |
#define cudaMemsetAsync hipMemsetAsync
|
| 68 |
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
|
|
| 88 |
#define CC_OFFSET_AMD 1000000
|
| 89 |
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
| 90 |
|
| 91 |
+
#define GGML_CUDA_MAX_NODES 8192
|
| 92 |
+
|
| 93 |
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
| 94 |
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
| 95 |
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
|
|
|
| 184 |
do { \
|
| 185 |
cudaError_t err_ = (err); \
|
| 186 |
if (err_ != cudaSuccess) { \
|
| 187 |
+
int id; \
|
| 188 |
+
cudaGetDevice(&id); \
|
| 189 |
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
| 190 |
cudaGetErrorString(err_)); \
|
| 191 |
+
fprintf(stderr, "current device: %d\n", id); \
|
| 192 |
exit(1); \
|
| 193 |
} \
|
| 194 |
} while (0)
|
|
|
|
| 198 |
do { \
|
| 199 |
cublasStatus_t err_ = (err); \
|
| 200 |
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
| 201 |
+
int id; \
|
| 202 |
+
cudaGetDevice(&id); \
|
| 203 |
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
| 204 |
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
| 205 |
+
fprintf(stderr, "current device: %d\n", id); \
|
| 206 |
exit(1); \
|
| 207 |
} \
|
| 208 |
} while (0)
|
|
|
|
| 436 |
#define CUDA_MUL_BLOCK_SIZE 256
|
| 437 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 438 |
#define CUDA_SILU_BLOCK_SIZE 256
|
| 439 |
+
#define CUDA_RELU_BLOCK_SIZE 256
|
| 440 |
+
#define CUDA_SQR_BLOCK_SIZE 256
|
| 441 |
#define CUDA_CPY_BLOCK_SIZE 32
|
| 442 |
#define CUDA_SCALE_BLOCK_SIZE 256
|
| 443 |
#define CUDA_CLAMP_BLOCK_SIZE 256
|
|
|
|
| 470 |
|
| 471 |
#define MAX_STREAMS 8
|
| 472 |
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
|
|
|
| 473 |
|
| 474 |
struct ggml_tensor_extra_gpu {
|
| 475 |
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
|
|
|
| 558 |
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
| 559 |
}
|
| 560 |
|
| 561 |
+
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
| 562 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 563 |
+
|
| 564 |
+
if (i >= k) {
|
| 565 |
+
return;
|
| 566 |
+
}
|
| 567 |
+
dst[i] = fmaxf(x[i], 0);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
| 571 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 572 |
+
|
| 573 |
+
if (i >= k) {
|
| 574 |
+
return;
|
| 575 |
+
}
|
| 576 |
+
dst[i] = x[i] * x[i];
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
| 580 |
#pragma unroll
|
| 581 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
|
|
| 1005 |
|
| 1006 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
| 1007 |
|
| 1008 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 1009 |
if (row > nrows) return;
|
| 1010 |
|
| 1011 |
const int num_blocks_per_row = ncols / QK_K;
|
|
|
|
| 1109 |
|
| 1110 |
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1111 |
|
| 1112 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 1113 |
if (row > nrows) return;
|
| 1114 |
|
| 1115 |
const int num_blocks_per_row = ncols / QK_K;
|
|
|
|
| 1213 |
|
| 1214 |
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1215 |
|
| 1216 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 1217 |
if (row > nrows) return;
|
| 1218 |
const int num_blocks_per_row = ncols / QK_K;
|
| 1219 |
const int ib0 = row*num_blocks_per_row;
|
|
|
|
| 1467 |
|
| 1468 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
| 1469 |
|
| 1470 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 1471 |
if (row > nrows) return;
|
| 1472 |
|
| 1473 |
const int num_blocks_per_row = ncols / QK_K;
|
|
|
|
| 4277 |
|
| 4278 |
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
| 4279 |
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
| 4280 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 4281 |
|
| 4282 |
if (row >= nrows) {
|
| 4283 |
return;
|
|
|
|
| 4317 |
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
| 4318 |
// qk = quantized weights per x block
|
| 4319 |
// qr = number of quantized weights per data value in x block
|
| 4320 |
+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 4321 |
|
| 4322 |
if (row >= nrows) {
|
| 4323 |
return;
|
|
|
|
| 4756 |
int ofs0, int ofs1, int IW, int IH, int CHW,
|
| 4757 |
int s0, int s1, int p0, int p1, int d0, int d1) {
|
| 4758 |
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
| 4759 |
+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
| 4760 |
|
| 4761 |
const int offset_dst =
|
| 4762 |
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
|
|
|
| 4808 |
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 4809 |
}
|
| 4810 |
|
| 4811 |
+
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 4812 |
+
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
| 4813 |
+
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 4814 |
+
}
|
| 4815 |
+
|
| 4816 |
+
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 4817 |
+
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
| 4818 |
+
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 4819 |
+
}
|
| 4820 |
+
|
| 4821 |
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4822 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 4823 |
if (ncols < 1024) {
|
|
|
|
| 4926 |
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4927 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4928 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4929 |
+
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
| 4930 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4931 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4932 |
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
| 4933 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 4936 |
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4937 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4938 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4939 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4940 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4941 |
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
| 4942 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 4945 |
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4946 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4947 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4948 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4949 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4950 |
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
| 4951 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 4954 |
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4955 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4956 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4957 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4958 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4959 |
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
| 4960 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 4963 |
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 4964 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 4965 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 4966 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4967 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 4968 |
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
| 4969 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 4973 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4974 |
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
| 4975 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4976 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4977 |
const dim3 block_dims(32, ny, 1);
|
| 4978 |
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4979 |
}
|
|
|
|
| 4982 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4983 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 4984 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4985 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4986 |
const dim3 block_dims(32, ny, 1);
|
| 4987 |
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4988 |
}
|
|
|
|
| 4991 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 4992 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 4993 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 4994 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 4995 |
const dim3 block_dims(32, ny, 1);
|
| 4996 |
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 4997 |
}
|
|
|
|
| 5006 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5007 |
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
| 5008 |
const int block_num_y = (nrows + ny - 1) / ny;
|
| 5009 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5010 |
const dim3 block_dims(32, ny, 1);
|
| 5011 |
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
| 5012 |
}
|
|
|
|
| 5014 |
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5015 |
GGML_ASSERT(ncols % QK4_0 == 0);
|
| 5016 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5017 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5018 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5019 |
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
| 5020 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5023 |
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5024 |
GGML_ASSERT(ncols % QK4_1 == 0);
|
| 5025 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5026 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5027 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5028 |
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
| 5029 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5032 |
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5033 |
GGML_ASSERT(ncols % QK5_0 == 0);
|
| 5034 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5035 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5036 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5037 |
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
| 5038 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5041 |
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5042 |
GGML_ASSERT(ncols % QK5_1 == 0);
|
| 5043 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5044 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5045 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5046 |
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
| 5047 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5050 |
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5051 |
GGML_ASSERT(ncols % QK8_0 == 0);
|
| 5052 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5053 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5054 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5055 |
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
| 5056 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5059 |
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5060 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5061 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5062 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5063 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5064 |
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
| 5065 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5068 |
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5069 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5070 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5071 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5072 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5073 |
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
| 5074 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5077 |
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5078 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5079 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5080 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5081 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5082 |
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
| 5083 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5086 |
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5087 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5088 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5089 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5090 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5091 |
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
| 5092 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5095 |
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5096 |
GGML_ASSERT(ncols % QK_K == 0);
|
| 5097 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5098 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5099 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5100 |
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
| 5101 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
|
|
| 5114 |
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 5115 |
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
| 5116 |
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 5117 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 5118 |
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 5119 |
dequantize_mul_mat_vec<1, 1, convert_f16>
|
| 5120 |
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
|
|
|
| 5851 |
return ptr;
|
| 5852 |
}
|
| 5853 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5854 |
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
| 5855 |
scoped_spin_lock lock(g_cuda_pool_lock);
|
| 5856 |
int id;
|
|
|
|
| 5868 |
CUDA_CHECK(cudaFree(ptr));
|
| 5869 |
}
|
| 5870 |
|
| 5871 |
+
static bool g_cublas_loaded = false;
|
| 5872 |
|
| 5873 |
+
bool ggml_cublas_loaded(void) {
|
| 5874 |
+
return g_cublas_loaded;
|
|
|
|
|
|
|
|
|
|
| 5875 |
}
|
| 5876 |
|
| 5877 |
void ggml_init_cublas() {
|
|
|
|
| 5886 |
CUDA_CHECK(cudaDeviceSynchronize());
|
| 5887 |
#endif
|
| 5888 |
|
| 5889 |
+
if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
|
| 5890 |
+
initialized = true;
|
| 5891 |
+
g_cublas_loaded = false;
|
| 5892 |
+
return;
|
| 5893 |
+
}
|
| 5894 |
+
|
| 5895 |
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
| 5896 |
int64_t total_vram = 0;
|
| 5897 |
#if defined(GGML_CUDA_FORCE_MMQ)
|
|
|
|
| 5933 |
// create cublas handle
|
| 5934 |
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
| 5935 |
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5936 |
}
|
| 5937 |
|
| 5938 |
// configure logging to stdout
|
| 5939 |
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
| 5940 |
|
| 5941 |
initialized = true;
|
| 5942 |
+
g_cublas_loaded = true;
|
| 5943 |
}
|
| 5944 |
}
|
| 5945 |
|
|
|
|
| 6206 |
(void) src1_dd;
|
| 6207 |
}
|
| 6208 |
|
| 6209 |
+
inline void ggml_cuda_op_relu(
|
| 6210 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6211 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6212 |
+
|
| 6213 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6214 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6215 |
+
|
| 6216 |
+
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 6217 |
+
|
| 6218 |
+
(void) src1;
|
| 6219 |
+
(void) dst;
|
| 6220 |
+
(void) src1_dd;
|
| 6221 |
+
}
|
| 6222 |
+
|
| 6223 |
+
inline void ggml_cuda_op_sqr(
|
| 6224 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6225 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6226 |
+
|
| 6227 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6228 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6229 |
+
|
| 6230 |
+
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 6231 |
+
|
| 6232 |
+
(void) src1;
|
| 6233 |
+
(void) dst;
|
| 6234 |
+
(void) src1_dd;
|
| 6235 |
+
}
|
| 6236 |
+
|
| 6237 |
inline void ggml_cuda_op_norm(
|
| 6238 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6239 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 6555 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
| 6556 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 6557 |
size_t ne = row_diff*ne00;
|
| 6558 |
+
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
| 6559 |
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
| 6560 |
}
|
| 6561 |
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
|
|
|
| 6566 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
| 6567 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 6568 |
size_t ne = src1_ncols*ne10;
|
| 6569 |
+
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
| 6570 |
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
| 6571 |
}
|
| 6572 |
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
| 6573 |
+
size_t dst_as = 0;
|
| 6574 |
+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
| 6575 |
|
| 6576 |
const half alpha_f16 = 1.0f;
|
| 6577 |
const half beta_f16 = 0.0f;
|
|
|
|
| 6589 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
| 6590 |
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
| 6591 |
|
| 6592 |
+
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
|
|
|
|
|
| 6593 |
|
| 6594 |
if (src0_as != 0) {
|
| 6595 |
+
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
| 6596 |
}
|
| 6597 |
+
|
| 6598 |
if (src1_as != 0) {
|
| 6599 |
+
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
| 6600 |
}
|
| 6601 |
}
|
| 6602 |
else {
|
|
|
|
| 6606 |
if (src0->type != GGML_TYPE_F32) {
|
| 6607 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
| 6608 |
GGML_ASSERT(to_fp32_cuda != nullptr);
|
| 6609 |
+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
| 6610 |
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
| 6611 |
}
|
| 6612 |
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
|
|
|
| 6623 |
&beta, dst_dd_i, ldc));
|
| 6624 |
|
| 6625 |
if (src0_as != 0) {
|
| 6626 |
+
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
| 6627 |
}
|
| 6628 |
}
|
| 6629 |
|
|
|
|
| 7048 |
int64_t row_low[GGML_CUDA_MAX_DEVICES];
|
| 7049 |
int64_t row_high[GGML_CUDA_MAX_DEVICES];
|
| 7050 |
|
| 7051 |
+
int used_devices = 0;
|
| 7052 |
+
|
| 7053 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7054 |
// by default, use all rows
|
| 7055 |
row_low[id] = 0;
|
|
|
|
| 7077 |
continue;
|
| 7078 |
}
|
| 7079 |
|
| 7080 |
+
used_devices++;
|
| 7081 |
+
|
| 7082 |
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
|
| 7083 |
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
|
| 7084 |
|
|
|
|
| 7089 |
src0_dd[id] = (char *) src0_extra->data_device[id];
|
| 7090 |
} else {
|
| 7091 |
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
| 7092 |
+
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
|
| 7093 |
}
|
| 7094 |
|
| 7095 |
if (src1_on_device && src1_is_contiguous) {
|
| 7096 |
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
| 7097 |
} else {
|
| 7098 |
+
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
|
| 7099 |
}
|
| 7100 |
|
| 7101 |
if (convert_src1_to_q8_1) {
|
| 7102 |
+
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
|
|
|
|
| 7103 |
|
| 7104 |
if (src1_on_device && src1_is_contiguous) {
|
| 7105 |
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
| 7106 |
+
CUDA_CHECK(cudaGetLastError());
|
| 7107 |
}
|
| 7108 |
}
|
| 7109 |
|
|
|
|
| 7111 |
dst_dd[id] = (float *) dst_extra->data_device[id];
|
| 7112 |
} else {
|
| 7113 |
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
| 7114 |
+
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
|
| 7115 |
}
|
| 7116 |
}
|
| 7117 |
|
| 7118 |
// if multiple devices are used they need to wait for the main device
|
| 7119 |
// here an event is recorded that signals that the main device has finished calculating the input data
|
| 7120 |
+
if (split && used_devices > 1) {
|
| 7121 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7122 |
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
| 7123 |
}
|
| 7124 |
|
| 7125 |
+
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
| 7126 |
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
| 7127 |
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
| 7128 |
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
|
|
|
| 7237 |
}
|
| 7238 |
}
|
| 7239 |
|
| 7240 |
+
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7241 |
+
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
|
| 7242 |
+
continue;
|
| 7243 |
+
}
|
| 7244 |
+
CUDA_CHECK(ggml_cuda_set_device(id));
|
| 7245 |
+
|
| 7246 |
+
// free buffers again when done
|
| 7247 |
+
if (src0_as[id] > 0) {
|
| 7248 |
+
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
| 7249 |
+
}
|
| 7250 |
+
if (src1_asf[id] > 0) {
|
| 7251 |
+
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
| 7252 |
+
}
|
| 7253 |
+
if (src1_asq[id] > 0) {
|
| 7254 |
+
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
|
| 7255 |
+
}
|
| 7256 |
+
if (dst_as[id] > 0) {
|
| 7257 |
+
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
| 7258 |
+
}
|
| 7259 |
+
}
|
| 7260 |
+
|
| 7261 |
// main device waits for all other devices to be finished
|
| 7262 |
if (split && g_device_count > 1) {
|
| 7263 |
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
|
|
|
| 7265 |
|
| 7266 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7267 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7268 |
+
if (row_low[id] == row_high[id]) {
|
| 7269 |
+
continue;
|
| 7270 |
+
}
|
| 7271 |
for (int64_t is = 0; is < is_max; ++is) {
|
| 7272 |
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
| 7273 |
}
|
|
|
|
| 7278 |
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
| 7279 |
CUDA_CHECK(cudaDeviceSynchronize());
|
| 7280 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7281 |
}
|
| 7282 |
|
| 7283 |
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
| 7304 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
| 7305 |
}
|
| 7306 |
|
| 7307 |
+
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7308 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
| 7309 |
+
}
|
| 7310 |
+
|
| 7311 |
+
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7312 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
| 7313 |
+
}
|
| 7314 |
+
|
| 7315 |
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7316 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
| 7317 |
}
|
|
|
|
| 7321 |
}
|
| 7322 |
|
| 7323 |
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
| 7324 |
+
if (!g_cublas_loaded) return false;
|
| 7325 |
+
|
| 7326 |
const int64_t ne10 = src1->ne[0];
|
| 7327 |
|
| 7328 |
const int64_t ne0 = dst->ne[0];
|
|
|
|
| 7474 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 7475 |
|
| 7476 |
size_t src1_as = 0;
|
| 7477 |
+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
| 7478 |
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
| 7479 |
|
| 7480 |
size_t dst_as = 0;
|
| 7481 |
+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
| 7482 |
|
| 7483 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 7484 |
GGML_ASSERT(ne13 % ne03 == 0);
|
|
|
|
| 7532 |
size_t ptrs_src_s = 0;
|
| 7533 |
size_t ptrs_dst_s = 0;
|
| 7534 |
|
| 7535 |
+
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
| 7536 |
+
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
| 7537 |
|
| 7538 |
dim3 block_dims(ne13, ne12);
|
| 7539 |
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
|
|
|
| 7546 |
dst->nb[2], dst->nb[3],
|
| 7547 |
r2, r3);
|
| 7548 |
CUDA_CHECK(cudaGetLastError());
|
| 7549 |
+
|
| 7550 |
CUBLAS_CHECK(
|
| 7551 |
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
| 7552 |
ne01, ne11, ne10,
|
|
|
|
| 7558 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 7559 |
|
| 7560 |
if (ptrs_src_s != 0) {
|
| 7561 |
+
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
| 7562 |
}
|
| 7563 |
if (ptrs_dst_s != 0) {
|
| 7564 |
+
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
| 7565 |
}
|
| 7566 |
}
|
| 7567 |
#endif
|
| 7568 |
|
| 7569 |
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
| 7570 |
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
| 7571 |
+
|
| 7572 |
+
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
| 7573 |
+
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
|
|
|
|
|
|
|
|
| 7574 |
}
|
| 7575 |
|
| 7576 |
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7577 |
const bool all_on_device =
|
| 7578 |
+
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
| 7579 |
(src1->backend == GGML_BACKEND_GPU) &&
|
| 7580 |
( dst->backend == GGML_BACKEND_GPU);
|
| 7581 |
|
| 7582 |
+
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
| 7583 |
+
|
| 7584 |
int64_t min_compute_capability = INT_MAX;
|
| 7585 |
for (int64_t id = 0; id < g_device_count; ++id) {
|
| 7586 |
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
|
|
|
| 7602 |
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
| 7603 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 7604 |
|
| 7605 |
+
if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
| 7606 |
// KQ single-batch
|
| 7607 |
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
| 7608 |
+
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
| 7609 |
// KQV single-batch
|
| 7610 |
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
| 7611 |
+
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
| 7612 |
// KQ + KQV multi-batch
|
| 7613 |
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
| 7614 |
} else if (src0->type == GGML_TYPE_F32) {
|
|
|
|
| 7729 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
| 7730 |
}
|
| 7731 |
|
| 7732 |
+
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7733 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
| 7734 |
}
|
| 7735 |
|
|
|
|
| 7844 |
|
| 7845 |
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
| 7846 |
if (g_temp_tensor_extras == nullptr) {
|
| 7847 |
+
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
| 7848 |
}
|
| 7849 |
|
| 7850 |
size_t alloc_index = g_temp_tensor_extra_index;
|
| 7851 |
+
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
| 7852 |
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
| 7853 |
memset(extra, 0, sizeof(*extra));
|
| 7854 |
|
|
|
|
| 8015 |
}
|
| 8016 |
|
| 8017 |
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
| 8018 |
+
if (!g_cublas_loaded) return false;
|
| 8019 |
+
|
| 8020 |
ggml_cuda_func_t func;
|
| 8021 |
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
| 8022 |
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
|
|
|
| 8059 |
case GGML_UNARY_OP_SILU:
|
| 8060 |
func = ggml_cuda_silu;
|
| 8061 |
break;
|
| 8062 |
+
case GGML_UNARY_OP_RELU:
|
| 8063 |
+
func = ggml_cuda_relu;
|
| 8064 |
+
break;
|
| 8065 |
default:
|
| 8066 |
return false;
|
| 8067 |
} break;
|
|
|
|
| 8080 |
case GGML_OP_SCALE:
|
| 8081 |
func = ggml_cuda_scale;
|
| 8082 |
break;
|
| 8083 |
+
case GGML_OP_SQR:
|
| 8084 |
+
func = ggml_cuda_sqr;
|
| 8085 |
+
break;
|
| 8086 |
case GGML_OP_CLAMP:
|
| 8087 |
if (!any_on_device) {
|
| 8088 |
return false;
|
|
|
|
| 8175 |
|
| 8176 |
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
| 8177 |
if (temp_tensor_extras == nullptr) {
|
| 8178 |
+
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
| 8179 |
}
|
| 8180 |
|
| 8181 |
size_t alloc_index = temp_tensor_extra_index;
|
| 8182 |
+
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
| 8183 |
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
| 8184 |
memset(extra, 0, sizeof(*extra));
|
| 8185 |
|
ggml-cuda.h
CHANGED
|
@@ -17,7 +17,12 @@ extern "C" {
|
|
| 17 |
|
| 18 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 19 |
|
|
|
|
| 20 |
GGML_API void ggml_init_cublas(void);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
| 22 |
GGML_API void ggml_cuda_host_free(void * ptr);
|
| 23 |
|
|
|
|
| 17 |
|
| 18 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 19 |
|
| 20 |
+
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
| 21 |
GGML_API void ggml_init_cublas(void);
|
| 22 |
+
|
| 23 |
+
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
| 24 |
+
GGML_API bool ggml_cublas_loaded(void);
|
| 25 |
+
|
| 26 |
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
| 27 |
GGML_API void ggml_cuda_host_free(void * ptr);
|
| 28 |
|
whisper.cpp
CHANGED
|
@@ -20,6 +20,7 @@
|
|
| 20 |
#include "ggml-alloc.h"
|
| 21 |
#include "ggml-backend.h"
|
| 22 |
|
|
|
|
| 23 |
#include <algorithm>
|
| 24 |
#include <cassert>
|
| 25 |
#define _USE_MATH_DEFINES
|
|
@@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
| 147 |
|
| 148 |
//#define WHISPER_USE_FLASH_ATTN
|
| 149 |
//#define WHISPER_USE_FLASH_FF
|
| 150 |
-
#define WHISPER_MAX_DECODERS
|
| 151 |
#define WHISPER_MAX_NODES 4096
|
| 152 |
|
| 153 |
//
|
|
@@ -406,6 +407,121 @@ struct whisper_segment {
|
|
| 406 |
bool speaker_turn_next;
|
| 407 |
};
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
// medium
|
| 410 |
// hparams: {
|
| 411 |
// 'n_mels': 80,
|
|
@@ -523,15 +639,31 @@ struct whisper_layer_decoder {
|
|
| 523 |
struct ggml_tensor * mlp_1_b;
|
| 524 |
};
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
struct whisper_kv_cache {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
struct ggml_tensor * k;
|
| 528 |
struct ggml_tensor * v;
|
| 529 |
|
| 530 |
struct ggml_context * ctx;
|
| 531 |
|
| 532 |
ggml_backend_buffer_t buffer;
|
| 533 |
-
|
| 534 |
-
int n; // number of tokens currently in the cache
|
| 535 |
};
|
| 536 |
|
| 537 |
struct whisper_model {
|
|
@@ -585,11 +717,11 @@ struct whisper_partial_utf8 {
|
|
| 585 |
};
|
| 586 |
|
| 587 |
struct whisper_grammar {
|
| 588 |
-
/*const*/ std::vector<std::vector<whisper_grammar_element>>
|
| 589 |
-
std::vector<std::vector<const whisper_grammar_element *>>
|
| 590 |
|
| 591 |
// buffer for partially generated UTF-8 sequence from accepted tokens
|
| 592 |
-
whisper_partial_utf8
|
| 593 |
};
|
| 594 |
|
| 595 |
struct whisper_grammar_candidate {
|
|
@@ -613,15 +745,13 @@ struct whisper_sequence {
|
|
| 613 |
|
| 614 |
// TAGS: WHISPER_DECODER_INIT
|
| 615 |
struct whisper_decoder {
|
| 616 |
-
// each decoder keeps its own KV-cache
|
| 617 |
-
whisper_kv_cache kv_self;
|
| 618 |
-
|
| 619 |
// the currently generated sequence of tokens
|
| 620 |
whisper_sequence sequence;
|
| 621 |
|
| 622 |
// grammar parse state of generated sequence of tokens
|
| 623 |
whisper_grammar grammar;
|
| 624 |
|
|
|
|
| 625 |
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
| 626 |
|
| 627 |
bool failed; // has the current segment failed to decode?
|
|
@@ -633,100 +763,40 @@ struct whisper_decoder {
|
|
| 633 |
std::vector<float> logits;
|
| 634 |
std::vector<float> logprobs;
|
| 635 |
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
| 640 |
-
template<typename A, typename B>
|
| 641 |
-
struct whisper_pair {
|
| 642 |
-
A first;
|
| 643 |
-
B second;
|
| 644 |
-
|
| 645 |
-
// Define a constructor that takes two arguments.
|
| 646 |
-
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
| 647 |
-
// Define a constructor that takes no argument.
|
| 648 |
-
whisper_pair() : first(A()), second(B()) {}
|
| 649 |
-
};
|
| 650 |
-
|
| 651 |
-
// beam-search helpers
|
| 652 |
-
struct kv_buf {
|
| 653 |
-
std::vector<uint8_t> k;
|
| 654 |
-
std::vector<uint8_t> v;
|
| 655 |
-
};
|
| 656 |
-
|
| 657 |
-
// ggml_allocr wrapper for whisper usage
|
| 658 |
-
struct whisper_allocr {
|
| 659 |
-
ggml_allocr * alloc = nullptr;
|
| 660 |
-
|
| 661 |
-
std::vector<uint8_t> meta;
|
| 662 |
|
| 663 |
-
|
| 664 |
};
|
| 665 |
|
| 666 |
-
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
| 667 |
-
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
| 668 |
-
}
|
| 669 |
-
|
| 670 |
-
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
| 671 |
-
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
|
| 672 |
-
auto & alloc = allocr.alloc;
|
| 673 |
-
auto & meta = allocr.meta;
|
| 674 |
-
|
| 675 |
-
alloc = ggml_allocr_new_measure_from_backend(backend);
|
| 676 |
-
|
| 677 |
-
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
| 678 |
-
|
| 679 |
-
ggml_allocr_alloc_graph(alloc, get_graph());
|
| 680 |
-
}
|
| 681 |
-
|
| 682 |
-
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
| 683 |
-
if (allocr.alloc == nullptr) {
|
| 684 |
-
// this can be null if we use external encoder like CoreML or OpenVINO
|
| 685 |
-
return;
|
| 686 |
-
}
|
| 687 |
-
|
| 688 |
-
auto & alloc = allocr.alloc;
|
| 689 |
-
auto & buffer = allocr.buffer;
|
| 690 |
-
|
| 691 |
-
size_t size = ggml_allocr_max_size(alloc);
|
| 692 |
-
|
| 693 |
-
ggml_allocr_free(alloc);
|
| 694 |
-
|
| 695 |
-
buffer = ggml_backend_alloc_buffer(backend, size);
|
| 696 |
-
alloc = ggml_allocr_new_from_buffer(buffer);
|
| 697 |
-
}
|
| 698 |
-
|
| 699 |
-
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
| 700 |
-
if (allocr.alloc) {
|
| 701 |
-
ggml_allocr_free(allocr.alloc);
|
| 702 |
-
ggml_backend_buffer_free(allocr.buffer);
|
| 703 |
-
allocr.alloc = nullptr;
|
| 704 |
-
}
|
| 705 |
-
}
|
| 706 |
-
|
| 707 |
struct whisper_state {
|
| 708 |
int64_t t_sample_us = 0;
|
| 709 |
int64_t t_encode_us = 0;
|
| 710 |
int64_t t_decode_us = 0;
|
|
|
|
| 711 |
int64_t t_prompt_us = 0;
|
| 712 |
int64_t t_mel_us = 0;
|
| 713 |
|
| 714 |
int32_t n_sample = 0; // number of tokens sampled
|
| 715 |
int32_t n_encode = 0; // number of encoder calls
|
| 716 |
-
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1
|
| 717 |
-
int32_t
|
|
|
|
| 718 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 719 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 720 |
|
|
|
|
|
|
|
|
|
|
| 721 |
// cross-attention KV cache for the decoders
|
| 722 |
// shared between all decoders
|
| 723 |
whisper_kv_cache kv_cross;
|
|
|
|
| 724 |
whisper_mel mel;
|
| 725 |
|
| 726 |
-
|
| 727 |
|
| 728 |
-
|
| 729 |
-
std::vector<kv_buf> kv_swap_bufs;
|
| 730 |
|
| 731 |
ggml_backend_t backend = nullptr;
|
| 732 |
|
|
@@ -742,8 +812,9 @@ struct whisper_state {
|
|
| 742 |
struct ggml_tensor * embd_conv = nullptr;
|
| 743 |
struct ggml_tensor * embd_enc = nullptr;
|
| 744 |
|
| 745 |
-
//
|
| 746 |
std::vector<float> inp_mel;
|
|
|
|
| 747 |
|
| 748 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 749 |
std::vector<float> logits;
|
|
@@ -751,11 +822,6 @@ struct whisper_state {
|
|
| 751 |
std::vector<whisper_segment> result_all;
|
| 752 |
std::vector<whisper_token> prompt_past;
|
| 753 |
|
| 754 |
-
// work container used to avoid memory allocations
|
| 755 |
-
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
| 756 |
-
|
| 757 |
-
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
| 758 |
-
|
| 759 |
int lang_id = 0; // english by default
|
| 760 |
|
| 761 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
@@ -831,6 +897,12 @@ static bool kv_cache_init(
|
|
| 831 |
/*.no_alloc =*/ true,
|
| 832 |
};
|
| 833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
cache.ctx = ggml_init(params);
|
| 835 |
|
| 836 |
if (!cache.ctx) {
|
|
@@ -858,54 +930,129 @@ static bool kv_cache_init(
|
|
| 858 |
return true;
|
| 859 |
}
|
| 860 |
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
-
|
| 866 |
-
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
-
|
| 869 |
-
|
|
|
|
|
|
|
| 870 |
|
| 871 |
-
|
| 872 |
-
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
| 873 |
-
/*.mem_buffer =*/ nullptr,
|
| 874 |
-
/*.no_alloc =*/ true,
|
| 875 |
-
};
|
| 876 |
|
| 877 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
}
|
| 883 |
|
| 884 |
-
|
| 885 |
-
|
| 886 |
|
| 887 |
-
|
|
|
|
|
|
|
|
|
|
| 888 |
|
| 889 |
-
|
|
|
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
|
| 895 |
-
|
| 896 |
-
|
| 897 |
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
| 899 |
}
|
|
|
|
|
|
|
| 900 |
|
| 901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
}
|
| 903 |
|
| 904 |
-
static void
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
}
|
| 910 |
}
|
| 911 |
|
|
@@ -914,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
|
| 914 |
|
| 915 |
// initialize the backends
|
| 916 |
#ifdef GGML_USE_CUBLAS
|
| 917 |
-
if (params.use_gpu) {
|
| 918 |
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
| 919 |
backend_gpu = ggml_backend_cuda_init();
|
| 920 |
if (!backend_gpu) {
|
|
@@ -1116,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1116 |
word = "[_EOT_]";
|
| 1117 |
} else if (i == vocab.token_sot) {
|
| 1118 |
word = "[_SOT_]";
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
} else if (i == vocab.token_solm) {
|
| 1120 |
word = "[_SOLM_]";
|
| 1121 |
} else if (i == vocab.token_prev) {
|
|
@@ -1126,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1126 |
word = "[_NOT_]";
|
| 1127 |
} else if (i == vocab.token_beg) {
|
| 1128 |
word = "[_BEG_]";
|
|
|
|
|
|
|
| 1129 |
} else {
|
| 1130 |
word = "[_extra_token_" + std::to_string(i) + "]";
|
| 1131 |
}
|
|
@@ -2031,26 +2184,28 @@ static bool whisper_encode_internal(
|
|
| 2031 |
static struct ggml_cgraph * whisper_build_graph_decoder(
|
| 2032 |
whisper_context & wctx,
|
| 2033 |
whisper_state & wstate,
|
| 2034 |
-
|
| 2035 |
-
const whisper_token * tokens,
|
| 2036 |
-
int n_tokens,
|
| 2037 |
-
int n_past) {
|
| 2038 |
const auto & model = wctx.model;
|
| 2039 |
const auto & hparams = model.hparams;
|
| 2040 |
|
| 2041 |
-
auto & kv_self =
|
| 2042 |
|
| 2043 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 2044 |
|
| 2045 |
-
|
|
|
|
|
|
|
| 2046 |
const int n_state = hparams.n_text_state;
|
| 2047 |
const int n_head = hparams.n_text_head;
|
| 2048 |
const int n_layer = hparams.n_text_layer;
|
| 2049 |
|
| 2050 |
-
const int
|
| 2051 |
-
const int
|
| 2052 |
|
| 2053 |
-
|
|
|
|
|
|
|
|
|
|
| 2054 |
|
| 2055 |
struct ggml_init_params params = {
|
| 2056 |
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
@@ -2062,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2062 |
|
| 2063 |
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
| 2064 |
|
| 2065 |
-
|
| 2066 |
-
|
| 2067 |
-
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 2068 |
ggml_allocr_alloc(alloc, embd);
|
| 2069 |
|
| 2070 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2071 |
-
ggml_backend_tensor_set(embd,
|
| 2072 |
}
|
| 2073 |
|
| 2074 |
-
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,
|
| 2075 |
ggml_allocr_alloc(alloc, position);
|
| 2076 |
|
| 2077 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2078 |
-
for (int i = 0; i <
|
| 2079 |
-
const int32_t val =
|
| 2080 |
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
| 2081 |
}
|
| 2082 |
}
|
|
@@ -2089,6 +2242,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2089 |
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
| 2090 |
}
|
| 2091 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2092 |
// token encoding + position encoding
|
| 2093 |
struct ggml_tensor * cur =
|
| 2094 |
ggml_add(ctx0,
|
|
@@ -2141,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2141 |
Vcur,
|
| 2142 |
layer.attn_v_b);
|
| 2143 |
|
| 2144 |
-
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state,
|
| 2145 |
|
| 2146 |
-
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k,
|
| 2147 |
-
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v,
|
| 2148 |
( n_ctx)*ggml_element_size(kv_self.v),
|
| 2149 |
-
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state +
|
| 2150 |
|
| 2151 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
| 2152 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2156,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2156 |
|
| 2157 |
struct ggml_tensor * Q =
|
| 2158 |
ggml_permute(ctx0,
|
| 2159 |
-
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
| 2160 |
0, 2, 1, 3);
|
| 2161 |
|
| 2162 |
struct ggml_tensor * K =
|
| 2163 |
ggml_view_3d(ctx0, kv_self.k,
|
| 2164 |
-
n_state/n_head,
|
| 2165 |
ggml_element_size(kv_self.k)*n_state,
|
| 2166 |
ggml_element_size(kv_self.k)*n_state/n_head,
|
| 2167 |
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
@@ -2171,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2171 |
|
| 2172 |
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
| 2173 |
|
| 2174 |
-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
|
|
| 2175 |
|
| 2176 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 2177 |
|
| 2178 |
struct ggml_tensor * V =
|
| 2179 |
ggml_view_3d(ctx0, kv_self.v,
|
| 2180 |
-
|
| 2181 |
n_ctx*ggml_element_size(kv_self.v),
|
| 2182 |
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
| 2183 |
-
|
| 2184 |
|
| 2185 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2186 |
|
|
@@ -2188,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2188 |
|
| 2189 |
cur = ggml_cpy(ctx0,
|
| 2190 |
KQV_merged,
|
| 2191 |
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state,
|
| 2192 |
}
|
| 2193 |
|
| 2194 |
// projection
|
|
@@ -2232,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2232 |
// Kcross is already scaled
|
| 2233 |
struct ggml_tensor * Kcross =
|
| 2234 |
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2235 |
-
n_state/n_head,
|
| 2236 |
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2237 |
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
| 2238 |
-
ggml_element_size(wstate.kv_cross.k)*n_state*
|
| 2239 |
|
| 2240 |
//struct ggml_tensor * Vcross =
|
| 2241 |
// ggml_reshape_3d(ctx0,
|
| 2242 |
-
// ggml_view_1d(ctx0, wstate.kv_cross.v,
|
| 2243 |
-
// n_state/n_head, n_head,
|
| 2244 |
|
| 2245 |
//struct ggml_tensor * V_trans =
|
| 2246 |
// ggml_cpy(ctx0,
|
| 2247 |
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
| 2248 |
-
// ggml_new_tensor_3d(ctx0, Vcross->type,
|
| 2249 |
|
| 2250 |
struct ggml_tensor * V =
|
| 2251 |
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
| 2252 |
-
|
| 2253 |
-
|
| 2254 |
-
|
| 2255 |
-
|
| 2256 |
|
| 2257 |
// ------
|
| 2258 |
|
| 2259 |
struct ggml_tensor * Q =
|
| 2260 |
ggml_permute(ctx0,
|
| 2261 |
-
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
| 2262 |
0, 2, 1, 3);
|
| 2263 |
|
| 2264 |
// K * Q
|
|
@@ -2279,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2279 |
|
| 2280 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2281 |
|
| 2282 |
-
// cur = KQV_merged.contiguous().view(n_state,
|
| 2283 |
cur = ggml_cpy(ctx0,
|
| 2284 |
KQV_merged,
|
| 2285 |
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state,
|
| 2286 |
}
|
| 2287 |
|
| 2288 |
// projection
|
|
@@ -2354,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2354 |
}
|
| 2355 |
|
| 2356 |
// compute logits only for the last token
|
| 2357 |
-
// comment this line to compute logits for all
|
| 2358 |
// might be useful in the future
|
| 2359 |
-
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
| 2360 |
|
| 2361 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2362 |
|
|
@@ -2380,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2380 |
static bool whisper_decode_internal(
|
| 2381 |
whisper_context & wctx,
|
| 2382 |
whisper_state & wstate,
|
| 2383 |
-
|
| 2384 |
-
const whisper_token * tokens,
|
| 2385 |
-
const int n_tokens,
|
| 2386 |
-
const int n_past,
|
| 2387 |
const int n_threads,
|
| 2388 |
whisper_abort_callback abort_callback,
|
| 2389 |
void * abort_callback_data) {
|
|
@@ -2392,19 +2568,33 @@ static bool whisper_decode_internal(
|
|
| 2392 |
const auto & model = wctx.model;
|
| 2393 |
const auto & hparams = model.hparams;
|
| 2394 |
|
| 2395 |
-
const int n_vocab
|
|
|
|
| 2396 |
|
| 2397 |
auto & logits_out = wstate.logits;
|
| 2398 |
|
| 2399 |
struct ggml_tensor * logits;
|
| 2400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2401 |
// decoder
|
| 2402 |
{
|
| 2403 |
auto & alloc = wstate.alloc_decode.alloc;
|
| 2404 |
|
| 2405 |
ggml_allocr_reset(alloc);
|
| 2406 |
|
| 2407 |
-
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate,
|
| 2408 |
|
| 2409 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 2410 |
|
|
@@ -2413,17 +2603,15 @@ static bool whisper_decode_internal(
|
|
| 2413 |
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
| 2414 |
}
|
| 2415 |
|
| 2416 |
-
|
| 2417 |
-
|
| 2418 |
-
|
| 2419 |
-
|
| 2420 |
-
|
| 2421 |
-
|
| 2422 |
-
|
| 2423 |
-
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
| 2424 |
-
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
|
| 2425 |
|
| 2426 |
-
if (n_tokens > 1) {
|
| 2427 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 2428 |
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 2429 |
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
@@ -2432,18 +2620,20 @@ static bool whisper_decode_internal(
|
|
| 2432 |
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 2433 |
}
|
| 2434 |
|
| 2435 |
-
if (n_tokens == 1) {
|
| 2436 |
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
| 2437 |
wstate.n_decode++;
|
|
|
|
|
|
|
|
|
|
| 2438 |
} else {
|
| 2439 |
wstate.t_prompt_us += ggml_time_us() - t_start_us;
|
| 2440 |
-
wstate.n_prompt
|
| 2441 |
}
|
| 2442 |
|
| 2443 |
return !(abort_callback && abort_callback(abort_callback_data));
|
| 2444 |
}
|
| 2445 |
|
| 2446 |
-
|
| 2447 |
// 500 -> 00:05.000
|
| 2448 |
// 6000 -> 01:00.000
|
| 2449 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2855,14 +3045,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2855 |
|
| 2856 |
state->backend = whisper_backend_init(ctx->params);
|
| 2857 |
|
| 2858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2859 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2860 |
delete state;
|
| 2861 |
return nullptr;
|
| 2862 |
}
|
| 2863 |
|
| 2864 |
{
|
| 2865 |
-
const size_t memory_size = ggml_nbytes(state->
|
| 2866 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2867 |
}
|
| 2868 |
|
|
@@ -2897,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2897 |
|
| 2898 |
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
| 2899 |
|
| 2900 |
-
state->
|
| 2901 |
|
| 2902 |
// TAGS: WHISPER_DECODER_INIT
|
| 2903 |
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
| 2904 |
|
| 2905 |
-
state->decoders[0].probs.reserve
|
| 2906 |
-
state->decoders[0].logits.reserve
|
| 2907 |
-
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
| 2908 |
|
| 2909 |
// conv allocator
|
| 2910 |
{
|
|
@@ -2946,7 +3143,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2946 |
const int n_tokens = hparams.n_text_ctx;
|
| 2947 |
const int n_past = 0;
|
| 2948 |
|
| 2949 |
-
|
|
|
|
|
|
|
| 2950 |
});
|
| 2951 |
|
| 2952 |
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
@@ -2957,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2957 |
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
| 2958 |
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
| 2959 |
|
| 2960 |
-
state->rng = std::mt19937(0);
|
| 2961 |
-
|
| 2962 |
return state;
|
| 2963 |
}
|
| 2964 |
|
|
@@ -3183,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
| 3183 |
void whisper_free_state(struct whisper_state * state)
|
| 3184 |
{
|
| 3185 |
if (state) {
|
|
|
|
| 3186 |
kv_cache_free(state->kv_cross);
|
| 3187 |
|
| 3188 |
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
| 3189 |
-
kv_cache_free(state->decoders[i].kv_self);
|
| 3190 |
-
}
|
| 3191 |
-
|
| 3192 |
#ifdef WHISPER_USE_COREML
|
| 3193 |
if (state->ctx_coreml != nullptr) {
|
| 3194 |
whisper_coreml_free(state->ctx_coreml);
|
|
@@ -3203,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state)
|
|
| 3203 |
}
|
| 3204 |
#endif
|
| 3205 |
|
|
|
|
|
|
|
| 3206 |
whisper_allocr_free(state->alloc_conv);
|
| 3207 |
whisper_allocr_free(state->alloc_encode);
|
| 3208 |
whisper_allocr_free(state->alloc_cross);
|
|
@@ -3329,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
| 3329 |
}
|
| 3330 |
|
| 3331 |
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 3332 |
-
|
|
|
|
|
|
|
| 3333 |
|
| 3334 |
-
if (!whisper_decode_internal(*ctx, *state, state->
|
| 3335 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3336 |
return 1;
|
| 3337 |
}
|
|
@@ -3340,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
| 3340 |
}
|
| 3341 |
|
| 3342 |
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 3343 |
-
// TODO: add selected_decoder_id to state
|
| 3344 |
-
const int selected_decoder_id = 0;
|
| 3345 |
-
|
| 3346 |
if (ctx->state == nullptr) {
|
| 3347 |
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
| 3348 |
return false;
|
| 3349 |
}
|
| 3350 |
|
| 3351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3352 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3353 |
return 1;
|
| 3354 |
}
|
|
@@ -3436,7 +3635,7 @@ int whisper_lang_auto_detect_with_state(
|
|
| 3436 |
return -7;
|
| 3437 |
}
|
| 3438 |
|
| 3439 |
-
auto & logits_id = state->logits_id;
|
| 3440 |
logits_id.clear();
|
| 3441 |
|
| 3442 |
for (const auto & kv : g_lang) {
|
|
@@ -3639,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
| 3639 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
| 3640 |
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
| 3641 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
|
|
| 3642 |
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
| 3643 |
|
| 3644 |
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
@@ -3646,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
| 3646 |
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3647 |
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3648 |
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
|
|
| 3649 |
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
| 3650 |
}
|
| 3651 |
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
@@ -3662,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
| 3662 |
ctx->state->n_sample = 0;
|
| 3663 |
ctx->state->n_encode = 0;
|
| 3664 |
ctx->state->n_decode = 0;
|
|
|
|
| 3665 |
ctx->state->n_prompt = 0;
|
| 3666 |
}
|
| 3667 |
}
|
|
@@ -3969,8 +4171,7 @@ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_
|
|
| 3969 |
if (*tok.code_points == 0) {
|
| 3970 |
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
| 3971 |
// that cannot satisfy this position in grammar
|
| 3972 |
-
if (tok.partial_utf8.n_remain != 0 &&
|
| 3973 |
-
!whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
| 3974 |
rejects.push_back(tok);
|
| 3975 |
}
|
| 3976 |
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
|
|
@@ -4189,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 4189 |
/*.max_initial_ts =*/ 1.0f,
|
| 4190 |
/*.length_penalty =*/ -1.0f,
|
| 4191 |
|
| 4192 |
-
/*.temperature_inc =*/ 0.
|
| 4193 |
/*.entropy_thold =*/ 2.4f,
|
| 4194 |
/*.logprob_thold =*/ -1.0f,
|
| 4195 |
/*.no_speech_thold =*/ 0.6f,
|
|
@@ -4229,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 4229 |
case WHISPER_SAMPLING_GREEDY:
|
| 4230 |
{
|
| 4231 |
result.greedy = {
|
| 4232 |
-
/*.best_of =*/
|
| 4233 |
};
|
| 4234 |
} break;
|
| 4235 |
case WHISPER_SAMPLING_BEAM_SEARCH:
|
| 4236 |
{
|
| 4237 |
result.beam_search = {
|
| 4238 |
-
/*.beam_size =*/
|
| 4239 |
|
| 4240 |
/*.patience =*/ -1.0f,
|
| 4241 |
};
|
|
@@ -4325,11 +4526,12 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
| 4325 |
// process the logits for the selected decoder
|
| 4326 |
// - applies logit filters
|
| 4327 |
// - computes logprobs and probs
|
|
|
|
| 4328 |
static void whisper_process_logits(
|
| 4329 |
struct whisper_context & ctx,
|
| 4330 |
struct whisper_state & state,
|
| 4331 |
-
const struct whisper_full_params params,
|
| 4332 |
struct whisper_decoder & decoder,
|
|
|
|
| 4333 |
float temperature) {
|
| 4334 |
const auto & vocab = ctx.vocab;
|
| 4335 |
const auto & tokens_cur = decoder.sequence.tokens;
|
|
@@ -4346,7 +4548,7 @@ static void whisper_process_logits(
|
|
| 4346 |
auto & logprobs = decoder.logprobs;
|
| 4347 |
{
|
| 4348 |
logits.resize(n_logits);
|
| 4349 |
-
memcpy(logits.data(), state.logits.data() +
|
| 4350 |
|
| 4351 |
if (temperature > 0.0f) {
|
| 4352 |
for (int i = 0; i < n_logits; i++) {
|
|
@@ -4512,30 +4714,31 @@ static void whisper_process_logits(
|
|
| 4512 |
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
| 4513 |
|
| 4514 |
if (timestamp_logprob > max_text_token_logprob) {
|
| 4515 |
-
//printf("sampling timestamp\n");
|
| 4516 |
for (int i = 0; i < vocab.token_beg; ++i) {
|
| 4517 |
logits[i] = -INFINITY;
|
| 4518 |
logprobs[i] = -INFINITY;
|
| 4519 |
}
|
| 4520 |
-
} else
|
| 4521 |
-
|
|
|
|
| 4522 |
|
| 4523 |
-
|
| 4524 |
-
|
| 4525 |
-
|
| 4526 |
-
|
| 4527 |
-
|
| 4528 |
-
|
| 4529 |
-
|
|
|
|
| 4530 |
}
|
| 4531 |
-
|
| 4532 |
-
logsumexp = logf(logsumexp) + logit_max;
|
| 4533 |
|
| 4534 |
-
|
| 4535 |
-
|
| 4536 |
-
|
| 4537 |
-
|
| 4538 |
-
|
|
|
|
| 4539 |
}
|
| 4540 |
}
|
| 4541 |
}
|
|
@@ -4610,7 +4813,6 @@ static void whisper_process_logits(
|
|
| 4610 |
|
| 4611 |
static whisper_token_data whisper_sample_token(
|
| 4612 |
whisper_context & ctx,
|
| 4613 |
-
whisper_state & state,
|
| 4614 |
const whisper_decoder & decoder,
|
| 4615 |
bool best) {
|
| 4616 |
whisper_token_data result = {
|
|
@@ -4655,7 +4857,7 @@ static whisper_token_data whisper_sample_token(
|
|
| 4655 |
} else {
|
| 4656 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 4657 |
|
| 4658 |
-
result.id = dist(
|
| 4659 |
result.p = probs[result.id];
|
| 4660 |
result.plog = logprobs[result.id];
|
| 4661 |
}
|
|
@@ -4665,15 +4867,12 @@ static whisper_token_data whisper_sample_token(
|
|
| 4665 |
result.pt = result.p;
|
| 4666 |
}
|
| 4667 |
|
| 4668 |
-
state.n_sample++;
|
| 4669 |
-
|
| 4670 |
return result;
|
| 4671 |
}
|
| 4672 |
|
| 4673 |
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
| 4674 |
whisper_context & ctx,
|
| 4675 |
-
|
| 4676 |
-
const whisper_decoder & decoder,
|
| 4677 |
int k) {
|
| 4678 |
const auto & vocab = ctx.vocab;
|
| 4679 |
|
|
@@ -4683,7 +4882,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 4683 |
|
| 4684 |
const int n_logits = vocab.n_vocab;
|
| 4685 |
|
| 4686 |
-
auto & logits_id =
|
| 4687 |
|
| 4688 |
logits_id.resize(n_logits);
|
| 4689 |
for (int i = 0; i < n_logits; ++i) {
|
|
@@ -4732,7 +4931,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 4732 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 4733 |
|
| 4734 |
for (int i = 0; i < k; ++i) {
|
| 4735 |
-
const auto id = dist(
|
| 4736 |
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
| 4737 |
|
| 4738 |
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
@@ -4743,8 +4942,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 4743 |
}
|
| 4744 |
}
|
| 4745 |
|
| 4746 |
-
state.n_sample++;
|
| 4747 |
-
|
| 4748 |
return result;
|
| 4749 |
}
|
| 4750 |
|
|
@@ -4797,125 +4994,6 @@ static void whisper_sequence_score(
|
|
| 4797 |
}
|
| 4798 |
}
|
| 4799 |
|
| 4800 |
-
static bool whisper_kv_swap_fast(
|
| 4801 |
-
std::vector<int> & view,
|
| 4802 |
-
whisper_decoder src[],
|
| 4803 |
-
std::vector<kv_buf> & kv_swap_bufs,
|
| 4804 |
-
const int & n_decoders) {
|
| 4805 |
-
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
| 4806 |
-
|
| 4807 |
-
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
| 4808 |
-
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
| 4809 |
-
|
| 4810 |
-
// (buffer->decoder or decoder->decoder)
|
| 4811 |
-
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
| 4812 |
-
|
| 4813 |
-
// (decoder<->decoder)
|
| 4814 |
-
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
| 4815 |
-
std::vector<whisper_pair<int, int>> p_swap_vec;
|
| 4816 |
-
p_swap_vec.reserve(n_decoders);
|
| 4817 |
-
|
| 4818 |
-
// see https://github.com/ggerganov/whisper.cpp/wiki
|
| 4819 |
-
for (int i = 0; i < n_decoders; i++) {
|
| 4820 |
-
// zero-copy (no modification)
|
| 4821 |
-
if (i == view[i] || view[i] < 0) {
|
| 4822 |
-
continue;
|
| 4823 |
-
}
|
| 4824 |
-
|
| 4825 |
-
bool is_one_copy = true;
|
| 4826 |
-
// since we modify data sequentially, we only consider decoder indices after current index
|
| 4827 |
-
for (int j = i + 1; j < n_decoders; j++) {
|
| 4828 |
-
if (i == view[j]) {
|
| 4829 |
-
// detect symmetric diagram
|
| 4830 |
-
if (j == view[i]) {
|
| 4831 |
-
p_swap_set.insert(i);
|
| 4832 |
-
p_swap_set.insert(j);
|
| 4833 |
-
p_swap_vec.emplace_back(i, j);
|
| 4834 |
-
} else {
|
| 4835 |
-
two_copy.insert(i);
|
| 4836 |
-
is_one_copy = false;
|
| 4837 |
-
}
|
| 4838 |
-
break;
|
| 4839 |
-
}
|
| 4840 |
-
}
|
| 4841 |
-
if (is_one_copy) {
|
| 4842 |
-
one_copy.insert(i);
|
| 4843 |
-
}
|
| 4844 |
-
}
|
| 4845 |
-
|
| 4846 |
-
kv_swap_bufs.resize(n_decoders);
|
| 4847 |
-
|
| 4848 |
-
for (int i = 0; i < n_decoders; i++) {
|
| 4849 |
-
kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
|
| 4850 |
-
kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
|
| 4851 |
-
}
|
| 4852 |
-
|
| 4853 |
-
for (auto & i : two_copy) {
|
| 4854 |
-
// make a copy of KV caches
|
| 4855 |
-
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
| 4856 |
-
//memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
| 4857 |
-
//memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
| 4858 |
-
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
|
| 4859 |
-
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
|
| 4860 |
-
}
|
| 4861 |
-
|
| 4862 |
-
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
| 4863 |
-
for (auto & i : two_copy) {
|
| 4864 |
-
// skip the decoder indices that require pointer swapping
|
| 4865 |
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
| 4866 |
-
continue;
|
| 4867 |
-
}
|
| 4868 |
-
|
| 4869 |
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4870 |
-
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4871 |
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4872 |
-
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4873 |
-
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
| 4874 |
-
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
| 4875 |
-
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
| 4876 |
-
} else {
|
| 4877 |
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4878 |
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4879 |
-
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4880 |
-
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
| 4881 |
-
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
| 4882 |
-
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
| 4883 |
-
}
|
| 4884 |
-
}
|
| 4885 |
-
|
| 4886 |
-
// then modify one-copy decoder KV caches
|
| 4887 |
-
for (auto & i : one_copy) {
|
| 4888 |
-
// skip the decoder indices that require pointer swapping
|
| 4889 |
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
| 4890 |
-
continue;
|
| 4891 |
-
}
|
| 4892 |
-
|
| 4893 |
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4894 |
-
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4895 |
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4896 |
-
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4897 |
-
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
| 4898 |
-
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
| 4899 |
-
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
| 4900 |
-
} else {
|
| 4901 |
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4902 |
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4903 |
-
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4904 |
-
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
| 4905 |
-
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
| 4906 |
-
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
| 4907 |
-
}
|
| 4908 |
-
}
|
| 4909 |
-
|
| 4910 |
-
// swap the pointers
|
| 4911 |
-
for (auto & i : p_swap_vec) {
|
| 4912 |
-
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
| 4913 |
-
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
| 4914 |
-
}
|
| 4915 |
-
|
| 4916 |
-
return true;
|
| 4917 |
-
}
|
| 4918 |
-
|
| 4919 |
int whisper_full_with_state(
|
| 4920 |
struct whisper_context * ctx,
|
| 4921 |
struct whisper_state * state,
|
|
@@ -5005,25 +5083,23 @@ int whisper_full_with_state(
|
|
| 5005 |
|
| 5006 |
n_decoders = std::max(1, n_decoders);
|
| 5007 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5008 |
// TAGS: WHISPER_DECODER_INIT
|
| 5009 |
for (int j = 1; j < n_decoders; j++) {
|
| 5010 |
auto & decoder = state->decoders[j];
|
| 5011 |
|
| 5012 |
-
|
| 5013 |
-
decoder.kv_self = state->decoders[0].kv_self;
|
| 5014 |
-
if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
|
| 5015 |
-
WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
| 5016 |
-
return -4;
|
| 5017 |
-
}
|
| 5018 |
-
|
| 5019 |
-
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
| 5020 |
|
| 5021 |
-
|
|
|
|
|
|
|
|
|
|
| 5022 |
|
| 5023 |
-
|
| 5024 |
-
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 5025 |
-
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
| 5026 |
-
}
|
| 5027 |
}
|
| 5028 |
|
| 5029 |
// the accumulated text context so far
|
|
@@ -5100,8 +5176,10 @@ int whisper_full_with_state(
|
|
| 5100 |
bool has_ts;
|
| 5101 |
|
| 5102 |
whisper_sequence sequence;
|
|
|
|
| 5103 |
};
|
| 5104 |
|
|
|
|
| 5105 |
std::vector<beam_candidate> beam_candidates;
|
| 5106 |
|
| 5107 |
// main loop
|
|
@@ -5169,8 +5247,6 @@ int whisper_full_with_state(
|
|
| 5169 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5170 |
auto & decoder = state->decoders[j];
|
| 5171 |
|
| 5172 |
-
decoder.kv_self.n = 0;
|
| 5173 |
-
|
| 5174 |
decoder.sequence.tokens.clear();
|
| 5175 |
decoder.sequence.result_len = 0;
|
| 5176 |
decoder.sequence.sum_logprobs_all = 0.0;
|
|
@@ -5186,15 +5262,14 @@ int whisper_full_with_state(
|
|
| 5186 |
decoder.has_ts = false;
|
| 5187 |
|
| 5188 |
if (params.grammar_rules != nullptr) {
|
| 5189 |
-
decoder.grammar = whisper_grammar_init(
|
| 5190 |
-
params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
| 5191 |
} else {
|
| 5192 |
decoder.grammar = {};
|
| 5193 |
}
|
| 5194 |
}
|
| 5195 |
|
| 5196 |
// init prompt and kv cache for the current iteration
|
| 5197 |
-
//
|
| 5198 |
{
|
| 5199 |
prompt.clear();
|
| 5200 |
|
|
@@ -5216,7 +5291,11 @@ int whisper_full_with_state(
|
|
| 5216 |
}
|
| 5217 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 5218 |
|
| 5219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5220 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5221 |
return -7;
|
| 5222 |
}
|
|
@@ -5224,20 +5303,14 @@ int whisper_full_with_state(
|
|
| 5224 |
{
|
| 5225 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5226 |
|
| 5227 |
-
|
| 5228 |
|
| 5229 |
-
state->decoders[0]
|
| 5230 |
|
| 5231 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 5232 |
auto & decoder = state->decoders[j];
|
| 5233 |
|
| 5234 |
-
|
| 5235 |
-
//memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
| 5236 |
-
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
| 5237 |
-
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
|
| 5238 |
-
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
|
| 5239 |
-
|
| 5240 |
-
decoder.kv_self.n += prompt.size();
|
| 5241 |
|
| 5242 |
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
| 5243 |
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
@@ -5252,41 +5325,81 @@ int whisper_full_with_state(
|
|
| 5252 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5253 |
|
| 5254 |
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
| 5255 |
-
|
|
|
|
|
|
|
| 5256 |
}
|
| 5257 |
|
| 5258 |
-
//
|
| 5259 |
-
|
| 5260 |
-
|
|
|
|
| 5261 |
|
| 5262 |
-
|
| 5263 |
-
|
| 5264 |
-
|
| 5265 |
|
| 5266 |
-
|
| 5267 |
-
|
| 5268 |
-
|
| 5269 |
-
if (t_cur < 1e-6f) {
|
| 5270 |
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
| 5271 |
-
} else {
|
| 5272 |
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
| 5273 |
-
}
|
| 5274 |
|
| 5275 |
-
|
| 5276 |
-
} break;
|
| 5277 |
-
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
| 5278 |
-
{
|
| 5279 |
-
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
| 5280 |
|
| 5281 |
-
|
| 5282 |
-
|
| 5283 |
-
|
| 5284 |
-
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
| 5285 |
|
| 5286 |
-
|
| 5287 |
-
|
| 5288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5289 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5290 |
}
|
| 5291 |
|
| 5292 |
// for beam-search, choose the top candidates and update the KV caches
|
|
@@ -5299,7 +5412,6 @@ int whisper_full_with_state(
|
|
| 5299 |
});
|
| 5300 |
|
| 5301 |
uint32_t cur_c = 0;
|
| 5302 |
-
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
| 5303 |
|
| 5304 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5305 |
auto & decoder = state->decoders[j];
|
|
@@ -5318,17 +5430,28 @@ int whisper_full_with_state(
|
|
| 5318 |
++cur_c;
|
| 5319 |
}
|
| 5320 |
|
| 5321 |
-
decoder.sequence = cur.sequence;
|
| 5322 |
decoder.seek_delta = cur.seek_delta;
|
| 5323 |
decoder.has_ts = cur.has_ts;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5324 |
|
| 5325 |
-
decoder_idx[j] = cur.decoder_idx;
|
| 5326 |
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
| 5327 |
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
| 5328 |
}
|
| 5329 |
|
| 5330 |
-
|
| 5331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5332 |
}
|
| 5333 |
|
| 5334 |
// update the decoder state
|
|
@@ -5437,32 +5560,83 @@ int whisper_full_with_state(
|
|
| 5437 |
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 5438 |
|
| 5439 |
// obtain logits for the next token
|
| 5440 |
-
|
| 5441 |
-
auto &
|
| 5442 |
|
| 5443 |
-
|
| 5444 |
-
|
| 5445 |
-
|
|
|
|
|
|
|
|
|
|
| 5446 |
|
| 5447 |
-
|
| 5448 |
-
|
|
|
|
| 5449 |
|
| 5450 |
-
|
| 5451 |
|
| 5452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5453 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5454 |
return -8;
|
| 5455 |
}
|
| 5456 |
|
|
|
|
|
|
|
|
|
|
| 5457 |
{
|
| 5458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5459 |
|
| 5460 |
-
|
| 5461 |
|
| 5462 |
-
|
|
|
|
|
|
|
| 5463 |
|
| 5464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5465 |
}
|
|
|
|
|
|
|
| 5466 |
}
|
| 5467 |
}
|
| 5468 |
|
|
@@ -5759,11 +5933,13 @@ int whisper_full_parallel(
|
|
| 5759 |
ctx->state->t_sample_us += states[i]->t_sample_us;
|
| 5760 |
ctx->state->t_encode_us += states[i]->t_encode_us;
|
| 5761 |
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
|
|
| 5762 |
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
| 5763 |
|
| 5764 |
ctx->state->n_sample += states[i]->n_sample;
|
| 5765 |
ctx->state->n_encode += states[i]->n_encode;
|
| 5766 |
ctx->state->n_decode += states[i]->n_decode;
|
|
|
|
| 5767 |
ctx->state->n_prompt += states[i]->n_prompt;
|
| 5768 |
|
| 5769 |
whisper_free_state(states[i]);
|
|
|
|
| 20 |
#include "ggml-alloc.h"
|
| 21 |
#include "ggml-backend.h"
|
| 22 |
|
| 23 |
+
#include <atomic>
|
| 24 |
#include <algorithm>
|
| 25 |
#include <cassert>
|
| 26 |
#define _USE_MATH_DEFINES
|
|
|
|
| 148 |
|
| 149 |
//#define WHISPER_USE_FLASH_ATTN
|
| 150 |
//#define WHISPER_USE_FLASH_FF
|
| 151 |
+
#define WHISPER_MAX_DECODERS 8
|
| 152 |
#define WHISPER_MAX_NODES 4096
|
| 153 |
|
| 154 |
//
|
|
|
|
| 407 |
bool speaker_turn_next;
|
| 408 |
};
|
| 409 |
|
| 410 |
+
struct whisper_batch {
|
| 411 |
+
int32_t n_tokens;
|
| 412 |
+
|
| 413 |
+
whisper_token * token;
|
| 414 |
+
whisper_pos * pos;
|
| 415 |
+
int32_t * n_seq_id;
|
| 416 |
+
whisper_seq_id ** seq_id; // null terminated
|
| 417 |
+
int8_t * logits;
|
| 418 |
+
};
|
| 419 |
+
|
| 420 |
+
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
|
| 421 |
+
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
|
| 422 |
+
|
| 423 |
+
batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
|
| 424 |
+
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
|
| 425 |
+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
|
| 426 |
+
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
|
| 427 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 428 |
+
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
|
| 429 |
+
}
|
| 430 |
+
batch.seq_id[n_tokens] = nullptr;
|
| 431 |
+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
| 432 |
+
|
| 433 |
+
return batch;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
static void whisper_batch_free(struct whisper_batch batch) {
|
| 437 |
+
if (batch.token) free(batch.token);
|
| 438 |
+
if (batch.pos) free(batch.pos);
|
| 439 |
+
if (batch.n_seq_id) free(batch.n_seq_id);
|
| 440 |
+
if (batch.seq_id) {
|
| 441 |
+
for (int i = 0; batch.seq_id[i]; ++i) {
|
| 442 |
+
free(batch.seq_id[i]);
|
| 443 |
+
}
|
| 444 |
+
free(batch.seq_id);
|
| 445 |
+
}
|
| 446 |
+
if (batch.logits) free(batch.logits);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
|
| 450 |
+
batch.n_tokens = n_tokens;
|
| 451 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 452 |
+
if (tokens) {
|
| 453 |
+
batch.token[i] = tokens[i];
|
| 454 |
+
}
|
| 455 |
+
batch.pos [i] = n_past + i;
|
| 456 |
+
batch.n_seq_id[i] = 1;
|
| 457 |
+
batch.seq_id [i][0] = seq_id;
|
| 458 |
+
batch.logits [i] = 0;
|
| 459 |
+
}
|
| 460 |
+
batch.logits[n_tokens - 1] = 1;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
| 464 |
+
template<typename A, typename B>
|
| 465 |
+
struct whisper_pair {
|
| 466 |
+
A first;
|
| 467 |
+
B second;
|
| 468 |
+
|
| 469 |
+
// Define a constructor that takes two arguments.
|
| 470 |
+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
| 471 |
+
// Define a constructor that takes no argument.
|
| 472 |
+
whisper_pair() : first(A()), second(B()) {}
|
| 473 |
+
};
|
| 474 |
+
|
| 475 |
+
// ggml_allocr wrapper for whisper usage
|
| 476 |
+
struct whisper_allocr {
|
| 477 |
+
ggml_allocr * alloc = nullptr;
|
| 478 |
+
|
| 479 |
+
std::vector<uint8_t> meta;
|
| 480 |
+
|
| 481 |
+
ggml_backend_buffer_t buffer;
|
| 482 |
+
};
|
| 483 |
+
|
| 484 |
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
| 485 |
+
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
| 489 |
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
|
| 490 |
+
auto & alloc = allocr.alloc;
|
| 491 |
+
auto & meta = allocr.meta;
|
| 492 |
+
|
| 493 |
+
alloc = ggml_allocr_new_measure_from_backend(backend);
|
| 494 |
+
|
| 495 |
+
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
| 496 |
+
|
| 497 |
+
ggml_allocr_alloc_graph(alloc, get_graph());
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
| 501 |
+
if (allocr.alloc == nullptr) {
|
| 502 |
+
// this can be null if we use external encoder like CoreML or OpenVINO
|
| 503 |
+
return;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
auto & alloc = allocr.alloc;
|
| 507 |
+
auto & buffer = allocr.buffer;
|
| 508 |
+
|
| 509 |
+
size_t size = ggml_allocr_max_size(alloc);
|
| 510 |
+
|
| 511 |
+
ggml_allocr_free(alloc);
|
| 512 |
+
|
| 513 |
+
buffer = ggml_backend_alloc_buffer(backend, size);
|
| 514 |
+
alloc = ggml_allocr_new_from_buffer(buffer);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
| 518 |
+
if (allocr.alloc) {
|
| 519 |
+
ggml_allocr_free(allocr.alloc);
|
| 520 |
+
ggml_backend_buffer_free(allocr.buffer);
|
| 521 |
+
allocr.alloc = nullptr;
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
// medium
|
| 526 |
// hparams: {
|
| 527 |
// 'n_mels': 80,
|
|
|
|
| 639 |
struct ggml_tensor * mlp_1_b;
|
| 640 |
};
|
| 641 |
|
| 642 |
+
struct whisper_kv_cell {
|
| 643 |
+
whisper_pos pos = -1;
|
| 644 |
+
|
| 645 |
+
std::set<whisper_seq_id> seq_id;
|
| 646 |
+
|
| 647 |
+
bool has_seq_id(const whisper_seq_id & id) const {
|
| 648 |
+
return seq_id.find(id) != seq_id.end();
|
| 649 |
+
}
|
| 650 |
+
};
|
| 651 |
+
|
| 652 |
struct whisper_kv_cache {
|
| 653 |
+
uint32_t head = 0;
|
| 654 |
+
uint32_t size = 0;
|
| 655 |
+
|
| 656 |
+
// computed before each graph build
|
| 657 |
+
uint32_t n = 0;
|
| 658 |
+
|
| 659 |
+
std::vector<whisper_kv_cell> cells;
|
| 660 |
+
|
| 661 |
struct ggml_tensor * k;
|
| 662 |
struct ggml_tensor * v;
|
| 663 |
|
| 664 |
struct ggml_context * ctx;
|
| 665 |
|
| 666 |
ggml_backend_buffer_t buffer;
|
|
|
|
|
|
|
| 667 |
};
|
| 668 |
|
| 669 |
struct whisper_model {
|
|
|
|
| 717 |
};
|
| 718 |
|
| 719 |
struct whisper_grammar {
|
| 720 |
+
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
|
| 721 |
+
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
| 722 |
|
| 723 |
// buffer for partially generated UTF-8 sequence from accepted tokens
|
| 724 |
+
whisper_partial_utf8 partial_utf8;
|
| 725 |
};
|
| 726 |
|
| 727 |
struct whisper_grammar_candidate {
|
|
|
|
| 745 |
|
| 746 |
// TAGS: WHISPER_DECODER_INIT
|
| 747 |
struct whisper_decoder {
|
|
|
|
|
|
|
|
|
|
| 748 |
// the currently generated sequence of tokens
|
| 749 |
whisper_sequence sequence;
|
| 750 |
|
| 751 |
// grammar parse state of generated sequence of tokens
|
| 752 |
whisper_grammar grammar;
|
| 753 |
|
| 754 |
+
int i_batch; // the index of the token in the current batch
|
| 755 |
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
| 756 |
|
| 757 |
bool failed; // has the current segment failed to decode?
|
|
|
|
| 763 |
std::vector<float> logits;
|
| 764 |
std::vector<float> logprobs;
|
| 765 |
|
| 766 |
+
// work container used to avoid memory allocations
|
| 767 |
+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
|
| 769 |
+
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
| 770 |
};
|
| 771 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
struct whisper_state {
|
| 773 |
int64_t t_sample_us = 0;
|
| 774 |
int64_t t_encode_us = 0;
|
| 775 |
int64_t t_decode_us = 0;
|
| 776 |
+
int64_t t_batchd_us = 0;
|
| 777 |
int64_t t_prompt_us = 0;
|
| 778 |
int64_t t_mel_us = 0;
|
| 779 |
|
| 780 |
int32_t n_sample = 0; // number of tokens sampled
|
| 781 |
int32_t n_encode = 0; // number of encoder calls
|
| 782 |
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
| 783 |
+
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
|
| 784 |
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
| 785 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 786 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 787 |
|
| 788 |
+
// unified self-attention KV cache for all decoders
|
| 789 |
+
whisper_kv_cache kv_self;
|
| 790 |
+
|
| 791 |
// cross-attention KV cache for the decoders
|
| 792 |
// shared between all decoders
|
| 793 |
whisper_kv_cache kv_cross;
|
| 794 |
+
|
| 795 |
whisper_mel mel;
|
| 796 |
|
| 797 |
+
whisper_batch batch;
|
| 798 |
|
| 799 |
+
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
|
|
|
| 800 |
|
| 801 |
ggml_backend_t backend = nullptr;
|
| 802 |
|
|
|
|
| 812 |
struct ggml_tensor * embd_conv = nullptr;
|
| 813 |
struct ggml_tensor * embd_enc = nullptr;
|
| 814 |
|
| 815 |
+
// helpers for GPU offloading
|
| 816 |
std::vector<float> inp_mel;
|
| 817 |
+
std::vector<float> inp_mask;
|
| 818 |
|
| 819 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 820 |
std::vector<float> logits;
|
|
|
|
| 822 |
std::vector<whisper_segment> result_all;
|
| 823 |
std::vector<whisper_token> prompt_past;
|
| 824 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
int lang_id = 0; // english by default
|
| 826 |
|
| 827 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
|
|
| 897 |
/*.no_alloc =*/ true,
|
| 898 |
};
|
| 899 |
|
| 900 |
+
cache.head = 0;
|
| 901 |
+
cache.size = n_ctx;
|
| 902 |
+
|
| 903 |
+
cache.cells.clear();
|
| 904 |
+
cache.cells.resize(n_ctx);
|
| 905 |
+
|
| 906 |
cache.ctx = ggml_init(params);
|
| 907 |
|
| 908 |
if (!cache.ctx) {
|
|
|
|
| 930 |
return true;
|
| 931 |
}
|
| 932 |
|
| 933 |
+
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
| 934 |
+
if (cache.ctx) {
|
| 935 |
+
ggml_free(cache.ctx);
|
| 936 |
+
ggml_backend_buffer_free(cache.buffer);
|
| 937 |
+
cache.ctx = nullptr;
|
| 938 |
+
}
|
| 939 |
+
}
|
| 940 |
|
| 941 |
+
static bool whisper_kv_cache_find_slot(
|
| 942 |
+
struct whisper_kv_cache & cache,
|
| 943 |
+
const struct whisper_batch & batch) {
|
| 944 |
+
const uint32_t n_ctx = cache.size;
|
| 945 |
+
const uint32_t n_tokens = batch.n_tokens;
|
| 946 |
|
| 947 |
+
if (n_tokens > n_ctx) {
|
| 948 |
+
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
| 949 |
+
return false;
|
| 950 |
+
}
|
| 951 |
|
| 952 |
+
uint32_t n_tested = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
|
| 954 |
+
while (true) {
|
| 955 |
+
if (cache.head + n_tokens > n_ctx) {
|
| 956 |
+
n_tested += n_ctx - cache.head;
|
| 957 |
+
cache.head = 0;
|
| 958 |
+
continue;
|
| 959 |
+
}
|
| 960 |
|
| 961 |
+
bool found = true;
|
| 962 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 963 |
+
if (cache.cells[cache.head + i].pos >= 0) {
|
| 964 |
+
found = false;
|
| 965 |
+
cache.head += i + 1;
|
| 966 |
+
n_tested += i + 1;
|
| 967 |
+
break;
|
| 968 |
+
}
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
if (found) {
|
| 972 |
+
break;
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
if (n_tested >= n_ctx) {
|
| 976 |
+
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 977 |
+
return false;
|
| 978 |
+
}
|
| 979 |
}
|
| 980 |
|
| 981 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 982 |
+
cache.cells[cache.head + i].pos = batch.pos[i];
|
| 983 |
|
| 984 |
+
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
| 985 |
+
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
|
| 986 |
+
}
|
| 987 |
+
}
|
| 988 |
|
| 989 |
+
return true;
|
| 990 |
+
}
|
| 991 |
|
| 992 |
+
// find how many cells are currently in use
|
| 993 |
+
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
|
| 994 |
+
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
| 995 |
+
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
| 996 |
+
return i + 1;
|
| 997 |
+
}
|
| 998 |
+
}
|
| 999 |
|
| 1000 |
+
return 1;
|
| 1001 |
+
}
|
| 1002 |
|
| 1003 |
+
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
| 1004 |
+
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
| 1005 |
+
cache.cells[i].pos = -1;
|
| 1006 |
+
cache.cells[i].seq_id.clear();
|
| 1007 |
}
|
| 1008 |
+
cache.head = 0;
|
| 1009 |
+
}
|
| 1010 |
|
| 1011 |
+
static void whisper_kv_cache_seq_rm(
|
| 1012 |
+
struct whisper_kv_cache & cache,
|
| 1013 |
+
whisper_seq_id seq_id,
|
| 1014 |
+
whisper_pos p0,
|
| 1015 |
+
whisper_pos p1) {
|
| 1016 |
+
uint32_t new_head = cache.size;
|
| 1017 |
+
|
| 1018 |
+
if (p0 < 0) p0 = 0;
|
| 1019 |
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
| 1020 |
+
|
| 1021 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 1022 |
+
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 1023 |
+
if (seq_id < 0) {
|
| 1024 |
+
cache.cells[i].seq_id.clear();
|
| 1025 |
+
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
| 1026 |
+
cache.cells[i].seq_id.erase(seq_id);
|
| 1027 |
+
} else {
|
| 1028 |
+
continue;
|
| 1029 |
+
}
|
| 1030 |
+
if (cache.cells[i].seq_id.empty()) {
|
| 1031 |
+
cache.cells[i].pos = -1;
|
| 1032 |
+
if (new_head == cache.size) new_head = i;
|
| 1033 |
+
}
|
| 1034 |
+
}
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 1038 |
+
if (new_head != cache.size) cache.head = new_head;
|
| 1039 |
}
|
| 1040 |
|
| 1041 |
+
static void whisper_kv_cache_seq_cp(
|
| 1042 |
+
struct whisper_kv_cache & cache,
|
| 1043 |
+
whisper_seq_id seq_id_src,
|
| 1044 |
+
whisper_seq_id seq_id_dst,
|
| 1045 |
+
whisper_pos p0,
|
| 1046 |
+
whisper_pos p1) {
|
| 1047 |
+
if (p0 < 0) p0 = 0;
|
| 1048 |
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
| 1049 |
+
|
| 1050 |
+
cache.head = 0;
|
| 1051 |
+
|
| 1052 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 1053 |
+
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 1054 |
+
cache.cells[i].seq_id.insert(seq_id_dst);
|
| 1055 |
+
}
|
| 1056 |
}
|
| 1057 |
}
|
| 1058 |
|
|
|
|
| 1061 |
|
| 1062 |
// initialize the backends
|
| 1063 |
#ifdef GGML_USE_CUBLAS
|
| 1064 |
+
if (params.use_gpu && ggml_cublas_loaded()) {
|
| 1065 |
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
| 1066 |
backend_gpu = ggml_backend_cuda_init();
|
| 1067 |
if (!backend_gpu) {
|
|
|
|
| 1263 |
word = "[_EOT_]";
|
| 1264 |
} else if (i == vocab.token_sot) {
|
| 1265 |
word = "[_SOT_]";
|
| 1266 |
+
} else if (i == vocab.token_translate) {
|
| 1267 |
+
word = "[_TRANSLATE_]";
|
| 1268 |
+
} else if (i == vocab.token_transcribe) {
|
| 1269 |
+
word = "[_TRANSCRIBE_]";
|
| 1270 |
} else if (i == vocab.token_solm) {
|
| 1271 |
word = "[_SOLM_]";
|
| 1272 |
} else if (i == vocab.token_prev) {
|
|
|
|
| 1277 |
word = "[_NOT_]";
|
| 1278 |
} else if (i == vocab.token_beg) {
|
| 1279 |
word = "[_BEG_]";
|
| 1280 |
+
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
|
| 1281 |
+
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
|
| 1282 |
} else {
|
| 1283 |
word = "[_extra_token_" + std::to_string(i) + "]";
|
| 1284 |
}
|
|
|
|
| 2184 |
static struct ggml_cgraph * whisper_build_graph_decoder(
|
| 2185 |
whisper_context & wctx,
|
| 2186 |
whisper_state & wstate,
|
| 2187 |
+
const whisper_batch & batch) {
|
|
|
|
|
|
|
|
|
|
| 2188 |
const auto & model = wctx.model;
|
| 2189 |
const auto & hparams = model.hparams;
|
| 2190 |
|
| 2191 |
+
auto & kv_self = wstate.kv_self;
|
| 2192 |
|
| 2193 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 2194 |
|
| 2195 |
+
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
| 2196 |
+
|
| 2197 |
+
const int n_ctx = kv_self.size;
|
| 2198 |
const int n_state = hparams.n_text_state;
|
| 2199 |
const int n_head = hparams.n_text_head;
|
| 2200 |
const int n_layer = hparams.n_text_layer;
|
| 2201 |
|
| 2202 |
+
const int n_tokens = batch.n_tokens;
|
| 2203 |
+
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 2204 |
|
| 2205 |
+
const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
| 2206 |
+
const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
| 2207 |
+
|
| 2208 |
+
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
| 2209 |
|
| 2210 |
struct ggml_init_params params = {
|
| 2211 |
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
|
|
| 2217 |
|
| 2218 |
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
| 2219 |
|
| 2220 |
+
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
|
|
|
|
|
|
| 2221 |
ggml_allocr_alloc(alloc, embd);
|
| 2222 |
|
| 2223 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2224 |
+
ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
|
| 2225 |
}
|
| 2226 |
|
| 2227 |
+
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
| 2228 |
ggml_allocr_alloc(alloc, position);
|
| 2229 |
|
| 2230 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2231 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 2232 |
+
const int32_t val = batch.pos[i];
|
| 2233 |
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
| 2234 |
}
|
| 2235 |
}
|
|
|
|
| 2242 |
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
| 2243 |
}
|
| 2244 |
|
| 2245 |
+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
| 2246 |
+
ggml_allocr_alloc(alloc, KQ_mask);
|
| 2247 |
+
|
| 2248 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 2249 |
+
wstate.inp_mask.resize(n_kv*n_tokens);
|
| 2250 |
+
|
| 2251 |
+
float * data = wstate.inp_mask.data();
|
| 2252 |
+
memset(data, 0, ggml_nbytes(KQ_mask));
|
| 2253 |
+
|
| 2254 |
+
for (int h = 0; h < 1; ++h) {
|
| 2255 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 2256 |
+
const whisper_pos pos = batch.pos[j];
|
| 2257 |
+
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
| 2258 |
+
|
| 2259 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 2260 |
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
| 2261 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
| 2262 |
+
}
|
| 2263 |
+
}
|
| 2264 |
+
}
|
| 2265 |
+
}
|
| 2266 |
+
|
| 2267 |
+
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
| 2268 |
+
}
|
| 2269 |
+
|
| 2270 |
// token encoding + position encoding
|
| 2271 |
struct ggml_tensor * cur =
|
| 2272 |
ggml_add(ctx0,
|
|
|
|
| 2319 |
Vcur,
|
| 2320 |
layer.attn_v_b);
|
| 2321 |
|
| 2322 |
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
| 2323 |
|
| 2324 |
+
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
| 2325 |
+
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
| 2326 |
( n_ctx)*ggml_element_size(kv_self.v),
|
| 2327 |
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
| 2328 |
|
| 2329 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
| 2330 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
|
|
| 2334 |
|
| 2335 |
struct ggml_tensor * Q =
|
| 2336 |
ggml_permute(ctx0,
|
| 2337 |
+
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
| 2338 |
0, 2, 1, 3);
|
| 2339 |
|
| 2340 |
struct ggml_tensor * K =
|
| 2341 |
ggml_view_3d(ctx0, kv_self.k,
|
| 2342 |
+
n_state/n_head, n_kv, n_head,
|
| 2343 |
ggml_element_size(kv_self.k)*n_state,
|
| 2344 |
ggml_element_size(kv_self.k)*n_state/n_head,
|
| 2345 |
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
|
|
| 2349 |
|
| 2350 |
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
| 2351 |
|
| 2352 |
+
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 2353 |
+
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
|
| 2354 |
|
| 2355 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 2356 |
|
| 2357 |
struct ggml_tensor * V =
|
| 2358 |
ggml_view_3d(ctx0, kv_self.v,
|
| 2359 |
+
n_kv, n_state/n_head, n_head,
|
| 2360 |
n_ctx*ggml_element_size(kv_self.v),
|
| 2361 |
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
| 2362 |
+
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
| 2363 |
|
| 2364 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2365 |
|
|
|
|
| 2367 |
|
| 2368 |
cur = ggml_cpy(ctx0,
|
| 2369 |
KQV_merged,
|
| 2370 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2371 |
}
|
| 2372 |
|
| 2373 |
// projection
|
|
|
|
| 2411 |
// Kcross is already scaled
|
| 2412 |
struct ggml_tensor * Kcross =
|
| 2413 |
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2414 |
+
n_state/n_head, n_audio_ctx, n_head,
|
| 2415 |
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2416 |
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
| 2417 |
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
| 2418 |
|
| 2419 |
//struct ggml_tensor * Vcross =
|
| 2420 |
// ggml_reshape_3d(ctx0,
|
| 2421 |
+
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
| 2422 |
+
// n_state/n_head, n_head, n_audio_ctx);
|
| 2423 |
|
| 2424 |
//struct ggml_tensor * V_trans =
|
| 2425 |
// ggml_cpy(ctx0,
|
| 2426 |
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
| 2427 |
+
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
| 2428 |
|
| 2429 |
struct ggml_tensor * V =
|
| 2430 |
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
| 2431 |
+
n_audio_ctx, n_state/n_head, n_head,
|
| 2432 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
| 2433 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
| 2434 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
| 2435 |
|
| 2436 |
// ------
|
| 2437 |
|
| 2438 |
struct ggml_tensor * Q =
|
| 2439 |
ggml_permute(ctx0,
|
| 2440 |
+
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
| 2441 |
0, 2, 1, 3);
|
| 2442 |
|
| 2443 |
// K * Q
|
|
|
|
| 2458 |
|
| 2459 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2460 |
|
| 2461 |
+
// cur = KQV_merged.contiguous().view(n_state, n_tokens)
|
| 2462 |
cur = ggml_cpy(ctx0,
|
| 2463 |
KQV_merged,
|
| 2464 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2465 |
}
|
| 2466 |
|
| 2467 |
// projection
|
|
|
|
| 2533 |
}
|
| 2534 |
|
| 2535 |
// compute logits only for the last token
|
| 2536 |
+
// comment this line to compute logits for all n_tokens
|
| 2537 |
// might be useful in the future
|
| 2538 |
+
//cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
| 2539 |
|
| 2540 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2541 |
|
|
|
|
| 2559 |
static bool whisper_decode_internal(
|
| 2560 |
whisper_context & wctx,
|
| 2561 |
whisper_state & wstate,
|
| 2562 |
+
const whisper_batch & batch,
|
|
|
|
|
|
|
|
|
|
| 2563 |
const int n_threads,
|
| 2564 |
whisper_abort_callback abort_callback,
|
| 2565 |
void * abort_callback_data) {
|
|
|
|
| 2568 |
const auto & model = wctx.model;
|
| 2569 |
const auto & hparams = model.hparams;
|
| 2570 |
|
| 2571 |
+
const int n_vocab = hparams.n_vocab;
|
| 2572 |
+
const int n_tokens = batch.n_tokens;
|
| 2573 |
|
| 2574 |
auto & logits_out = wstate.logits;
|
| 2575 |
|
| 2576 |
struct ggml_tensor * logits;
|
| 2577 |
|
| 2578 |
+
// find KV slot for the batch
|
| 2579 |
+
{
|
| 2580 |
+
auto & kv_self = wstate.kv_self;
|
| 2581 |
+
|
| 2582 |
+
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
| 2583 |
+
return false;
|
| 2584 |
+
}
|
| 2585 |
+
|
| 2586 |
+
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
| 2587 |
+
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
| 2588 |
+
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
| 2589 |
+
}
|
| 2590 |
+
|
| 2591 |
// decoder
|
| 2592 |
{
|
| 2593 |
auto & alloc = wstate.alloc_decode.alloc;
|
| 2594 |
|
| 2595 |
ggml_allocr_reset(alloc);
|
| 2596 |
|
| 2597 |
+
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
|
| 2598 |
|
| 2599 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 2600 |
|
|
|
|
| 2603 |
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
| 2604 |
}
|
| 2605 |
|
| 2606 |
+
logits_out.resize(n_tokens*n_vocab);
|
| 2607 |
+
for (int i = 0; i < n_tokens; i++) {
|
| 2608 |
+
if (batch.logits[i] == 0) {
|
| 2609 |
+
continue;
|
| 2610 |
+
}
|
| 2611 |
+
ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
|
| 2612 |
+
}
|
|
|
|
|
|
|
| 2613 |
|
| 2614 |
+
if (batch.n_tokens > 1) {
|
| 2615 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 2616 |
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 2617 |
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
| 2620 |
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 2621 |
}
|
| 2622 |
|
| 2623 |
+
if (batch.n_tokens == 1) {
|
| 2624 |
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
| 2625 |
wstate.n_decode++;
|
| 2626 |
+
} else if (batch.n_tokens < 16) {
|
| 2627 |
+
wstate.t_batchd_us += ggml_time_us() - t_start_us;
|
| 2628 |
+
wstate.n_batchd += n_tokens;
|
| 2629 |
} else {
|
| 2630 |
wstate.t_prompt_us += ggml_time_us() - t_start_us;
|
| 2631 |
+
wstate.n_prompt += n_tokens;
|
| 2632 |
}
|
| 2633 |
|
| 2634 |
return !(abort_callback && abort_callback(abort_callback_data));
|
| 2635 |
}
|
| 2636 |
|
|
|
|
| 2637 |
// 500 -> 00:05.000
|
| 2638 |
// 6000 -> 01:00.000
|
| 2639 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
|
| 3045 |
|
| 3046 |
state->backend = whisper_backend_init(ctx->params);
|
| 3047 |
|
| 3048 |
+
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
| 3049 |
+
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
| 3050 |
+
const int factor = 3;
|
| 3051 |
+
|
| 3052 |
+
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
|
| 3053 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3054 |
delete state;
|
| 3055 |
return nullptr;
|
| 3056 |
}
|
| 3057 |
|
| 3058 |
{
|
| 3059 |
+
const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
|
| 3060 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 3061 |
}
|
| 3062 |
|
|
|
|
| 3091 |
|
| 3092 |
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
| 3093 |
|
| 3094 |
+
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
|
| 3095 |
|
| 3096 |
// TAGS: WHISPER_DECODER_INIT
|
| 3097 |
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
| 3098 |
|
| 3099 |
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
| 3100 |
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
| 3101 |
+
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
|
| 3102 |
+
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
|
| 3103 |
+
|
| 3104 |
+
state->decoders[0].rng = std::mt19937(0);
|
| 3105 |
|
| 3106 |
// conv allocator
|
| 3107 |
{
|
|
|
|
| 3143 |
const int n_tokens = hparams.n_text_ctx;
|
| 3144 |
const int n_past = 0;
|
| 3145 |
|
| 3146 |
+
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
| 3147 |
+
|
| 3148 |
+
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
| 3149 |
});
|
| 3150 |
|
| 3151 |
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
|
|
| 3156 |
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
| 3157 |
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
| 3158 |
|
|
|
|
|
|
|
| 3159 |
return state;
|
| 3160 |
}
|
| 3161 |
|
|
|
|
| 3380 |
void whisper_free_state(struct whisper_state * state)
|
| 3381 |
{
|
| 3382 |
if (state) {
|
| 3383 |
+
kv_cache_free(state->kv_self);
|
| 3384 |
kv_cache_free(state->kv_cross);
|
| 3385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3386 |
#ifdef WHISPER_USE_COREML
|
| 3387 |
if (state->ctx_coreml != nullptr) {
|
| 3388 |
whisper_coreml_free(state->ctx_coreml);
|
|
|
|
| 3397 |
}
|
| 3398 |
#endif
|
| 3399 |
|
| 3400 |
+
whisper_batch_free(state->batch);
|
| 3401 |
+
|
| 3402 |
whisper_allocr_free(state->alloc_conv);
|
| 3403 |
whisper_allocr_free(state->alloc_encode);
|
| 3404 |
whisper_allocr_free(state->alloc_cross);
|
|
|
|
| 3525 |
}
|
| 3526 |
|
| 3527 |
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 3528 |
+
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
|
| 3529 |
+
|
| 3530 |
+
whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
|
| 3531 |
|
| 3532 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
| 3533 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3534 |
return 1;
|
| 3535 |
}
|
|
|
|
| 3538 |
}
|
| 3539 |
|
| 3540 |
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
|
|
|
|
|
|
|
|
| 3541 |
if (ctx->state == nullptr) {
|
| 3542 |
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
| 3543 |
return false;
|
| 3544 |
}
|
| 3545 |
|
| 3546 |
+
whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
|
| 3547 |
+
|
| 3548 |
+
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
|
| 3549 |
+
|
| 3550 |
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
|
| 3551 |
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3552 |
return 1;
|
| 3553 |
}
|
|
|
|
| 3635 |
return -7;
|
| 3636 |
}
|
| 3637 |
|
| 3638 |
+
auto & logits_id = state->decoders[0].logits_id;
|
| 3639 |
logits_id.clear();
|
| 3640 |
|
| 3641 |
for (const auto & kv : g_lang) {
|
|
|
|
| 3838 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
| 3839 |
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
| 3840 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
| 3841 |
+
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
| 3842 |
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
| 3843 |
|
| 3844 |
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
|
|
| 3846 |
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3847 |
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3848 |
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
| 3849 |
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
| 3850 |
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
| 3851 |
}
|
| 3852 |
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
|
|
| 3863 |
ctx->state->n_sample = 0;
|
| 3864 |
ctx->state->n_encode = 0;
|
| 3865 |
ctx->state->n_decode = 0;
|
| 3866 |
+
ctx->state->n_batchd = 0;
|
| 3867 |
ctx->state->n_prompt = 0;
|
| 3868 |
}
|
| 3869 |
}
|
|
|
|
| 4171 |
if (*tok.code_points == 0) {
|
| 4172 |
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
| 4173 |
// that cannot satisfy this position in grammar
|
| 4174 |
+
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
|
|
|
| 4175 |
rejects.push_back(tok);
|
| 4176 |
}
|
| 4177 |
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
|
|
|
|
| 4390 |
/*.max_initial_ts =*/ 1.0f,
|
| 4391 |
/*.length_penalty =*/ -1.0f,
|
| 4392 |
|
| 4393 |
+
/*.temperature_inc =*/ 0.2f,
|
| 4394 |
/*.entropy_thold =*/ 2.4f,
|
| 4395 |
/*.logprob_thold =*/ -1.0f,
|
| 4396 |
/*.no_speech_thold =*/ 0.6f,
|
|
|
|
| 4430 |
case WHISPER_SAMPLING_GREEDY:
|
| 4431 |
{
|
| 4432 |
result.greedy = {
|
| 4433 |
+
/*.best_of =*/ 5,
|
| 4434 |
};
|
| 4435 |
} break;
|
| 4436 |
case WHISPER_SAMPLING_BEAM_SEARCH:
|
| 4437 |
{
|
| 4438 |
result.beam_search = {
|
| 4439 |
+
/*.beam_size =*/ 5,
|
| 4440 |
|
| 4441 |
/*.patience =*/ -1.0f,
|
| 4442 |
};
|
|
|
|
| 4526 |
// process the logits for the selected decoder
|
| 4527 |
// - applies logit filters
|
| 4528 |
// - computes logprobs and probs
|
| 4529 |
+
// TODO: optimize
|
| 4530 |
static void whisper_process_logits(
|
| 4531 |
struct whisper_context & ctx,
|
| 4532 |
struct whisper_state & state,
|
|
|
|
| 4533 |
struct whisper_decoder & decoder,
|
| 4534 |
+
const struct whisper_full_params params,
|
| 4535 |
float temperature) {
|
| 4536 |
const auto & vocab = ctx.vocab;
|
| 4537 |
const auto & tokens_cur = decoder.sequence.tokens;
|
|
|
|
| 4548 |
auto & logprobs = decoder.logprobs;
|
| 4549 |
{
|
| 4550 |
logits.resize(n_logits);
|
| 4551 |
+
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
|
| 4552 |
|
| 4553 |
if (temperature > 0.0f) {
|
| 4554 |
for (int i = 0; i < n_logits; i++) {
|
|
|
|
| 4714 |
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
| 4715 |
|
| 4716 |
if (timestamp_logprob > max_text_token_logprob) {
|
|
|
|
| 4717 |
for (int i = 0; i < vocab.token_beg; ++i) {
|
| 4718 |
logits[i] = -INFINITY;
|
| 4719 |
logprobs[i] = -INFINITY;
|
| 4720 |
}
|
| 4721 |
+
} else {
|
| 4722 |
+
if (params.n_grammar_rules > 0) {
|
| 4723 |
+
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
| 4724 |
|
| 4725 |
+
// populate the logprobs array (log_softmax)
|
| 4726 |
+
{
|
| 4727 |
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
| 4728 |
+
float logsumexp = 0.0f;
|
| 4729 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 4730 |
+
if (logits[i] > -INFINITY) {
|
| 4731 |
+
logsumexp += expf(logits[i] - logit_max);
|
| 4732 |
+
}
|
| 4733 |
}
|
| 4734 |
+
logsumexp = logf(logsumexp) + logit_max;
|
|
|
|
| 4735 |
|
| 4736 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 4737 |
+
if (logits[i] > -INFINITY) {
|
| 4738 |
+
logprobs[i] = logits[i] - logsumexp;
|
| 4739 |
+
} else {
|
| 4740 |
+
logprobs[i] = -INFINITY;
|
| 4741 |
+
}
|
| 4742 |
}
|
| 4743 |
}
|
| 4744 |
}
|
|
|
|
| 4813 |
|
| 4814 |
static whisper_token_data whisper_sample_token(
|
| 4815 |
whisper_context & ctx,
|
|
|
|
| 4816 |
const whisper_decoder & decoder,
|
| 4817 |
bool best) {
|
| 4818 |
whisper_token_data result = {
|
|
|
|
| 4857 |
} else {
|
| 4858 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 4859 |
|
| 4860 |
+
result.id = dist(decoder.rng);
|
| 4861 |
result.p = probs[result.id];
|
| 4862 |
result.plog = logprobs[result.id];
|
| 4863 |
}
|
|
|
|
| 4867 |
result.pt = result.p;
|
| 4868 |
}
|
| 4869 |
|
|
|
|
|
|
|
| 4870 |
return result;
|
| 4871 |
}
|
| 4872 |
|
| 4873 |
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
| 4874 |
whisper_context & ctx,
|
| 4875 |
+
whisper_decoder & decoder,
|
|
|
|
| 4876 |
int k) {
|
| 4877 |
const auto & vocab = ctx.vocab;
|
| 4878 |
|
|
|
|
| 4882 |
|
| 4883 |
const int n_logits = vocab.n_vocab;
|
| 4884 |
|
| 4885 |
+
auto & logits_id = decoder.logits_id;
|
| 4886 |
|
| 4887 |
logits_id.resize(n_logits);
|
| 4888 |
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
| 4931 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 4932 |
|
| 4933 |
for (int i = 0; i < k; ++i) {
|
| 4934 |
+
const auto id = dist(decoder.rng);
|
| 4935 |
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
| 4936 |
|
| 4937 |
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
|
|
| 4942 |
}
|
| 4943 |
}
|
| 4944 |
|
|
|
|
|
|
|
| 4945 |
return result;
|
| 4946 |
}
|
| 4947 |
|
|
|
|
| 4994 |
}
|
| 4995 |
}
|
| 4996 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4997 |
int whisper_full_with_state(
|
| 4998 |
struct whisper_context * ctx,
|
| 4999 |
struct whisper_state * state,
|
|
|
|
| 5083 |
|
| 5084 |
n_decoders = std::max(1, n_decoders);
|
| 5085 |
|
| 5086 |
+
if (n_decoders > WHISPER_MAX_DECODERS) {
|
| 5087 |
+
WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
|
| 5088 |
+
return -4;
|
| 5089 |
+
}
|
| 5090 |
+
|
| 5091 |
// TAGS: WHISPER_DECODER_INIT
|
| 5092 |
for (int j = 1; j < n_decoders; j++) {
|
| 5093 |
auto & decoder = state->decoders[j];
|
| 5094 |
|
| 5095 |
+
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5096 |
|
| 5097 |
+
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 5098 |
+
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 5099 |
+
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
| 5100 |
+
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
| 5101 |
|
| 5102 |
+
decoder.rng = std::mt19937(0);
|
|
|
|
|
|
|
|
|
|
| 5103 |
}
|
| 5104 |
|
| 5105 |
// the accumulated text context so far
|
|
|
|
| 5176 |
bool has_ts;
|
| 5177 |
|
| 5178 |
whisper_sequence sequence;
|
| 5179 |
+
whisper_grammar grammar;
|
| 5180 |
};
|
| 5181 |
|
| 5182 |
+
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
|
| 5183 |
std::vector<beam_candidate> beam_candidates;
|
| 5184 |
|
| 5185 |
// main loop
|
|
|
|
| 5247 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5248 |
auto & decoder = state->decoders[j];
|
| 5249 |
|
|
|
|
|
|
|
| 5250 |
decoder.sequence.tokens.clear();
|
| 5251 |
decoder.sequence.result_len = 0;
|
| 5252 |
decoder.sequence.sum_logprobs_all = 0.0;
|
|
|
|
| 5262 |
decoder.has_ts = false;
|
| 5263 |
|
| 5264 |
if (params.grammar_rules != nullptr) {
|
| 5265 |
+
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
|
|
|
| 5266 |
} else {
|
| 5267 |
decoder.grammar = {};
|
| 5268 |
}
|
| 5269 |
}
|
| 5270 |
|
| 5271 |
// init prompt and kv cache for the current iteration
|
| 5272 |
+
// TODO: do not recompute the prompt if it is the same as previous time
|
| 5273 |
{
|
| 5274 |
prompt.clear();
|
| 5275 |
|
|
|
|
| 5291 |
}
|
| 5292 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 5293 |
|
| 5294 |
+
whisper_kv_cache_clear(state->kv_self);
|
| 5295 |
+
|
| 5296 |
+
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
| 5297 |
+
|
| 5298 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 5299 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5300 |
return -7;
|
| 5301 |
}
|
|
|
|
| 5303 |
{
|
| 5304 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5305 |
|
| 5306 |
+
state->decoders[0].i_batch = prompt.size() - 1;
|
| 5307 |
|
| 5308 |
+
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
|
| 5309 |
|
| 5310 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 5311 |
auto & decoder = state->decoders[j];
|
| 5312 |
|
| 5313 |
+
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5314 |
|
| 5315 |
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
| 5316 |
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
| 5325 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 5326 |
|
| 5327 |
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
| 5328 |
+
for (auto & bc : bc_per_dec) {
|
| 5329 |
+
bc.clear();
|
| 5330 |
+
}
|
| 5331 |
}
|
| 5332 |
|
| 5333 |
+
// sampling
|
| 5334 |
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
| 5335 |
+
{
|
| 5336 |
+
std::atomic<int> j_cur(0);
|
| 5337 |
|
| 5338 |
+
auto process = [&]() {
|
| 5339 |
+
while (true) {
|
| 5340 |
+
const int j = j_cur.fetch_add(1);
|
| 5341 |
|
| 5342 |
+
if (j >= n_decoders_cur) {
|
| 5343 |
+
break;
|
| 5344 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5345 |
|
| 5346 |
+
auto & decoder = state->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5347 |
|
| 5348 |
+
if (decoder.completed || decoder.failed) {
|
| 5349 |
+
continue;
|
| 5350 |
+
}
|
|
|
|
| 5351 |
|
| 5352 |
+
switch (params.strategy) {
|
| 5353 |
+
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
| 5354 |
+
{
|
| 5355 |
+
if (t_cur < 1e-6f) {
|
| 5356 |
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
| 5357 |
+
} else {
|
| 5358 |
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
| 5359 |
+
}
|
| 5360 |
+
|
| 5361 |
+
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
| 5362 |
+
} break;
|
| 5363 |
+
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
| 5364 |
+
{
|
| 5365 |
+
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
| 5366 |
+
|
| 5367 |
+
for (const auto & token : tokens_new) {
|
| 5368 |
+
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
| 5369 |
+
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
| 5370 |
+
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
| 5371 |
+
}
|
| 5372 |
+
} break;
|
| 5373 |
+
};
|
| 5374 |
+
}
|
| 5375 |
};
|
| 5376 |
+
|
| 5377 |
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
| 5378 |
+
|
| 5379 |
+
if (n_threads == 1) {
|
| 5380 |
+
process();
|
| 5381 |
+
} else {
|
| 5382 |
+
std::vector<std::thread> threads(n_threads - 1);
|
| 5383 |
+
|
| 5384 |
+
for (int t = 0; t < n_threads - 1; ++t) {
|
| 5385 |
+
threads[t] = std::thread(process);
|
| 5386 |
+
}
|
| 5387 |
+
|
| 5388 |
+
process();
|
| 5389 |
+
|
| 5390 |
+
for (int t = 0; t < n_threads - 1; ++t) {
|
| 5391 |
+
threads[t].join();
|
| 5392 |
+
}
|
| 5393 |
+
}
|
| 5394 |
+
}
|
| 5395 |
+
|
| 5396 |
+
beam_candidates.clear();
|
| 5397 |
+
for (const auto & bc : bc_per_dec) {
|
| 5398 |
+
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
|
| 5399 |
+
|
| 5400 |
+
if (!bc.empty()) {
|
| 5401 |
+
state->n_sample += 1;
|
| 5402 |
+
}
|
| 5403 |
}
|
| 5404 |
|
| 5405 |
// for beam-search, choose the top candidates and update the KV caches
|
|
|
|
| 5412 |
});
|
| 5413 |
|
| 5414 |
uint32_t cur_c = 0;
|
|
|
|
| 5415 |
|
| 5416 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5417 |
auto & decoder = state->decoders[j];
|
|
|
|
| 5430 |
++cur_c;
|
| 5431 |
}
|
| 5432 |
|
|
|
|
| 5433 |
decoder.seek_delta = cur.seek_delta;
|
| 5434 |
decoder.has_ts = cur.has_ts;
|
| 5435 |
+
decoder.sequence = cur.sequence;
|
| 5436 |
+
decoder.grammar = cur.grammar;
|
| 5437 |
+
|
| 5438 |
+
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
| 5439 |
|
|
|
|
| 5440 |
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
| 5441 |
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
| 5442 |
}
|
| 5443 |
|
| 5444 |
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5445 |
+
auto & decoder = state->decoders[j];
|
| 5446 |
+
|
| 5447 |
+
if (decoder.completed || decoder.failed) {
|
| 5448 |
+
continue;
|
| 5449 |
+
}
|
| 5450 |
+
|
| 5451 |
+
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
|
| 5452 |
+
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
|
| 5453 |
+
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
|
| 5454 |
+
}
|
| 5455 |
}
|
| 5456 |
|
| 5457 |
// update the decoder state
|
|
|
|
| 5560 |
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 5561 |
|
| 5562 |
// obtain logits for the next token
|
| 5563 |
+
{
|
| 5564 |
+
auto & batch = state->batch;
|
| 5565 |
|
| 5566 |
+
batch.n_tokens = 0;
|
| 5567 |
+
|
| 5568 |
+
const int n_past = prompt.size() + i;
|
| 5569 |
+
|
| 5570 |
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 5571 |
+
auto & decoder = state->decoders[j];
|
| 5572 |
|
| 5573 |
+
if (decoder.failed || decoder.completed) {
|
| 5574 |
+
continue;
|
| 5575 |
+
}
|
| 5576 |
|
| 5577 |
+
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
| 5578 |
|
| 5579 |
+
decoder.i_batch = batch.n_tokens;
|
| 5580 |
+
|
| 5581 |
+
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
|
| 5582 |
+
batch.pos [batch.n_tokens] = n_past;
|
| 5583 |
+
batch.n_seq_id[batch.n_tokens] = 1;
|
| 5584 |
+
batch.seq_id [batch.n_tokens][0] = j;
|
| 5585 |
+
batch.logits [batch.n_tokens] = 1;
|
| 5586 |
+
batch.n_tokens++;
|
| 5587 |
+
}
|
| 5588 |
+
|
| 5589 |
+
assert(batch.n_tokens > 0);
|
| 5590 |
+
|
| 5591 |
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 5592 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5593 |
return -8;
|
| 5594 |
}
|
| 5595 |
|
| 5596 |
+
const int64_t t_start_sample_us = ggml_time_us();
|
| 5597 |
+
|
| 5598 |
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
| 5599 |
{
|
| 5600 |
+
std::atomic<int> j_cur(0);
|
| 5601 |
+
|
| 5602 |
+
auto process = [&]() {
|
| 5603 |
+
while (true) {
|
| 5604 |
+
const int j = j_cur.fetch_add(1);
|
| 5605 |
+
|
| 5606 |
+
if (j >= n_decoders_cur) {
|
| 5607 |
+
break;
|
| 5608 |
+
}
|
| 5609 |
|
| 5610 |
+
auto & decoder = state->decoders[j];
|
| 5611 |
|
| 5612 |
+
if (decoder.failed || decoder.completed) {
|
| 5613 |
+
continue;
|
| 5614 |
+
}
|
| 5615 |
|
| 5616 |
+
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
|
| 5617 |
+
}
|
| 5618 |
+
};
|
| 5619 |
+
|
| 5620 |
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
| 5621 |
+
|
| 5622 |
+
if (n_threads == 1) {
|
| 5623 |
+
process();
|
| 5624 |
+
} else {
|
| 5625 |
+
std::vector<std::thread> threads(n_threads - 1);
|
| 5626 |
+
|
| 5627 |
+
for (int t = 0; t < n_threads - 1; ++t) {
|
| 5628 |
+
threads[t] = std::thread(process);
|
| 5629 |
+
}
|
| 5630 |
+
|
| 5631 |
+
process();
|
| 5632 |
+
|
| 5633 |
+
for (int t = 0; t < n_threads - 1; ++t) {
|
| 5634 |
+
threads[t].join();
|
| 5635 |
+
}
|
| 5636 |
+
}
|
| 5637 |
}
|
| 5638 |
+
|
| 5639 |
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 5640 |
}
|
| 5641 |
}
|
| 5642 |
|
|
|
|
| 5933 |
ctx->state->t_sample_us += states[i]->t_sample_us;
|
| 5934 |
ctx->state->t_encode_us += states[i]->t_encode_us;
|
| 5935 |
ctx->state->t_decode_us += states[i]->t_decode_us;
|
| 5936 |
+
ctx->state->t_batchd_us += states[i]->t_batchd_us;
|
| 5937 |
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
| 5938 |
|
| 5939 |
ctx->state->n_sample += states[i]->n_sample;
|
| 5940 |
ctx->state->n_encode += states[i]->n_encode;
|
| 5941 |
ctx->state->n_decode += states[i]->n_decode;
|
| 5942 |
+
ctx->state->n_batchd += states[i]->n_batchd;
|
| 5943 |
ctx->state->n_prompt += states[i]->n_prompt;
|
| 5944 |
|
| 5945 |
whisper_free_state(states[i]);
|
whisper.h
CHANGED
|
@@ -78,7 +78,9 @@ extern "C" {
|
|
| 78 |
struct whisper_state;
|
| 79 |
struct whisper_full_params;
|
| 80 |
|
| 81 |
-
typedef
|
|
|
|
|
|
|
| 82 |
|
| 83 |
struct whisper_context_params {
|
| 84 |
bool use_gpu;
|
|
|
|
| 78 |
struct whisper_state;
|
| 79 |
struct whisper_full_params;
|
| 80 |
|
| 81 |
+
typedef int32_t whisper_pos;
|
| 82 |
+
typedef int32_t whisper_token;
|
| 83 |
+
typedef int32_t whisper_seq_id;
|
| 84 |
|
| 85 |
struct whisper_context_params {
|
| 86 |
bool use_gpu;
|