Spaces:
Running
Running
main : add command-style grammar (#1998)
Browse files* Implemented command-style grammar in the main example.
Mostly just copied the relevant parts from the command example.
* main : code style
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- examples/main/main.cpp +54 -4
examples/main/main.cpp
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#include "common.h"
|
| 2 |
|
| 3 |
#include "whisper.h"
|
|
|
|
| 4 |
|
| 5 |
#include <cmath>
|
| 6 |
#include <fstream>
|
|
@@ -38,9 +39,10 @@ struct whisper_params {
|
|
| 38 |
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
| 39 |
int32_t audio_ctx = 0;
|
| 40 |
|
| 41 |
-
float word_thold
|
| 42 |
-
float entropy_thold
|
| 43 |
-
float logprob_thold
|
|
|
|
| 44 |
|
| 45 |
bool speed_up = false;
|
| 46 |
bool debug_mode = false;
|
|
@@ -70,6 +72,8 @@ struct whisper_params {
|
|
| 70 |
std::string prompt;
|
| 71 |
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
| 72 |
std::string model = "models/ggml-base.en.bin";
|
|
|
|
|
|
|
| 73 |
|
| 74 |
// [TDRZ] speaker turn string
|
| 75 |
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
|
@@ -80,6 +84,8 @@ struct whisper_params {
|
|
| 80 |
|
| 81 |
std::vector<std::string> fname_inp = {};
|
| 82 |
std::vector<std::string> fname_out = {};
|
|
|
|
|
|
|
| 83 |
};
|
| 84 |
|
| 85 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 154 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 155 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 156 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
|
|
|
|
|
|
| 157 |
else {
|
| 158 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 159 |
whisper_print_usage(argc, argv, params);
|
|
@@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 214 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 215 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 216 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
|
|
|
|
|
|
| 217 |
fprintf(stderr, "\n");
|
| 218 |
}
|
| 219 |
|
|
@@ -926,6 +938,29 @@ int main(int argc, char ** argv) {
|
|
| 926 |
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
| 927 |
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
| 928 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
| 930 |
const auto fname_inp = params.fname_inp[f];
|
| 931 |
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
|
@@ -972,7 +1007,8 @@ int main(int argc, char ** argv) {
|
|
| 972 |
{
|
| 973 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 974 |
|
| 975 |
-
|
|
|
|
| 976 |
|
| 977 |
wparams.print_realtime = false;
|
| 978 |
wparams.print_progress = params.print_progress;
|
|
@@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) {
|
|
| 1010 |
|
| 1011 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
| 1012 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
// this callback is called on each new segment
|
| 1014 |
if (!wparams.print_realtime) {
|
| 1015 |
wparams.new_segment_callback = whisper_print_segment_callback;
|
|
|
|
| 1 |
#include "common.h"
|
| 2 |
|
| 3 |
#include "whisper.h"
|
| 4 |
+
#include "grammar-parser.h"
|
| 5 |
|
| 6 |
#include <cmath>
|
| 7 |
#include <fstream>
|
|
|
|
| 39 |
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
| 40 |
int32_t audio_ctx = 0;
|
| 41 |
|
| 42 |
+
float word_thold = 0.01f;
|
| 43 |
+
float entropy_thold = 2.40f;
|
| 44 |
+
float logprob_thold = -1.00f;
|
| 45 |
+
float grammar_penalty = 100.0f;
|
| 46 |
|
| 47 |
bool speed_up = false;
|
| 48 |
bool debug_mode = false;
|
|
|
|
| 72 |
std::string prompt;
|
| 73 |
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
| 74 |
std::string model = "models/ggml-base.en.bin";
|
| 75 |
+
std::string grammar;
|
| 76 |
+
std::string grammar_rule;
|
| 77 |
|
| 78 |
// [TDRZ] speaker turn string
|
| 79 |
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
|
|
|
| 84 |
|
| 85 |
std::vector<std::string> fname_inp = {};
|
| 86 |
std::vector<std::string> fname_out = {};
|
| 87 |
+
|
| 88 |
+
grammar_parser::parse_state grammar_parsed;
|
| 89 |
};
|
| 90 |
|
| 91 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 160 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 161 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 162 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 163 |
+
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 164 |
+
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
| 165 |
+
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
| 166 |
else {
|
| 167 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 168 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 223 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 224 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 225 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 226 |
+
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 227 |
+
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
| 228 |
+
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
| 229 |
fprintf(stderr, "\n");
|
| 230 |
}
|
| 231 |
|
|
|
|
| 938 |
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
| 939 |
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
| 940 |
|
| 941 |
+
if (!params.grammar.empty()) {
|
| 942 |
+
auto & grammar = params.grammar_parsed;
|
| 943 |
+
if (is_file_exist(params.grammar.c_str())) {
|
| 944 |
+
// read grammar from file
|
| 945 |
+
std::ifstream ifs(params.grammar.c_str());
|
| 946 |
+
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
| 947 |
+
grammar = grammar_parser::parse(txt.c_str());
|
| 948 |
+
} else {
|
| 949 |
+
// read grammar from string
|
| 950 |
+
grammar = grammar_parser::parse(params.grammar.c_str());
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
// will be empty (default) if there are parse errors
|
| 954 |
+
if (grammar.rules.empty()) {
|
| 955 |
+
fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
|
| 956 |
+
return 4;
|
| 957 |
+
} else {
|
| 958 |
+
fprintf(stderr, "%s: grammar:\n", __func__);
|
| 959 |
+
grammar_parser::print_grammar(stderr, grammar);
|
| 960 |
+
fprintf(stderr, "\n");
|
| 961 |
+
}
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
| 965 |
const auto fname_inp = params.fname_inp[f];
|
| 966 |
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
|
|
|
| 1007 |
{
|
| 1008 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 1009 |
|
| 1010 |
+
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
|
| 1011 |
+
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
| 1012 |
|
| 1013 |
wparams.print_realtime = false;
|
| 1014 |
wparams.print_progress = params.print_progress;
|
|
|
|
| 1046 |
|
| 1047 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
| 1048 |
|
| 1049 |
+
const auto & grammar_parsed = params.grammar_parsed;
|
| 1050 |
+
auto grammar_rules = grammar_parsed.c_rules();
|
| 1051 |
+
|
| 1052 |
+
if (use_grammar) {
|
| 1053 |
+
if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
| 1054 |
+
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
|
| 1055 |
+
} else {
|
| 1056 |
+
wparams.grammar_rules = grammar_rules.data();
|
| 1057 |
+
wparams.n_grammar_rules = grammar_rules.size();
|
| 1058 |
+
wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
|
| 1059 |
+
wparams.grammar_penalty = params.grammar_penalty;
|
| 1060 |
+
}
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
// this callback is called on each new segment
|
| 1064 |
if (!wparams.print_realtime) {
|
| 1065 |
wparams.new_segment_callback = whisper_print_segment_callback;
|