pprobst commited on
Commit
b4d05df
·
unverified ·
1 Parent(s): e9954d9

node : add flash_attn param (#2170)

Browse files
examples/addon.node/__test__/whisper.spec.js CHANGED
@@ -12,6 +12,7 @@ const whisperParamsMock = {
12
  model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
13
  fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
14
  use_gpu: true,
 
15
  no_prints: true,
16
  comma_in_time: false,
17
  translate: true,
 
12
  model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
13
  fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
14
  use_gpu: true,
15
+ flash_attn: false,
16
  no_prints: true,
17
  comma_in_time: false,
18
  translate: true,
examples/addon.node/addon.cpp CHANGED
@@ -39,6 +39,7 @@ struct whisper_params {
39
  bool no_timestamps = false;
40
  bool no_prints = false;
41
  bool use_gpu = true;
 
42
  bool comma_in_time = true;
43
 
44
  std::string language = "en";
@@ -146,6 +147,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
146
 
147
  struct whisper_context_params cparams = whisper_context_default_params();
148
  cparams.use_gpu = params.use_gpu;
 
149
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
150
 
151
  if (ctx == nullptr) {
@@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
326
  std::string model = whisper_params.Get("model").As<Napi::String>();
327
  std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
328
  bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
 
329
  bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
330
  bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
331
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
@@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
346
  params.model = model;
347
  params.fname_inp.emplace_back(input);
348
  params.use_gpu = use_gpu;
 
349
  params.no_prints = no_prints;
350
  params.no_timestamps = no_timestamps;
351
  params.audio_ctx = audio_ctx;
 
39
  bool no_timestamps = false;
40
  bool no_prints = false;
41
  bool use_gpu = true;
42
+ bool flash_attn = false;
43
  bool comma_in_time = true;
44
 
45
  std::string language = "en";
 
147
 
148
  struct whisper_context_params cparams = whisper_context_default_params();
149
  cparams.use_gpu = params.use_gpu;
150
+ cparams.flash_attn = params.flash_attn;
151
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
152
 
153
  if (ctx == nullptr) {
 
328
  std::string model = whisper_params.Get("model").As<Napi::String>();
329
  std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
330
  bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
331
+ bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
332
  bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
333
  bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
334
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
 
349
  params.model = model;
350
  params.fname_inp.emplace_back(input);
351
  params.use_gpu = use_gpu;
352
+ params.flash_attn = flash_attn;
353
  params.no_prints = no_prints;
354
  params.no_timestamps = no_timestamps;
355
  params.audio_ctx = audio_ctx;
examples/addon.node/index.js CHANGED
@@ -12,6 +12,7 @@ const whisperParams = {
12
  model: path.join(__dirname, "../../models/ggml-base.en.bin"),
13
  fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
14
  use_gpu: true,
 
15
  no_prints: true,
16
  comma_in_time: false,
17
  translate: true,
 
12
  model: path.join(__dirname, "../../models/ggml-base.en.bin"),
13
  fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
14
  use_gpu: true,
15
+ flash_attn: false,
16
  no_prints: true,
17
  comma_in_time: false,
18
  translate: true,