Spaces:
Running
Running
node : add flash_attn param (#2170)
Browse files
examples/addon.node/__test__/whisper.spec.js
CHANGED
|
@@ -12,6 +12,7 @@ const whisperParamsMock = {
|
|
| 12 |
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
|
| 13 |
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
|
| 14 |
use_gpu: true,
|
|
|
|
| 15 |
no_prints: true,
|
| 16 |
comma_in_time: false,
|
| 17 |
translate: true,
|
|
|
|
| 12 |
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
|
| 13 |
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
|
| 14 |
use_gpu: true,
|
| 15 |
+
flash_attn: false,
|
| 16 |
no_prints: true,
|
| 17 |
comma_in_time: false,
|
| 18 |
translate: true,
|
examples/addon.node/addon.cpp
CHANGED
|
@@ -39,6 +39,7 @@ struct whisper_params {
|
|
| 39 |
bool no_timestamps = false;
|
| 40 |
bool no_prints = false;
|
| 41 |
bool use_gpu = true;
|
|
|
|
| 42 |
bool comma_in_time = true;
|
| 43 |
|
| 44 |
std::string language = "en";
|
|
@@ -146,6 +147,7 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
|
| 146 |
|
| 147 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 148 |
cparams.use_gpu = params.use_gpu;
|
|
|
|
| 149 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 150 |
|
| 151 |
if (ctx == nullptr) {
|
|
@@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
|
| 326 |
std::string model = whisper_params.Get("model").As<Napi::String>();
|
| 327 |
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
|
| 328 |
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
|
|
|
|
| 329 |
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
|
| 330 |
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
| 331 |
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
|
@@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
|
| 346 |
params.model = model;
|
| 347 |
params.fname_inp.emplace_back(input);
|
| 348 |
params.use_gpu = use_gpu;
|
|
|
|
| 349 |
params.no_prints = no_prints;
|
| 350 |
params.no_timestamps = no_timestamps;
|
| 351 |
params.audio_ctx = audio_ctx;
|
|
|
|
| 39 |
bool no_timestamps = false;
|
| 40 |
bool no_prints = false;
|
| 41 |
bool use_gpu = true;
|
| 42 |
+
bool flash_attn = false;
|
| 43 |
bool comma_in_time = true;
|
| 44 |
|
| 45 |
std::string language = "en";
|
|
|
|
| 147 |
|
| 148 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 149 |
cparams.use_gpu = params.use_gpu;
|
| 150 |
+
cparams.flash_attn = params.flash_attn;
|
| 151 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 152 |
|
| 153 |
if (ctx == nullptr) {
|
|
|
|
| 328 |
std::string model = whisper_params.Get("model").As<Napi::String>();
|
| 329 |
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
|
| 330 |
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
|
| 331 |
+
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
|
| 332 |
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
|
| 333 |
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
| 334 |
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
|
|
|
| 349 |
params.model = model;
|
| 350 |
params.fname_inp.emplace_back(input);
|
| 351 |
params.use_gpu = use_gpu;
|
| 352 |
+
params.flash_attn = flash_attn;
|
| 353 |
params.no_prints = no_prints;
|
| 354 |
params.no_timestamps = no_timestamps;
|
| 355 |
params.audio_ctx = audio_ctx;
|
examples/addon.node/index.js
CHANGED
|
@@ -12,6 +12,7 @@ const whisperParams = {
|
|
| 12 |
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
| 13 |
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
|
| 14 |
use_gpu: true,
|
|
|
|
| 15 |
no_prints: true,
|
| 16 |
comma_in_time: false,
|
| 17 |
translate: true,
|
|
|
|
| 12 |
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
| 13 |
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
|
| 14 |
use_gpu: true,
|
| 15 |
+
flash_attn: false,
|
| 16 |
no_prints: true,
|
| 17 |
comma_in_time: false,
|
| 18 |
translate: true,
|