ulatekh ggerganov commited on
Commit
7e6ea10
·
unverified ·
1 Parent(s): b1f3938

main : add command-style grammar (#1998)

Browse files

* Implemented command-style grammar in the main example.

Mostly just copied the relevant parts from the command example.

* main : code style

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (1) hide show
  1. examples/main/main.cpp +54 -4
examples/main/main.cpp CHANGED
@@ -1,6 +1,7 @@
1
  #include "common.h"
2
 
3
  #include "whisper.h"
 
4
 
5
  #include <cmath>
6
  #include <fstream>
@@ -38,9 +39,10 @@ struct whisper_params {
38
  int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39
  int32_t audio_ctx = 0;
40
 
41
- float word_thold = 0.01f;
42
- float entropy_thold = 2.40f;
43
- float logprob_thold = -1.00f;
 
44
 
45
  bool speed_up = false;
46
  bool debug_mode = false;
@@ -70,6 +72,8 @@ struct whisper_params {
70
  std::string prompt;
71
  std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
72
  std::string model = "models/ggml-base.en.bin";
 
 
73
 
74
  // [TDRZ] speaker turn string
75
  std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
@@ -80,6 +84,8 @@ struct whisper_params {
80
 
81
  std::vector<std::string> fname_inp = {};
82
  std::vector<std::string> fname_out = {};
 
 
83
  };
84
 
85
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
154
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
155
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
156
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
 
 
157
  else {
158
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
159
  whisper_print_usage(argc, argv, params);
@@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
214
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
215
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
216
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
 
 
217
  fprintf(stderr, "\n");
218
  }
219
 
@@ -926,6 +938,29 @@ int main(int argc, char ** argv) {
926
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
927
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
930
  const auto fname_inp = params.fname_inp[f];
931
  const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -972,7 +1007,8 @@ int main(int argc, char ** argv) {
972
  {
973
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
974
 
975
- wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
 
976
 
977
  wparams.print_realtime = false;
978
  wparams.print_progress = params.print_progress;
@@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) {
1010
 
1011
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
1012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  // this callback is called on each new segment
1014
  if (!wparams.print_realtime) {
1015
  wparams.new_segment_callback = whisper_print_segment_callback;
 
1
  #include "common.h"
2
 
3
  #include "whisper.h"
4
+ #include "grammar-parser.h"
5
 
6
  #include <cmath>
7
  #include <fstream>
 
39
  int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
40
  int32_t audio_ctx = 0;
41
 
42
+ float word_thold = 0.01f;
43
+ float entropy_thold = 2.40f;
44
+ float logprob_thold = -1.00f;
45
+ float grammar_penalty = 100.0f;
46
 
47
  bool speed_up = false;
48
  bool debug_mode = false;
 
72
  std::string prompt;
73
  std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
74
  std::string model = "models/ggml-base.en.bin";
75
+ std::string grammar;
76
+ std::string grammar_rule;
77
 
78
  // [TDRZ] speaker turn string
79
  std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
 
84
 
85
  std::vector<std::string> fname_inp = {};
86
  std::vector<std::string> fname_out = {};
87
+
88
+ grammar_parser::parse_state grammar_parsed;
89
  };
90
 
91
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
160
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
161
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
162
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
163
+ else if ( arg == "--grammar") { params.grammar = argv[++i]; }
164
+ else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
165
+ else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
166
  else {
167
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
168
  whisper_print_usage(argc, argv, params);
 
223
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
224
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
225
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
226
+ fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
227
+ fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
228
+ fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
229
  fprintf(stderr, "\n");
230
  }
231
 
 
938
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
939
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
940
 
941
+ if (!params.grammar.empty()) {
942
+ auto & grammar = params.grammar_parsed;
943
+ if (is_file_exist(params.grammar.c_str())) {
944
+ // read grammar from file
945
+ std::ifstream ifs(params.grammar.c_str());
946
+ const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
947
+ grammar = grammar_parser::parse(txt.c_str());
948
+ } else {
949
+ // read grammar from string
950
+ grammar = grammar_parser::parse(params.grammar.c_str());
951
+ }
952
+
953
+ // will be empty (default) if there are parse errors
954
+ if (grammar.rules.empty()) {
955
+ fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
956
+ return 4;
957
+ } else {
958
+ fprintf(stderr, "%s: grammar:\n", __func__);
959
+ grammar_parser::print_grammar(stderr, grammar);
960
+ fprintf(stderr, "\n");
961
+ }
962
+ }
963
+
964
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
965
  const auto fname_inp = params.fname_inp[f];
966
  const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
 
1007
  {
1008
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
1009
 
1010
+ const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
1011
+ wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
1012
 
1013
  wparams.print_realtime = false;
1014
  wparams.print_progress = params.print_progress;
 
1046
 
1047
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
1048
 
1049
+ const auto & grammar_parsed = params.grammar_parsed;
1050
+ auto grammar_rules = grammar_parsed.c_rules();
1051
+
1052
+ if (use_grammar) {
1053
+ if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) {
1054
+ fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
1055
+ } else {
1056
+ wparams.grammar_rules = grammar_rules.data();
1057
+ wparams.n_grammar_rules = grammar_rules.size();
1058
+ wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
1059
+ wparams.grammar_penalty = params.grammar_penalty;
1060
+ }
1061
+ }
1062
+
1063
  // this callback is called on each new segment
1064
  if (!wparams.print_realtime) {
1065
  wparams.new_segment_callback = whisper_print_segment_callback;