Spaces:
Sleeping
Sleeping
Andy Maloney
commited on
examples : small code cleanups (#322)
Browse files- remove unnecessary initialization of string to ""
- use empty() instead of checking size()
- use emplace_back instead of push_back
- use nullptr instead of NULL
- remove unnecessary call to .data() on string
- use character overload of find_first_of() instead of passing a string
- examples/command/command.cpp +5 -5
- examples/main/main.cpp +8 -8
- examples/stream/stream.cpp +1 -1
- examples/talk/gpt-2.cpp +6 -6
- examples/talk/talk.cpp +6 -6
examples/command/command.cpp
CHANGED
|
@@ -41,8 +41,8 @@ struct whisper_params {
|
|
| 41 |
|
| 42 |
std::string language = "en";
|
| 43 |
std::string model = "models/ggml-base.en.bin";
|
| 44 |
-
std::string fname_out
|
| 45 |
-
std::string commands
|
| 46 |
};
|
| 47 |
|
| 48 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -576,10 +576,10 @@ int main(int argc, char ** argv) {
|
|
| 576 |
std::vector<std::string> allowed_commands;
|
| 577 |
std::vector<std::vector<whisper_token>> allowed_tokens;
|
| 578 |
|
| 579 |
-
std::string k_prompt
|
| 580 |
std::vector<whisper_token> k_tokens;
|
| 581 |
|
| 582 |
-
if (params.commands
|
| 583 |
fprintf(stderr, "\n");
|
| 584 |
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 585 |
|
|
@@ -808,7 +808,7 @@ int main(int argc, char ** argv) {
|
|
| 808 |
|
| 809 |
double psum = 0.0;
|
| 810 |
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 811 |
-
probs_id.
|
| 812 |
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 813 |
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 814 |
}
|
|
|
|
| 41 |
|
| 42 |
std::string language = "en";
|
| 43 |
std::string model = "models/ggml-base.en.bin";
|
| 44 |
+
std::string fname_out;
|
| 45 |
+
std::string commands;
|
| 46 |
};
|
| 47 |
|
| 48 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 576 |
std::vector<std::string> allowed_commands;
|
| 577 |
std::vector<std::vector<whisper_token>> allowed_tokens;
|
| 578 |
|
| 579 |
+
std::string k_prompt;
|
| 580 |
std::vector<whisper_token> k_tokens;
|
| 581 |
|
| 582 |
+
if (!params.commands.empty()) {
|
| 583 |
fprintf(stderr, "\n");
|
| 584 |
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 585 |
|
|
|
|
| 808 |
|
| 809 |
double psum = 0.0;
|
| 810 |
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 811 |
+
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
| 812 |
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 813 |
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 814 |
}
|
examples/main/main.cpp
CHANGED
|
@@ -75,7 +75,7 @@ struct whisper_params {
|
|
| 75 |
bool no_timestamps = false;
|
| 76 |
|
| 77 |
std::string language = "en";
|
| 78 |
-
std::string prompt
|
| 79 |
std::string model = "models/ggml-base.en.bin";
|
| 80 |
|
| 81 |
std::vector<std::string> fname_inp = {};
|
|
@@ -118,7 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 118 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 119 |
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
| 120 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 121 |
-
else if (arg == "-f" || arg == "--file") { params.fname_inp.
|
| 122 |
else {
|
| 123 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 124 |
whisper_print_usage(argc, argv, params);
|
|
@@ -206,7 +206,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
|
| 206 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 207 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 208 |
|
| 209 |
-
std::string speaker
|
| 210 |
|
| 211 |
if (params.diarize && pcmf32s.size() == 2) {
|
| 212 |
const int64_t n_samples = pcmf32s[0].size();
|
|
@@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
|
|
| 468 |
// initial prompt
|
| 469 |
std::vector<whisper_token> prompt_tokens;
|
| 470 |
|
| 471 |
-
if (params.prompt.
|
| 472 |
prompt_tokens.resize(1024);
|
| 473 |
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
| 474 |
|
|
@@ -505,14 +505,14 @@ int main(int argc, char ** argv) {
|
|
| 505 |
}
|
| 506 |
}
|
| 507 |
|
| 508 |
-
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(),
|
| 509 |
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
| 510 |
return 4;
|
| 511 |
}
|
| 512 |
|
| 513 |
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
| 514 |
}
|
| 515 |
-
else if (drwav_init_file(&wav, fname_inp.c_str(),
|
| 516 |
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
| 517 |
return 5;
|
| 518 |
}
|
|
@@ -617,8 +617,8 @@ int main(int argc, char ** argv) {
|
|
| 617 |
|
| 618 |
wparams.speed_up = params.speed_up;
|
| 619 |
|
| 620 |
-
wparams.prompt_tokens = prompt_tokens.
|
| 621 |
-
wparams.prompt_n_tokens = prompt_tokens.
|
| 622 |
|
| 623 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
| 624 |
|
|
|
|
| 75 |
bool no_timestamps = false;
|
| 76 |
|
| 77 |
std::string language = "en";
|
| 78 |
+
std::string prompt;
|
| 79 |
std::string model = "models/ggml-base.en.bin";
|
| 80 |
|
| 81 |
std::vector<std::string> fname_inp = {};
|
|
|
|
| 118 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 119 |
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
| 120 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 121 |
+
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
| 122 |
else {
|
| 123 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 124 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 206 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 207 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 208 |
|
| 209 |
+
std::string speaker;
|
| 210 |
|
| 211 |
if (params.diarize && pcmf32s.size() == 2) {
|
| 212 |
const int64_t n_samples = pcmf32s[0].size();
|
|
|
|
| 468 |
// initial prompt
|
| 469 |
std::vector<whisper_token> prompt_tokens;
|
| 470 |
|
| 471 |
+
if (!params.prompt.empty()) {
|
| 472 |
prompt_tokens.resize(1024);
|
| 473 |
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
| 474 |
|
|
|
|
| 505 |
}
|
| 506 |
}
|
| 507 |
|
| 508 |
+
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
| 509 |
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
| 510 |
return 4;
|
| 511 |
}
|
| 512 |
|
| 513 |
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
| 514 |
}
|
| 515 |
+
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
| 516 |
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
| 517 |
return 5;
|
| 518 |
}
|
|
|
|
| 617 |
|
| 618 |
wparams.speed_up = params.speed_up;
|
| 619 |
|
| 620 |
+
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
| 621 |
+
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
| 622 |
|
| 623 |
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
| 624 |
|
examples/stream/stream.cpp
CHANGED
|
@@ -51,7 +51,7 @@ struct whisper_params {
|
|
| 51 |
|
| 52 |
std::string language = "en";
|
| 53 |
std::string model = "models/ggml-base.en.bin";
|
| 54 |
-
std::string fname_out
|
| 55 |
};
|
| 56 |
|
| 57 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 51 |
|
| 52 |
std::string language = "en";
|
| 53 |
std::string model = "models/ggml-base.en.bin";
|
| 54 |
+
std::string fname_out;
|
| 55 |
};
|
| 56 |
|
| 57 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
examples/talk/gpt-2.cpp
CHANGED
|
@@ -40,7 +40,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
|
| 40 |
// find the longest tokens that form the words:
|
| 41 |
std::vector<gpt_vocab::id> tokens;
|
| 42 |
for (const auto & word : words) {
|
| 43 |
-
if (word.
|
| 44 |
|
| 45 |
int i = 0;
|
| 46 |
int n = word.size();
|
|
@@ -86,7 +86,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
|
| 86 |
logits_id.reserve(n_logits);
|
| 87 |
|
| 88 |
for (int i = 0; i < n_logits; i++) {
|
| 89 |
-
logits_id.
|
| 90 |
}
|
| 91 |
|
| 92 |
// find the top K tokens
|
|
@@ -327,7 +327,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
| 327 |
{
|
| 328 |
struct ggml_init_params params;
|
| 329 |
params.mem_size = ctx_size;
|
| 330 |
-
params.mem_buffer =
|
| 331 |
|
| 332 |
model.ctx = ggml_init(params);
|
| 333 |
if (!model.ctx) {
|
|
@@ -448,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
| 448 |
std::string name(length, 0);
|
| 449 |
fin.read(&name[0], length);
|
| 450 |
|
| 451 |
-
if (model.tensors.find(name
|
| 452 |
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
| 453 |
return false;
|
| 454 |
}
|
|
@@ -833,7 +833,7 @@ Me too.
|
|
| 833 |
struct gpt2_context * gpt2_init(const char * path_model) {
|
| 834 |
gpt2_context * ctx = new gpt2_context;
|
| 835 |
|
| 836 |
-
ctx->rng = std::mt19937(time(
|
| 837 |
|
| 838 |
// load the model
|
| 839 |
{
|
|
@@ -886,7 +886,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
|
| 886 |
|
| 887 |
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
| 888 |
// predict
|
| 889 |
-
if (embd.
|
| 890 |
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
| 891 |
printf("gpt-2: failed to generate text\n");
|
| 892 |
return "";
|
|
|
|
| 40 |
// find the longest tokens that form the words:
|
| 41 |
std::vector<gpt_vocab::id> tokens;
|
| 42 |
for (const auto & word : words) {
|
| 43 |
+
if (word.empty()) continue;
|
| 44 |
|
| 45 |
int i = 0;
|
| 46 |
int n = word.size();
|
|
|
|
| 86 |
logits_id.reserve(n_logits);
|
| 87 |
|
| 88 |
for (int i = 0; i < n_logits; i++) {
|
| 89 |
+
logits_id.emplace_back(logits[i], i);
|
| 90 |
}
|
| 91 |
|
| 92 |
// find the top K tokens
|
|
|
|
| 327 |
{
|
| 328 |
struct ggml_init_params params;
|
| 329 |
params.mem_size = ctx_size;
|
| 330 |
+
params.mem_buffer = nullptr;
|
| 331 |
|
| 332 |
model.ctx = ggml_init(params);
|
| 333 |
if (!model.ctx) {
|
|
|
|
| 448 |
std::string name(length, 0);
|
| 449 |
fin.read(&name[0], length);
|
| 450 |
|
| 451 |
+
if (model.tensors.find(name) == model.tensors.end()) {
|
| 452 |
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
| 453 |
return false;
|
| 454 |
}
|
|
|
|
| 833 |
struct gpt2_context * gpt2_init(const char * path_model) {
|
| 834 |
gpt2_context * ctx = new gpt2_context;
|
| 835 |
|
| 836 |
+
ctx->rng = std::mt19937(time(nullptr));
|
| 837 |
|
| 838 |
// load the model
|
| 839 |
{
|
|
|
|
| 886 |
|
| 887 |
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
| 888 |
// predict
|
| 889 |
+
if (!embd.empty()) {
|
| 890 |
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
| 891 |
printf("gpt-2: failed to generate text\n");
|
| 892 |
return "";
|
examples/talk/talk.cpp
CHANGED
|
@@ -39,7 +39,7 @@ struct whisper_params {
|
|
| 39 |
std::string model_wsp = "models/ggml-base.en.bin";
|
| 40 |
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
|
| 41 |
std::string speak = "./examples/talk/speak.sh";
|
| 42 |
-
std::string fname_out
|
| 43 |
};
|
| 44 |
|
| 45 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
|
|
| 588 |
|
| 589 |
audio.get(params.voice_ms, pcmf32_cur);
|
| 590 |
|
| 591 |
-
std::string text_heard
|
| 592 |
|
| 593 |
if (!force_speak) {
|
| 594 |
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
|
@@ -610,7 +610,7 @@ int main(int argc, char ** argv) {
|
|
| 610 |
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
| 611 |
|
| 612 |
// take first line
|
| 613 |
-
text_heard = text_heard.substr(0, text_heard.find_first_of(
|
| 614 |
|
| 615 |
// remove leading and trailing whitespace
|
| 616 |
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
|
@@ -640,18 +640,18 @@ int main(int argc, char ** argv) {
|
|
| 640 |
|
| 641 |
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
| 642 |
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
| 643 |
-
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of(
|
| 644 |
|
| 645 |
// remove first 2 lines of base prompt
|
| 646 |
if (n_iter > 4) {
|
| 647 |
{
|
| 648 |
-
const size_t pos = prompt_base.find_first_of(
|
| 649 |
if (pos != std::string::npos) {
|
| 650 |
prompt_base = prompt_base.substr(pos + 1);
|
| 651 |
}
|
| 652 |
}
|
| 653 |
{
|
| 654 |
-
const size_t pos = prompt_base.find_first_of(
|
| 655 |
if (pos != std::string::npos) {
|
| 656 |
prompt_base = prompt_base.substr(pos + 1);
|
| 657 |
}
|
|
|
|
| 39 |
std::string model_wsp = "models/ggml-base.en.bin";
|
| 40 |
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
|
| 41 |
std::string speak = "./examples/talk/speak.sh";
|
| 42 |
+
std::string fname_out;
|
| 43 |
};
|
| 44 |
|
| 45 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 588 |
|
| 589 |
audio.get(params.voice_ms, pcmf32_cur);
|
| 590 |
|
| 591 |
+
std::string text_heard;
|
| 592 |
|
| 593 |
if (!force_speak) {
|
| 594 |
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
|
|
|
| 610 |
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
| 611 |
|
| 612 |
// take first line
|
| 613 |
+
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
| 614 |
|
| 615 |
// remove leading and trailing whitespace
|
| 616 |
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
|
|
|
| 640 |
|
| 641 |
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
| 642 |
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
| 643 |
+
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
|
| 644 |
|
| 645 |
// remove first 2 lines of base prompt
|
| 646 |
if (n_iter > 4) {
|
| 647 |
{
|
| 648 |
+
const size_t pos = prompt_base.find_first_of('\n');
|
| 649 |
if (pos != std::string::npos) {
|
| 650 |
prompt_base = prompt_base.substr(pos + 1);
|
| 651 |
}
|
| 652 |
}
|
| 653 |
{
|
| 654 |
+
const size_t pos = prompt_base.find_first_of('\n');
|
| 655 |
if (pos != std::string::npos) {
|
| 656 |
prompt_base = prompt_base.substr(pos + 1);
|
| 657 |
}
|