duncanm5 commited on
Commit
61d4989
·
unverified ·
1 Parent(s): 44ab249

examples : add tinydiarization support for streaming (#1137)

Browse files
Files changed (1) hide show
  1. examples/stream/stream.cpp +17 -2
examples/stream/stream.cpp CHANGED
@@ -47,6 +47,7 @@ struct whisper_params {
47
  bool print_special = false;
48
  bool no_context = true;
49
  bool no_timestamps = false;
 
50
 
51
  std::string language = "en";
52
  std::string model = "models/ggml-base.en.bin";
@@ -80,6 +81,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
80
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
81
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
82
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
 
 
83
  else {
84
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
85
  whisper_print_usage(argc, argv, params);
@@ -113,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
113
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
114
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
115
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
 
116
  fprintf(stderr, "\n");
117
  }
118
 
@@ -299,6 +303,8 @@ int main(int argc, char ** argv) {
299
  wparams.audio_ctx = params.audio_ctx;
300
  wparams.speed_up = params.speed_up;
301
 
 
 
302
  // disable temperature fallback
303
  //wparams.temperature_inc = -1.0f;
304
  wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
@@ -344,10 +350,19 @@ int main(int argc, char ** argv) {
344
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
345
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
346
 
347
- printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
 
 
 
 
 
 
 
 
 
348
 
349
  if (params.fname_out.length() > 0) {
350
- fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "] " << text << std::endl;
351
  }
352
  }
353
  }
 
47
  bool print_special = false;
48
  bool no_context = true;
49
  bool no_timestamps = false;
50
+ bool tinydiarize = false;
51
 
52
  std::string language = "en";
53
  std::string model = "models/ggml-base.en.bin";
 
81
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
82
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
83
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
84
+ else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
85
+
86
  else {
87
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
88
  whisper_print_usage(argc, argv, params);
 
116
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
117
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
118
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
119
+ fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
120
  fprintf(stderr, "\n");
121
  }
122
 
 
303
  wparams.audio_ctx = params.audio_ctx;
304
  wparams.speed_up = params.speed_up;
305
 
306
+ wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
307
+
308
  // disable temperature fallback
309
  //wparams.temperature_inc = -1.0f;
310
  wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
 
350
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
351
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
352
 
353
+ std::string output = "[" + to_timestamp(t0) + " --> " + to_timestamp(t1) + "] " + text;
354
+
355
+ if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
356
+ output += " [SPEAKER_TURN]";
357
+ }
358
+
359
+ output += "\n";
360
+
361
+ printf("%s", output.c_str());
362
+ fflush(stdout);
363
 
364
  if (params.fname_out.length() > 0) {
365
+ fout << output;
366
  }
367
  }
368
  }