Spaces:
Sleeping
Sleeping
mkiol
commited on
whisper : add abort callback (#1335)
Browse files- whisper.cpp +31 -19
- whisper.h +9 -0
whisper.cpp
CHANGED
|
@@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
| 125 |
// ggml helpers
|
| 126 |
//
|
| 127 |
|
| 128 |
-
static void ggml_graph_compute_helper(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
| 130 |
|
|
|
|
|
|
|
|
|
|
| 131 |
if (plan.work_size > 0) {
|
| 132 |
buf.resize(plan.work_size);
|
| 133 |
plan.work_data = buf.data();
|
|
@@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
|
|
| 1922 |
whisper_context & wctx,
|
| 1923 |
whisper_state & wstate,
|
| 1924 |
const int mel_offset,
|
| 1925 |
-
const int n_threads
|
|
|
|
|
|
|
| 1926 |
const int64_t t_start_us = ggml_time_us();
|
| 1927 |
|
| 1928 |
// conv
|
|
@@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
|
|
| 1936 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1937 |
|
| 1938 |
if (!whisper_encode_external(wstate)) {
|
| 1939 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1940 |
}
|
| 1941 |
}
|
| 1942 |
|
|
@@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
|
|
| 1955 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1956 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1957 |
} else {
|
| 1958 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1959 |
}
|
| 1960 |
#else
|
| 1961 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1962 |
#endif
|
| 1963 |
}
|
| 1964 |
|
|
@@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
|
|
| 1977 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1978 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1979 |
} else {
|
| 1980 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1981 |
}
|
| 1982 |
#else
|
| 1983 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1984 |
#endif
|
| 1985 |
}
|
| 1986 |
|
|
@@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
|
|
| 2346 |
const whisper_token * tokens,
|
| 2347 |
const int n_tokens,
|
| 2348 |
const int n_past,
|
| 2349 |
-
const int n_threads
|
|
|
|
|
|
|
| 2350 |
const int64_t t_start_us = ggml_time_us();
|
| 2351 |
|
| 2352 |
const auto & model = wctx.model;
|
|
@@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
|
|
| 2375 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 2376 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 2377 |
} else {
|
| 2378 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 2379 |
}
|
| 2380 |
#else
|
| 2381 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 2382 |
#endif
|
| 2383 |
}
|
| 2384 |
|
|
@@ -3290,7 +3302,7 @@ int whisper_set_mel(
|
|
| 3290 |
}
|
| 3291 |
|
| 3292 |
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
| 3293 |
-
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
| 3294 |
log("%s: failed to eval\n", __func__);
|
| 3295 |
return -1;
|
| 3296 |
}
|
|
@@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
| 3299 |
}
|
| 3300 |
|
| 3301 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 3302 |
-
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
| 3303 |
log("%s: failed to eval\n", __func__);
|
| 3304 |
return -1;
|
| 3305 |
}
|
|
@@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
| 3310 |
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 3311 |
const int selected_decoder_id = 0;
|
| 3312 |
|
| 3313 |
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
| 3314 |
log("%s: failed to eval\n", __func__);
|
| 3315 |
return 1;
|
| 3316 |
}
|
|
@@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
| 3327 |
return false;
|
| 3328 |
}
|
| 3329 |
|
| 3330 |
-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
| 3331 |
log("%s: failed to eval\n", __func__);
|
| 3332 |
return 1;
|
| 3333 |
}
|
|
@@ -4594,7 +4606,7 @@ int whisper_full_with_state(
|
|
| 4594 |
}
|
| 4595 |
|
| 4596 |
// encode audio features starting at offset seek
|
| 4597 |
-
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
| 4598 |
log("%s: failed to encode\n", __func__);
|
| 4599 |
return -6;
|
| 4600 |
}
|
|
@@ -4677,7 +4689,7 @@ int whisper_full_with_state(
|
|
| 4677 |
}
|
| 4678 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 4679 |
|
| 4680 |
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
| 4681 |
log("%s: failed to decode\n", __func__);
|
| 4682 |
return -7;
|
| 4683 |
}
|
|
@@ -4901,7 +4913,7 @@ int whisper_full_with_state(
|
|
| 4901 |
|
| 4902 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 4903 |
|
| 4904 |
-
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
| 4905 |
log("%s: failed to decode\n", __func__);
|
| 4906 |
return -8;
|
| 4907 |
}
|
|
@@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
| 5473 |
double tsum = 0.0;
|
| 5474 |
|
| 5475 |
// heat-up
|
| 5476 |
-
ggml_graph_compute_helper(work, &gf, n_threads);
|
| 5477 |
|
| 5478 |
for (int i = 0; i < n_max; ++i) {
|
| 5479 |
const int64_t t0 = ggml_time_us();
|
| 5480 |
|
| 5481 |
-
ggml_graph_compute_helper(work, &gf, n_threads);
|
| 5482 |
|
| 5483 |
const int64_t t1 = ggml_time_us();
|
| 5484 |
|
|
|
|
| 125 |
// ggml helpers
|
| 126 |
//
|
| 127 |
|
| 128 |
+
static void ggml_graph_compute_helper(
|
| 129 |
+
std::vector<uint8_t> & buf,
|
| 130 |
+
ggml_cgraph * graph,
|
| 131 |
+
int n_threads,
|
| 132 |
+
whisper_abort_callback abort_callback,
|
| 133 |
+
void * abort_callback_data) {
|
| 134 |
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
| 135 |
|
| 136 |
+
plan.abort_callback = abort_callback;
|
| 137 |
+
plan.abort_callback_data = abort_callback_data;
|
| 138 |
+
|
| 139 |
if (plan.work_size > 0) {
|
| 140 |
buf.resize(plan.work_size);
|
| 141 |
plan.work_data = buf.data();
|
|
|
|
| 1930 |
whisper_context & wctx,
|
| 1931 |
whisper_state & wstate,
|
| 1932 |
const int mel_offset,
|
| 1933 |
+
const int n_threads,
|
| 1934 |
+
whisper_abort_callback abort_callback,
|
| 1935 |
+
void * abort_callback_data) {
|
| 1936 |
const int64_t t_start_us = ggml_time_us();
|
| 1937 |
|
| 1938 |
// conv
|
|
|
|
| 1946 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1947 |
|
| 1948 |
if (!whisper_encode_external(wstate)) {
|
| 1949 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1950 |
}
|
| 1951 |
}
|
| 1952 |
|
|
|
|
| 1965 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1966 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1967 |
} else {
|
| 1968 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1969 |
}
|
| 1970 |
#else
|
| 1971 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1972 |
#endif
|
| 1973 |
}
|
| 1974 |
|
|
|
|
| 1987 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1988 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1989 |
} else {
|
| 1990 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1991 |
}
|
| 1992 |
#else
|
| 1993 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1994 |
#endif
|
| 1995 |
}
|
| 1996 |
|
|
|
|
| 2356 |
const whisper_token * tokens,
|
| 2357 |
const int n_tokens,
|
| 2358 |
const int n_past,
|
| 2359 |
+
const int n_threads,
|
| 2360 |
+
whisper_abort_callback abort_callback,
|
| 2361 |
+
void * abort_callback_data) {
|
| 2362 |
const int64_t t_start_us = ggml_time_us();
|
| 2363 |
|
| 2364 |
const auto & model = wctx.model;
|
|
|
|
| 2387 |
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 2388 |
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 2389 |
} else {
|
| 2390 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2391 |
}
|
| 2392 |
#else
|
| 2393 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2394 |
#endif
|
| 2395 |
}
|
| 2396 |
|
|
|
|
| 3302 |
}
|
| 3303 |
|
| 3304 |
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
| 3305 |
+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
| 3306 |
log("%s: failed to eval\n", __func__);
|
| 3307 |
return -1;
|
| 3308 |
}
|
|
|
|
| 3311 |
}
|
| 3312 |
|
| 3313 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 3314 |
+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
| 3315 |
log("%s: failed to eval\n", __func__);
|
| 3316 |
return -1;
|
| 3317 |
}
|
|
|
|
| 3322 |
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 3323 |
const int selected_decoder_id = 0;
|
| 3324 |
|
| 3325 |
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3326 |
log("%s: failed to eval\n", __func__);
|
| 3327 |
return 1;
|
| 3328 |
}
|
|
|
|
| 3339 |
return false;
|
| 3340 |
}
|
| 3341 |
|
| 3342 |
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3343 |
log("%s: failed to eval\n", __func__);
|
| 3344 |
return 1;
|
| 3345 |
}
|
|
|
|
| 4606 |
}
|
| 4607 |
|
| 4608 |
// encode audio features starting at offset seek
|
| 4609 |
+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4610 |
log("%s: failed to encode\n", __func__);
|
| 4611 |
return -6;
|
| 4612 |
}
|
|
|
|
| 4689 |
}
|
| 4690 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 4691 |
|
| 4692 |
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4693 |
log("%s: failed to decode\n", __func__);
|
| 4694 |
return -7;
|
| 4695 |
}
|
|
|
|
| 4913 |
|
| 4914 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 4915 |
|
| 4916 |
+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4917 |
log("%s: failed to decode\n", __func__);
|
| 4918 |
return -8;
|
| 4919 |
}
|
|
|
|
| 5485 |
double tsum = 0.0;
|
| 5486 |
|
| 5487 |
// heat-up
|
| 5488 |
+
ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
|
| 5489 |
|
| 5490 |
for (int i = 0; i < n_max; ++i) {
|
| 5491 |
const int64_t t0 = ggml_time_us();
|
| 5492 |
|
| 5493 |
+
ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
|
| 5494 |
|
| 5495 |
const int64_t t1 = ggml_time_us();
|
| 5496 |
|
whisper.h
CHANGED
|
@@ -334,6 +334,11 @@ extern "C" {
|
|
| 334 |
// If it returns false, the computation is aborted
|
| 335 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
// Logits filter callback
|
| 338 |
// Can be used to modify the logits before sampling
|
| 339 |
// If not NULL, called after applying temperature to logits
|
|
@@ -428,6 +433,10 @@ extern "C" {
|
|
| 428 |
whisper_encoder_begin_callback encoder_begin_callback;
|
| 429 |
void * encoder_begin_callback_user_data;
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
// called by each decoder to filter obtained logits
|
| 432 |
whisper_logits_filter_callback logits_filter_callback;
|
| 433 |
void * logits_filter_callback_user_data;
|
|
|
|
| 334 |
// If it returns false, the computation is aborted
|
| 335 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
| 336 |
|
| 337 |
+
// Abort callback
|
| 338 |
+
// If not NULL, called before ggml computation
|
| 339 |
+
// If it returns true, the computation is aborted
|
| 340 |
+
typedef bool (*whisper_abort_callback)(void * user_data);
|
| 341 |
+
|
| 342 |
// Logits filter callback
|
| 343 |
// Can be used to modify the logits before sampling
|
| 344 |
// If not NULL, called after applying temperature to logits
|
|
|
|
| 433 |
whisper_encoder_begin_callback encoder_begin_callback;
|
| 434 |
void * encoder_begin_callback_user_data;
|
| 435 |
|
| 436 |
+
// called each time before ggml computation starts
|
| 437 |
+
whisper_abort_callback abort_callback;
|
| 438 |
+
void * abort_callback_user_data;
|
| 439 |
+
|
| 440 |
// called by each decoder to filter obtained logits
|
| 441 |
whisper_logits_filter_callback logits_filter_callback;
|
| 442 |
void * logits_filter_callback_user_data;
|