ggerganov commited on
Commit
3aa9e6c
·
unverified ·
1 Parent(s): 2439857

whisper : reduce memory usage during inference (#431)

Browse files

* ggml : add "scratch" buffer support

* ggml : support for scratch ring-buffer

* ggml : bug fix in ggml_repeat()

* ggml : error on scratch buffer overflow

* whisper : use scratch buffers during inference (base model only)

* whisper : update memory usage for all models

* whisper : fix encoder memory usage

* whisper : use whisper_context functions instead of macros

* whisper : fix FF + remove it from README

* ggml : reuse ggml_new_i32

* ggml : refactor the scratch buffer storage

* whisper : reorder scratch buffers in the decoder

* main : add option to disable temp fallback

* Update README.md

Files changed (7) hide show
  1. README.md +109 -98
  2. bindings/javascript/whisper.js +0 -0
  3. examples/main/README.md +31 -21
  4. examples/main/main.cpp +19 -14
  5. ggml.c +88 -35
  6. ggml.h +9 -0
  7. whisper.cpp +392 -250
README.md CHANGED
@@ -13,7 +13,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
13
  - AVX intrinsics support for x86 architectures
14
  - VSX intrinsics support for POWER architectures
15
  - Mixed F16 / F32 precision
16
- - Low memory usage (Flash Attention + Flash Forward)
17
  - Zero memory allocations at runtime
18
  - Runs on the CPU
19
  - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
@@ -89,35 +89,37 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
89
  usage: ./main [options] file0.wav file1.wav ...
90
 
91
  options:
92
- -h, --help [default] show this help message and exit
93
- -t N, --threads N [4 ] number of threads to use during computation
94
- -p N, --processors N [1 ] number of processors to use during computation
95
- -ot N, --offset-t N [0 ] time offset in milliseconds
96
- -on N, --offset-n N [0 ] segment index offset
97
- -d N, --duration N [0 ] duration of audio to process in milliseconds
98
- -mc N, --max-context N [-1 ] maximum number of text context tokens to store
99
- -ml N, --max-len N [0 ] maximum segment length in characters
100
- -bo N, --best-of N [5 ] number of best candidates to keep
101
- -bs N, --beam-size N [-1 ] beam size for beam search
102
- -wt N, --word-thold N [0.01 ] word timestamp probability threshold
103
- -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
104
- -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
105
- -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
106
- -tr, --translate [false ] translate from source language to english
107
- -di, --diarize [false ] stereo audio diarization
108
- -otxt, --output-txt [false ] output result in a text file
109
- -ovtt, --output-vtt [false ] output result in a vtt file
110
- -osrt, --output-srt [false ] output result in a srt file
111
- -owts, --output-words [false ] output script for generating karaoke video
112
- -ocsv, --output-csv [false ] output result in a CSV file
113
- -ps, --print-special [false ] print special tokens
114
- -pc, --print-colors [false ] print colors
115
- -pp, --print-progress [false ] print progress
116
- -nt, --no-timestamps [true ] do not print timestamps
117
- -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
118
- --prompt PROMPT [ ] initial prompt
119
- -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
120
- -f FNAME, --file FNAME [ ] input WAV file path
 
 
121
 
122
 
123
  bash ./models/download-ggml-model.sh base.en
@@ -137,7 +139,8 @@ Running base.en on all samples in ./samples ...
137
  [+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
138
  ----------------------------------------------
139
 
140
- whisper_model_load: loading model from 'models/ggml-base.en.bin'
 
141
  whisper_model_load: n_vocab = 51864
142
  whisper_model_load: n_audio_ctx = 1500
143
  whisper_model_load: n_audio_state = 512
@@ -150,13 +153,14 @@ whisper_model_load: n_text_layer = 6
150
  whisper_model_load: n_mels = 80
151
  whisper_model_load: f16 = 1
152
  whisper_model_load: type = 2
 
 
 
153
  whisper_model_load: adding 1607 extra tokens
154
- whisper_model_load: mem_required = 506.00 MB
155
- whisper_model_load: ggml ctx size = 140.60 MB
156
- whisper_model_load: memory size = 22.83 MB
157
  whisper_model_load: model size = 140.54 MB
158
 
159
- system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
160
 
161
  main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
162
 
@@ -164,12 +168,13 @@ main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 proc
164
  [00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
165
 
166
 
167
- whisper_print_timings: load time = 105.91 ms
168
- whisper_print_timings: mel time = 24.62 ms
169
- whisper_print_timings: sample time = 3.63 ms
170
- whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
171
- whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
172
- whisper_print_timings: total time = 542.81 ms
 
173
  ```
174
 
175
  The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@@ -212,11 +217,11 @@ make large
212
 
213
  | Model | Disk | Mem | SHA |
214
  | --- | --- | --- | --- |
215
- | tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
216
- | base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
217
- | small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
218
- | medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
219
- | large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
220
 
221
  ## Limitations
222
 
@@ -234,7 +239,8 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
234
  ```java
235
  $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
236
 
237
- whisper_model_load: loading model from 'models/ggml-medium.en.bin'
 
238
  whisper_model_load: n_vocab = 51864
239
  whisper_model_load: n_audio_ctx = 1500
240
  whisper_model_load: n_audio_state = 1024
@@ -247,55 +253,60 @@ whisper_model_load: n_text_layer = 24
247
  whisper_model_load: n_mels = 80
248
  whisper_model_load: f16 = 1
249
  whisper_model_load: type = 4
250
- whisper_model_load: mem_required = 2610.00 MB
 
 
251
  whisper_model_load: adding 1607 extra tokens
252
- whisper_model_load: ggml ctx size = 1644.97 MB
253
- whisper_model_load: memory size = 182.62 MB
254
- whisper_model_load: model size = 1462.12 MB
255
-
256
- main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
257
-
258
- [00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
259
- [00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
260
- [00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
261
- [00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
262
- [00:29.000 --> 00:32.000] On board was a crew of seven.
263
- [00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
264
- [00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
265
- [00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
266
- [00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
267
- [00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
268
- [01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
269
- [01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
270
- [01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
271
- [01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
272
- [01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
273
- [01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
274
- [01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
275
- [01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
276
- [01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
277
- [01:52.000 --> 01:56.000] The cause in which they died will continue.
278
- [01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
279
- [02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
280
- [02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
281
- [02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
282
- [02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
283
- [02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
284
- [02:35.000 --> 02:39.000] and calls them each by name."
285
- [02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
286
- [02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
287
- [02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
288
- [03:01.000 --> 03:05.000] yet we can pray that all are safely home.
289
- [03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
290
- [03:13.000 --> 03:41.000] Audio
291
-
292
-
293
- whisper_print_timings: load time = 575.92 ms
294
- whisper_print_timings: mel time = 230.60 ms
295
- whisper_print_timings: sample time = 73.19 ms
296
- whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
297
- whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
298
- whisper_print_timings: total time = 33686.27 ms
 
 
 
299
  ```
300
  </details>
301
 
@@ -321,14 +332,14 @@ to highlight words with high or low confidence:
321
 
322
  ## Controlling the length of the generated text segments (experimental)
323
 
324
- For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
325
 
326
  ```java
327
  ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
328
 
329
  whisper_model_load: loading model from './models/ggml-base.en.bin'
330
  ...
331
- system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
332
 
333
  main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
334
 
@@ -352,7 +363,7 @@ The `--max-len` argument can be used to obtain word-level timestamps. Simply use
352
 
353
  whisper_model_load: loading model from './models/ggml-base.en.bin'
354
  ...
355
- system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
356
 
357
  main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
358
 
 
13
  - AVX intrinsics support for x86 architectures
14
  - VSX intrinsics support for POWER architectures
15
  - Mixed F16 / F32 precision
16
+ - Low memory usage (Flash Attention)
17
  - Zero memory allocations at runtime
18
  - Runs on the CPU
19
  - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
 
89
  usage: ./main [options] file0.wav file1.wav ...
90
 
91
  options:
92
+ -h, --help [default] show this help message and exit
93
+ -t N, --threads N [4 ] number of threads to use during computation
94
+ -p N, --processors N [1 ] number of processors to use during computation
95
+ -ot N, --offset-t N [0 ] time offset in milliseconds
96
+ -on N, --offset-n N [0 ] segment index offset
97
+ -d N, --duration N [0 ] duration of audio to process in milliseconds
98
+ -mc N, --max-context N [-1 ] maximum number of text context tokens to store
99
+ -ml N, --max-len N [0 ] maximum segment length in characters
100
+ -bo N, --best-of N [5 ] number of best candidates to keep
101
+ -bs N, --beam-size N [-1 ] beam size for beam search
102
+ -wt N, --word-thold N [0.01 ] word timestamp probability threshold
103
+ -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
104
+ -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
105
+ -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
106
+ -tr, --translate [false ] translate from source language to english
107
+ -di, --diarize [false ] stereo audio diarization
108
+ -nf, --no-fallback [false ] do not use temperature fallback while decoding
109
+ -otxt, --output-txt [false ] output result in a text file
110
+ -ovtt, --output-vtt [false ] output result in a vtt file
111
+ -osrt, --output-srt [false ] output result in a srt file
112
+ -owts, --output-words [false ] output script for generating karaoke video
113
+ -ocsv, --output-csv [false ] output result in a CSV file
114
+ -of FNAME, --output-file FNAME [ ] output file path (without file extension)
115
+ -ps, --print-special [false ] print special tokens
116
+ -pc, --print-colors [false ] print colors
117
+ -pp, --print-progress [false ] print progress
118
+ -nt, --no-timestamps [true ] do not print timestamps
119
+ -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
120
+ --prompt PROMPT [ ] initial prompt
121
+ -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
122
+ -f FNAME, --file FNAME [ ] input WAV file path
123
 
124
 
125
  bash ./models/download-ggml-model.sh base.en
 
139
  [+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
140
  ----------------------------------------------
141
 
142
+ whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
143
+ whisper_model_load: loading model
144
  whisper_model_load: n_vocab = 51864
145
  whisper_model_load: n_audio_ctx = 1500
146
  whisper_model_load: n_audio_state = 512
 
153
  whisper_model_load: n_mels = 80
154
  whisper_model_load: f16 = 1
155
  whisper_model_load: type = 2
156
+ whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
157
+ whisper_model_load: kv self size = 5.25 MB
158
+ whisper_model_load: kv cross size = 17.58 MB
159
  whisper_model_load: adding 1607 extra tokens
160
+ whisper_model_load: model ctx = 140.60 MB
 
 
161
  whisper_model_load: model size = 140.54 MB
162
 
163
+ system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
164
 
165
  main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
166
 
 
168
  [00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
169
 
170
 
171
+ whisper_print_timings: fallbacks = 0 p / 0 h
172
+ whisper_print_timings: load time = 113.81 ms
173
+ whisper_print_timings: mel time = 15.40 ms
174
+ whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
175
+ whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
176
+ whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
177
+ whisper_print_timings: total time = 476.31 ms
178
  ```
179
 
180
  The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
 
217
 
218
  | Model | Disk | Mem | SHA |
219
  | --- | --- | --- | --- |
220
+ | tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
221
+ | base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
222
+ | small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
223
+ | medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
224
+ | large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
225
 
226
  ## Limitations
227
 
 
239
  ```java
240
  $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
241
 
242
+ whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
243
+ whisper_model_load: loading model
244
  whisper_model_load: n_vocab = 51864
245
  whisper_model_load: n_audio_ctx = 1500
246
  whisper_model_load: n_audio_state = 1024
 
253
  whisper_model_load: n_mels = 80
254
  whisper_model_load: f16 = 1
255
  whisper_model_load: type = 4
256
+ whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
257
+ whisper_model_load: kv self size = 42.00 MB
258
+ whisper_model_load: kv cross size = 140.62 MB
259
  whisper_model_load: adding 1607 extra tokens
260
+ whisper_model_load: model ctx = 1462.35 MB
261
+ whisper_model_load: model size = 1462.12 MB
262
+
263
+ system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
264
+
265
+ main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
266
+
267
+
268
+ [00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
269
+ [00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
270
+ [00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
271
+ [00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
272
+ [00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
273
+ [00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
274
+ [00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
275
+ [00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
276
+ [00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
277
+ [00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
278
+ [00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
279
+ [00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
280
+ [00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
281
+ [00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
282
+ [00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
283
+ [00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
284
+ [00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
285
+ [00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
286
+ [00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
287
+ [00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
288
+ [00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
289
+ [00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
290
+ [00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
291
+ [00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
292
+ [00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
293
+ [00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
294
+ [00:02:35.000 --> 00:02:39.000] and calls them each by name."
295
+ [00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
296
+ [00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
297
+ [00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
298
+ [00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
299
+ [00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
300
+ [00:03:13.000 --> 00:03:19.000] [Silence]
301
+
302
+
303
+ whisper_print_timings: fallbacks = 1 p / 0 h
304
+ whisper_print_timings: load time = 569.03 ms
305
+ whisper_print_timings: mel time = 146.85 ms
306
+ whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
307
+ whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
308
+ whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
309
+ whisper_print_timings: total time = 32733.52 ms
310
  ```
311
  </details>
312
 
 
332
 
333
  ## Controlling the length of the generated text segments (experimental)
334
 
335
+ For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
336
 
337
  ```java
338
  ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
339
 
340
  whisper_model_load: loading model from './models/ggml-base.en.bin'
341
  ...
342
+ system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
343
 
344
  main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
345
 
 
363
 
364
  whisper_model_load: loading model from './models/ggml-base.en.bin'
365
  ...
366
+ system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
367
 
368
  main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
369
 
bindings/javascript/whisper.js CHANGED
The diff for this file is too large to render. See raw diff
 
examples/main/README.md CHANGED
@@ -9,25 +9,35 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
9
  usage: ./main [options] file0.wav file1.wav ...
10
 
11
  options:
12
- -h, --help [default] show this help message and exit
13
- -t N, --threads N [4 ] number of threads to use during computation
14
- -p N, --processors N [1 ] number of processors to use during computation
15
- -ot N, --offset-t N [0 ] time offset in milliseconds
16
- -on N, --offset-n N [0 ] segment index offset
17
- -d N, --duration N [0 ] duration of audio to process in milliseconds
18
- -mc N, --max-context N [-1 ] maximum number of text context tokens to store
19
- -ml N, --max-len N [0 ] maximum segment length in characters
20
- -wt N, --word-thold N [0.01 ] word timestamp probability threshold
21
- -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
22
- -tr, --translate [false ] translate from source language to english
23
- -otxt, --output-txt [false ] output result in a text file
24
- -ovtt, --output-vtt [false ] output result in a vtt file
25
- -osrt, --output-srt [false ] output result in a srt file
26
- -owts, --output-words [false ] output script for generating karaoke video
27
- -ps, --print-special [false ] print special tokens
28
- -pc, --print-colors [false ] print colors
29
- -nt, --no-timestamps [true ] do not print timestamps
30
- -l LANG, --language LANG [en ] spoken language
31
- -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
32
- -f FNAME, --file FNAME [ ] input WAV file path
 
 
 
 
 
 
 
 
 
 
33
  ```
 
9
  usage: ./main [options] file0.wav file1.wav ...
10
 
11
  options:
12
+ -h, --help [default] show this help message and exit
13
+ -t N, --threads N [4 ] number of threads to use during computation
14
+ -p N, --processors N [1 ] number of processors to use during computation
15
+ -ot N, --offset-t N [0 ] time offset in milliseconds
16
+ -on N, --offset-n N [0 ] segment index offset
17
+ -d N, --duration N [0 ] duration of audio to process in milliseconds
18
+ -mc N, --max-context N [-1 ] maximum number of text context tokens to store
19
+ -ml N, --max-len N [0 ] maximum segment length in characters
20
+ -bo N, --best-of N [5 ] number of best candidates to keep
21
+ -bs N, --beam-size N [-1 ] beam size for beam search
22
+ -wt N, --word-thold N [0.01 ] word timestamp probability threshold
23
+ -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
24
+ -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
25
+ -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
26
+ -tr, --translate [false ] translate from source language to english
27
+ -di, --diarize [false ] stereo audio diarization
28
+ -nf, --no-fallback [false ] do not use temperature fallback while decoding
29
+ -otxt, --output-txt [false ] output result in a text file
30
+ -ovtt, --output-vtt [false ] output result in a vtt file
31
+ -osrt, --output-srt [false ] output result in a srt file
32
+ -owts, --output-words [false ] output script for generating karaoke video
33
+ -ocsv, --output-csv [false ] output result in a CSV file
34
+ -of FNAME, --output-file FNAME [ ] output file path (without file extension)
35
+ -ps, --print-special [false ] print special tokens
36
+ -pc, --print-colors [false ] print colors
37
+ -pp, --print-progress [false ] print progress
38
+ -nt, --no-timestamps [true ] do not print timestamps
39
+ -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
40
+ --prompt PROMPT [ ] initial prompt
41
+ -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
42
+ -f FNAME, --file FNAME [ ] input WAV file path
43
  ```
examples/main/main.cpp CHANGED
@@ -53,22 +53,23 @@ void replace_all(std::string & s, const std::string & search, const std::string
53
  // command-line parameters
54
  struct whisper_params {
55
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
56
- int32_t n_processors = 1;
57
- int32_t offset_t_ms = 0;
58
- int32_t offset_n = 0;
59
- int32_t duration_ms = 0;
60
  int32_t max_context = -1;
61
- int32_t max_len = 0;
62
- int32_t best_of = 5;
63
  int32_t beam_size = -1;
64
 
65
- float word_thold = 0.01f;
66
- float entropy_thold = 2.4f;
67
- float logprob_thold = -1.0f;
68
 
69
  bool speed_up = false;
70
  bool translate = false;
71
  bool diarize = false;
 
72
  bool output_txt = false;
73
  bool output_vtt = false;
74
  bool output_srt = false;
@@ -117,6 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
117
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
118
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
119
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
 
120
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
121
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
122
  else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
@@ -162,6 +164,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
162
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
163
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
164
  fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
 
165
  fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
166
  fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
167
  fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
@@ -514,7 +517,7 @@ int main(int argc, char ** argv) {
514
 
515
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
516
  const auto fname_inp = params.fname_inp[f];
517
- const auto fname_outp = f < params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
518
 
519
  std::vector<float> pcmf32; // mono-channel F32 PCM
520
  std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@@ -647,17 +650,19 @@ int main(int argc, char ** argv) {
647
 
648
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
649
  wparams.thold_pt = params.word_thold;
650
- wparams.entropy_thold = params.entropy_thold;
651
- wparams.logprob_thold = params.logprob_thold;
652
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
653
 
654
  wparams.speed_up = params.speed_up;
655
 
 
 
 
656
  wparams.greedy.best_of = params.best_of;
657
  wparams.beam_search.beam_size = params.beam_size;
658
 
659
- wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
660
- wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
 
661
 
662
  whisper_print_user_data user_data = { &params, &pcmf32s };
663
 
 
53
  // command-line parameters
54
  struct whisper_params {
55
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
56
+ int32_t n_processors = 1;
57
+ int32_t offset_t_ms = 0;
58
+ int32_t offset_n = 0;
59
+ int32_t duration_ms = 0;
60
  int32_t max_context = -1;
61
+ int32_t max_len = 0;
62
+ int32_t best_of = 5;
63
  int32_t beam_size = -1;
64
 
65
+ float word_thold = 0.01f;
66
+ float entropy_thold = 2.40f;
67
+ float logprob_thold = -1.00f;
68
 
69
  bool speed_up = false;
70
  bool translate = false;
71
  bool diarize = false;
72
+ bool no_fallback = false;
73
  bool output_txt = false;
74
  bool output_vtt = false;
75
  bool output_srt = false;
 
118
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
119
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
120
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
121
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
122
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
123
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
124
  else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
 
164
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
165
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
166
  fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
167
+ fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
168
  fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
169
  fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
170
  fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
 
517
 
518
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
519
  const auto fname_inp = params.fname_inp[f];
520
+ const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
521
 
522
  std::vector<float> pcmf32; // mono-channel F32 PCM
523
  std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
650
 
651
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
652
  wparams.thold_pt = params.word_thold;
 
 
653
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
654
 
655
  wparams.speed_up = params.speed_up;
656
 
657
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
658
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
659
+
660
  wparams.greedy.best_of = params.best_of;
661
  wparams.beam_search.beam_size = params.beam_size;
662
 
663
+ wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
664
+ wparams.entropy_thold = params.entropy_thold;
665
+ wparams.logprob_thold = params.logprob_thold;
666
 
667
  whisper_print_user_data user_data = { &params, &pcmf32s };
668
 
ggml.c CHANGED
@@ -1258,7 +1258,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1258
  //
1259
 
1260
  struct ggml_object {
1261
- size_t offset;
1262
  size_t size;
1263
 
1264
  struct ggml_object * next;
@@ -1284,6 +1284,9 @@ struct ggml_context {
1284
 
1285
  struct ggml_object * objects_begin;
1286
  struct ggml_object * objects_end;
 
 
 
1287
  };
1288
 
1289
  struct ggml_context_container {
@@ -1346,7 +1349,7 @@ inline static void ggml_critical_section_end(void) {
1346
 
1347
  void ggml_print_object(const struct ggml_object * obj) {
1348
  GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
1349
- obj->offset, obj->size, (const void *) obj->next);
1350
  }
1351
 
1352
  void ggml_print_objects(const struct ggml_context * ctx) {
@@ -1542,12 +1545,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
1542
  }
1543
 
1544
  *ctx = (struct ggml_context) {
1545
- .mem_size = params.mem_size,
1546
- .mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
1547
- .mem_buffer_owned = params.mem_buffer ? false : true,
1548
- .n_objects = 0,
1549
- .objects_begin = NULL,
1550
- .objects_end = NULL,
 
 
1551
  };
1552
 
1553
  ggml_assert_aligned(ctx->mem_buffer);
@@ -1570,7 +1575,7 @@ void ggml_free(struct ggml_context * ctx) {
1570
  g_state.contexts[i].used = false;
1571
 
1572
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
1573
- __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
1574
 
1575
  if (ctx->mem_buffer_owned) {
1576
  free(ctx->mem_buffer);
@@ -1589,7 +1594,15 @@ void ggml_free(struct ggml_context * ctx) {
1589
  }
1590
 
1591
  size_t ggml_used_mem(const struct ggml_context * ctx) {
1592
- return ctx->objects_end->offset + ctx->objects_end->size;
 
 
 
 
 
 
 
 
1593
  }
1594
 
1595
  ////////////////////////////////////////////////////////////////////////////////
@@ -1603,9 +1616,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
1603
  // always insert objects at the end of the context's memory pool
1604
  struct ggml_object * obj_cur = ctx->objects_end;
1605
 
1606
- const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
1607
- const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
1608
- const size_t cur_end = cur_offset + cur_size;
1609
 
1610
  size_t size_needed = 0;
1611
 
@@ -1616,25 +1629,52 @@ struct ggml_tensor * ggml_new_tensor_impl(
1616
  }
1617
  // align to GGML_MEM_ALIGN
1618
  size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
1619
-
1620
- }
1621
- size_needed += sizeof(struct ggml_tensor);
1622
-
1623
- if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
1624
- GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
1625
- assert(false);
1626
- return NULL;
1627
  }
1628
 
1629
  char * const mem_buffer = ctx->mem_buffer;
1630
-
1631
  struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
1632
 
1633
- *obj_new = (struct ggml_object) {
1634
- .offset = cur_end + GGML_OBJECT_SIZE,
1635
- .size = size_needed,
1636
- .next = NULL,
1637
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1638
 
1639
  if (obj_cur != NULL) {
1640
  obj_cur->next = obj_new;
@@ -1645,9 +1685,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
1645
 
1646
  ctx->objects_end = obj_new;
1647
 
1648
- //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
1649
 
1650
- struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
1651
 
1652
  ggml_assert_aligned(result);
1653
 
@@ -1690,7 +1730,7 @@ struct ggml_tensor * ggml_new_tensor(
1690
  struct ggml_context * ctx,
1691
  enum ggml_type type,
1692
  int n_dims,
1693
- const int* ne) {
1694
  return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
1695
  }
1696
 
@@ -1732,16 +1772,26 @@ struct ggml_tensor * ggml_new_tensor_4d(
1732
  }
1733
 
1734
  struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
 
 
 
1735
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
1736
 
 
 
1737
  ggml_set_i32(result, value);
1738
 
1739
  return result;
1740
  }
1741
 
1742
  struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
 
 
 
1743
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
1744
 
 
 
1745
  ggml_set_f32(result, value);
1746
 
1747
  return result;
@@ -2350,7 +2400,7 @@ struct ggml_tensor * ggml_repeat(
2350
  result->op = GGML_OP_REPEAT;
2351
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
2352
  result->src0 = a;
2353
- result->src1 = NULL;
2354
 
2355
  return result;
2356
  }
@@ -2966,9 +3016,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
2966
  // TODO: when implement backward, fix this:
2967
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2968
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
2969
-
2970
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
2971
- ((int32_t *) b->data)[0] = n_past;
2972
 
2973
  result->op = GGML_OP_DIAG_MASK_INF;
2974
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4300,7 +4348,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
4300
  const int ne1 = dst->ne[1];
4301
 
4302
  // TODO: find the optimal values for these
4303
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
 
 
4304
  //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
4305
  return true;
4306
  }
@@ -7289,6 +7339,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
7289
  node->n_tasks = 1; // TODO: this actually is doing nothing
7290
  // the threads are still spinning
7291
  cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
 
 
 
7292
  } else {
7293
  cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
7294
  }
 
1258
  //
1259
 
1260
  struct ggml_object {
1261
+ size_t offs;
1262
  size_t size;
1263
 
1264
  struct ggml_object * next;
 
1284
 
1285
  struct ggml_object * objects_begin;
1286
  struct ggml_object * objects_end;
1287
+
1288
+ struct ggml_scratch scratch;
1289
+ struct ggml_scratch scratch_save;
1290
  };
1291
 
1292
  struct ggml_context_container {
 
1349
 
1350
  void ggml_print_object(const struct ggml_object * obj) {
1351
  GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
1352
+ obj->offs, obj->size, (const void *) obj->next);
1353
  }
1354
 
1355
  void ggml_print_objects(const struct ggml_context * ctx) {
 
1545
  }
1546
 
1547
  *ctx = (struct ggml_context) {
1548
+ /*.mem_size =*/ params.mem_size,
1549
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
1550
+ /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
1551
+ /*.n_objects =*/ 0,
1552
+ /*.objects_begin =*/ NULL,
1553
+ /*.objects_end =*/ NULL,
1554
+ /*.scratch =*/ { 0, 0, NULL, },
1555
+ /*.scratch_save =*/ { 0, 0, NULL, },
1556
  };
1557
 
1558
  ggml_assert_aligned(ctx->mem_buffer);
 
1575
  g_state.contexts[i].used = false;
1576
 
1577
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
1578
+ __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
1579
 
1580
  if (ctx->mem_buffer_owned) {
1581
  free(ctx->mem_buffer);
 
1594
  }
1595
 
1596
  size_t ggml_used_mem(const struct ggml_context * ctx) {
1597
+ return ctx->objects_end->offs + ctx->objects_end->size;
1598
+ }
1599
+
1600
+ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
1601
+ const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
1602
+
1603
+ ctx->scratch = scratch;
1604
+
1605
+ return result;
1606
  }
1607
 
1608
  ////////////////////////////////////////////////////////////////////////////////
 
1616
  // always insert objects at the end of the context's memory pool
1617
  struct ggml_object * obj_cur = ctx->objects_end;
1618
 
1619
+ const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
1620
+ const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
1621
+ const size_t cur_end = cur_offs + cur_size;
1622
 
1623
  size_t size_needed = 0;
1624
 
 
1629
  }
1630
  // align to GGML_MEM_ALIGN
1631
  size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
 
 
 
 
 
 
 
 
1632
  }
1633
 
1634
  char * const mem_buffer = ctx->mem_buffer;
 
1635
  struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
1636
 
1637
+ if (ctx->scratch.data == NULL || data != NULL) {
1638
+ size_needed += sizeof(struct ggml_tensor);
1639
+
1640
+ if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
1641
+ GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
1642
+ __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
1643
+ assert(false);
1644
+ return NULL;
1645
+ }
1646
+
1647
+ *obj_new = (struct ggml_object) {
1648
+ .offs = cur_end + GGML_OBJECT_SIZE,
1649
+ .size = size_needed,
1650
+ .next = NULL,
1651
+ };
1652
+ } else {
1653
+ if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
1654
+ GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
1655
+ assert(false);
1656
+ return NULL;
1657
+ }
1658
+
1659
+ if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
1660
+ GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
1661
+ __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
1662
+ assert(false);
1663
+ return NULL;
1664
+ }
1665
+
1666
+ data = (char * const) ctx->scratch.data + ctx->scratch.offs;
1667
+
1668
+ *obj_new = (struct ggml_object) {
1669
+ .offs = cur_end + GGML_OBJECT_SIZE,
1670
+ .size = sizeof(struct ggml_tensor),
1671
+ .next = NULL,
1672
+ };
1673
+
1674
+ //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
1675
+
1676
+ ctx->scratch.offs += size_needed;
1677
+ }
1678
 
1679
  if (obj_cur != NULL) {
1680
  obj_cur->next = obj_new;
 
1685
 
1686
  ctx->objects_end = obj_new;
1687
 
1688
+ //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
1689
 
1690
+ struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
1691
 
1692
  ggml_assert_aligned(result);
1693
 
 
1730
  struct ggml_context * ctx,
1731
  enum ggml_type type,
1732
  int n_dims,
1733
+ const int * ne) {
1734
  return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
1735
  }
1736
 
 
1772
  }
1773
 
1774
  struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
1775
+ ctx->scratch_save = ctx->scratch;
1776
+ ctx->scratch.data = NULL;
1777
+
1778
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
1779
 
1780
+ ctx->scratch = ctx->scratch_save;
1781
+
1782
  ggml_set_i32(result, value);
1783
 
1784
  return result;
1785
  }
1786
 
1787
  struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
1788
+ ctx->scratch_save = ctx->scratch;
1789
+ ctx->scratch.data = NULL;
1790
+
1791
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
1792
 
1793
+ ctx->scratch = ctx->scratch_save;
1794
+
1795
  ggml_set_f32(result, value);
1796
 
1797
  return result;
 
2400
  result->op = GGML_OP_REPEAT;
2401
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
2402
  result->src0 = a;
2403
+ result->src1 = b;
2404
 
2405
  return result;
2406
  }
 
3016
  // TODO: when implement backward, fix this:
3017
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3018
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3019
+ struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
 
 
3020
 
3021
  result->op = GGML_OP_DIAG_MASK_INF;
3022
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
 
4348
  const int ne1 = dst->ne[1];
4349
 
4350
  // TODO: find the optimal values for these
4351
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
4352
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)
4353
+ )) {
4354
  //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
4355
  return true;
4356
  }
 
7339
  node->n_tasks = 1; // TODO: this actually is doing nothing
7340
  // the threads are still spinning
7341
  cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
7342
+ //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
7343
+ //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
7344
+ //printf("cur = %zu\n", cur);
7345
  } else {
7346
  cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
7347
  }
ggml.h CHANGED
@@ -301,6 +301,13 @@ struct ggml_cgraph {
301
  int64_t perf_time_us;
302
  };
303
 
 
 
 
 
 
 
 
304
  struct ggml_init_params {
305
  // memory pool
306
  size_t mem_size; // bytes
@@ -327,6 +334,8 @@ void ggml_free(struct ggml_context * ctx);
327
 
328
  size_t ggml_used_mem(const struct ggml_context * ctx);
329
 
 
 
330
  struct ggml_tensor * ggml_new_tensor(
331
  struct ggml_context * ctx,
332
  enum ggml_type type,
 
301
  int64_t perf_time_us;
302
  };
303
 
304
+ // scratch buffer
305
+ struct ggml_scratch {
306
+ size_t offs;
307
+ size_t size;
308
+ void * data;
309
+ };
310
+
311
  struct ggml_init_params {
312
  // memory pool
313
  size_t mem_size; // bytes
 
334
 
335
  size_t ggml_used_mem(const struct ggml_context * ctx);
336
 
337
+ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
338
+
339
  struct ggml_tensor * ggml_new_tensor(
340
  struct ggml_context * ctx,
341
  enum ggml_type type,
whisper.cpp CHANGED
@@ -103,6 +103,9 @@ static void byteswap_tensor(ggml_tensor * tensor) {
103
  //#define WHISPER_USE_FLASH_FF
104
  #define WHISPER_MAX_DECODERS 16
105
 
 
 
 
106
  // available whisper models
107
  enum e_model {
108
  MODEL_UNKNOWN,
@@ -217,6 +220,38 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
217
 
218
  static const size_t MB = 1024*1024;
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  static const std::map<e_model, size_t> MEM_REQ_MODEL = {
221
  { MODEL_TINY, 74ull*MB },
222
  { MODEL_BASE, 142ull*MB },
@@ -242,35 +277,19 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
242
  };
243
 
244
  static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
245
- { MODEL_TINY, 80ull*MB },
246
- { MODEL_BASE, 128ull*MB },
247
- { MODEL_SMALL, 300ull*MB },
248
- { MODEL_MEDIUM, 680ull*MB },
249
- { MODEL_LARGE, 1100ull*MB },
250
- };
251
-
252
- static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
253
- { MODEL_TINY, 104ull*MB },
254
- { MODEL_BASE, 138ull*MB },
255
- { MODEL_SMALL, 208ull*MB },
256
- { MODEL_MEDIUM, 280ull*MB },
257
- { MODEL_LARGE, 354ull*MB },
258
  };
259
 
260
  static const std::map<e_model, size_t> MEM_REQ_DECODE = {
261
- { MODEL_TINY, 200ull*MB },
262
- { MODEL_BASE, 202ull*MB },
263
- { MODEL_SMALL, 204ull*MB },
264
- { MODEL_MEDIUM, 206ull*MB },
265
- { MODEL_LARGE, 208ull*MB },
266
- };
267
-
268
- static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
269
- { MODEL_TINY, 32ull*MB },
270
- { MODEL_BASE, 44ull*MB },
271
- { MODEL_SMALL, 64ull*MB },
272
- { MODEL_MEDIUM, 84ull*MB },
273
- { MODEL_LARGE, 110ull*MB },
274
  };
275
 
276
  struct whisper_mel {
@@ -557,7 +576,10 @@ struct whisper_context {
557
 
558
  // memory buffers used by encode / decode contexts
559
  std::vector<uint8_t> buf_compute;
560
- std::vector<uint8_t> buf_compute_layer;
 
 
 
561
 
562
  // decode output (2-dimensional array: [n_tokens][n_vocab])
563
  std::vector<float> logits;
@@ -578,6 +600,37 @@ struct whisper_context {
578
 
579
  // [EXPERIMENTAL] speed-up techniques
580
  int32_t exp_n_audio_ctx; // 0 - use default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  };
582
 
583
  template<typename T>
@@ -744,10 +797,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
744
  {
745
  // this is the total memory required to run the inference
746
  const size_t mem_required =
747
- scale*MEM_REQ_MODEL.at (model.type) +
748
- scale*MEM_REQ_KV_CROSS.at (model.type) +
749
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) +
750
- scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
 
 
 
751
 
752
  // this is the memory required by one decoder
753
  const size_t mem_required_decoder =
@@ -783,8 +839,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
783
  fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
784
  }
785
 
786
- wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
787
- wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
 
 
 
 
788
  }
789
 
790
  // load mel filters
@@ -1317,6 +1377,8 @@ static bool whisper_encode(
1317
 
1318
  struct ggml_context * ctx0 = ggml_init(params);
1319
 
 
 
1320
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1321
  assert(mel->type == GGML_TYPE_F32);
1322
  {
@@ -1337,6 +1399,8 @@ static bool whisper_encode(
1337
 
1338
  // convolution + gelu
1339
  {
 
 
1340
  cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1341
  cur = ggml_add(ctx0,
1342
  ggml_repeat(ctx0,
@@ -1346,6 +1410,8 @@ static bool whisper_encode(
1346
 
1347
  cur = ggml_gelu(ctx0, cur);
1348
 
 
 
1349
  cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1350
  cur = ggml_add(ctx0,
1351
  ggml_repeat(ctx0,
@@ -1356,6 +1422,8 @@ static bool whisper_encode(
1356
  cur = ggml_gelu(ctx0, cur);
1357
  }
1358
 
 
 
1359
  // ===================================================================
1360
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1361
  //static int iter = -1;
@@ -1376,6 +1444,7 @@ static bool whisper_encode(
1376
  struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1377
 
1378
  cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
 
1379
  // ===================================================================
1380
 
1381
  // original:
@@ -1386,153 +1455,158 @@ static bool whisper_encode(
1386
  for (int il = 0; il < n_layer; ++il) {
1387
  const auto & layer = model.layers_encoder[il];
1388
 
1389
- // create separate context for each layer to reduce memory usage
1390
-
1391
- struct ggml_init_params paramsL;
1392
- paramsL.mem_size = wctx.buf_compute_layer.size();
1393
- paramsL.mem_buffer = wctx.buf_compute_layer.data();
1394
-
1395
- struct ggml_context * ctxL = ggml_init(paramsL);
1396
-
1397
  // norm
1398
  {
1399
- cur = ggml_norm(ctxL, inpL);
 
 
1400
 
1401
  // cur = ln_0_w*cur + ln_0_b
1402
- cur = ggml_add(ctxL,
1403
- ggml_mul(ctxL,
1404
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1405
  cur),
1406
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1407
  }
1408
 
1409
  // self-attention
1410
  {
1411
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
 
 
1412
  layer.attn_q_w,
1413
  cur);
1414
 
1415
- Qcur = ggml_add(ctxL,
1416
- ggml_repeat(ctxL,
1417
  layer.attn_q_b,
1418
  Qcur),
1419
  Qcur);
1420
 
1421
- //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1422
 
1423
  // note: no bias for Key
1424
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1425
  layer.attn_k_w,
1426
  cur);
1427
 
1428
- //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1429
 
1430
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1431
  layer.attn_v_w,
1432
  cur);
1433
 
1434
- Vcur = ggml_add(ctxL,
1435
- ggml_repeat(ctxL,
1436
  layer.attn_v_b,
1437
  Vcur),
1438
  Vcur);
1439
 
1440
  // ------
1441
 
 
 
1442
  #ifdef WHISPER_USE_FLASH_ATTN
1443
  struct ggml_tensor * Q =
1444
- ggml_permute(ctxL,
1445
- ggml_cpy(ctxL,
1446
  Qcur,
1447
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1448
  0, 2, 1, 3);
1449
 
1450
  struct ggml_tensor * K =
1451
- ggml_permute(ctxL,
1452
- ggml_cpy(ctxL,
1453
  Kcur,
1454
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1455
  0, 2, 1, 3);
1456
 
1457
  struct ggml_tensor * V =
1458
- ggml_cpy(ctxL,
1459
- ggml_permute(ctxL,
1460
- ggml_reshape_3d(ctxL,
1461
  Vcur,
1462
  n_state/n_head, n_head, n_ctx),
1463
  1, 2, 0, 3),
1464
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
1465
  );
1466
 
1467
- struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1468
  #else
1469
  struct ggml_tensor * Q =
1470
- ggml_permute(ctxL,
1471
- ggml_cpy(ctxL,
1472
  Qcur,
1473
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1474
  0, 2, 1, 3);
1475
 
1476
  struct ggml_tensor * K =
1477
- ggml_permute(ctxL,
1478
- ggml_cpy(ctxL,
1479
  Kcur,
1480
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1481
  0, 2, 1, 3);
1482
 
1483
  // K * Q
1484
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1485
 
1486
  struct ggml_tensor * KQ_scaled =
1487
- ggml_scale(ctxL,
1488
  KQ,
1489
- ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1490
  );
1491
 
1492
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
1493
 
1494
  //struct ggml_tensor * V_trans =
1495
- // ggml_permute(ctxL,
1496
- // ggml_cpy(ctxL,
1497
  // Vcur,
1498
- // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1499
  // 1, 2, 0, 3);
1500
 
1501
- //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1502
 
1503
  struct ggml_tensor * V =
1504
- ggml_cpy(ctxL,
1505
- ggml_permute(ctxL,
1506
- ggml_reshape_3d(ctxL,
1507
  Vcur,
1508
  n_state/n_head, n_head, n_ctx),
1509
  0, 2, 1, 3),
1510
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
1511
  );
1512
 
1513
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1514
  #endif
 
1515
 
1516
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1517
 
1518
- cur = ggml_cpy(ctxL,
1519
  KQV_merged,
1520
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
1521
  }
1522
 
1523
  // projection
1524
  {
1525
- cur = ggml_mul_mat(ctxL,
 
 
1526
  layer.attn_ln_1_w,
1527
  cur);
1528
 
1529
- cur = ggml_add(ctxL,
1530
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
 
 
1531
  cur);
1532
  }
1533
 
 
 
1534
  // add the input
1535
- cur = ggml_add(ctxL, cur, inpL);
1536
 
1537
  struct ggml_tensor * inpFF = cur;
1538
 
@@ -1540,75 +1614,75 @@ static bool whisper_encode(
1540
  {
1541
  // norm
1542
  {
1543
- cur = ggml_norm(ctxL, inpFF);
 
 
 
 
1544
 
1545
  // cur = mlp_ln_w*cur + mlp_ln_b
1546
- cur = ggml_add(ctxL,
1547
- ggml_mul(ctxL,
1548
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1549
  cur),
1550
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1551
  }
1552
 
1553
  #ifdef WHISPER_USE_FLASH_FF
1554
- cur = ggml_flash_ff(ctxL,
1555
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
 
 
1556
  layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1557
  #else
 
 
1558
  // fully connected
1559
- cur = ggml_mul_mat(ctxL,
1560
  layer.mlp_0_w,
1561
  cur);
1562
 
1563
- cur = ggml_add(ctxL,
1564
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
 
 
1565
  cur);
1566
 
 
 
1567
  // GELU activation
1568
- cur = ggml_gelu(ctxL, cur);
 
 
1569
 
1570
  // projection
1571
- cur = ggml_mul_mat(ctxL,
1572
  layer.mlp_1_w,
1573
  cur);
1574
 
1575
- cur = ggml_add(ctxL,
1576
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
 
 
1577
  cur);
1578
  #endif
1579
  }
1580
 
1581
- // output from this layer
1582
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1583
-
1584
- {
1585
- struct ggml_cgraph gf = {};
1586
- gf.n_threads = n_threads;
1587
-
1588
- ggml_build_forward_expand(&gf, inpO);
1589
- ggml_graph_compute (ctxL, &gf);
1590
-
1591
- //ggml_graph_print(&gf);
1592
- }
1593
-
1594
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1595
- // input for next layer (inpO -> inpL)
1596
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1597
- inpL->op = GGML_OP_NONE;
1598
- inpL->src0 = nullptr;
1599
- inpL->src1 = nullptr;
1600
 
1601
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1602
-
1603
- ggml_free(ctxL);
1604
  }
1605
 
1606
  cur = inpL;
1607
 
1608
  // norm
1609
  {
 
 
1610
  cur = ggml_norm(ctx0, cur);
1611
 
 
 
1612
  // cur = ln_f_g*cur + ln_f_b
1613
  cur = ggml_add(ctx0,
1614
  ggml_mul(ctx0,
@@ -1617,6 +1691,8 @@ static bool whisper_encode(
1617
  ggml_repeat(ctx0, model.e_ln_b, cur));
1618
  }
1619
 
 
 
1620
  // run the computation
1621
  {
1622
  struct ggml_cgraph gf = {};
@@ -1655,12 +1731,16 @@ static bool whisper_encode(
1655
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1656
  auto & layer = model.layers_decoder[il];
1657
 
 
 
1658
  struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1659
  layer.cross_attn_k_w,
1660
  cur);
1661
 
1662
  Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1663
 
 
 
1664
  struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1665
  layer.cross_attn_v_w,
1666
  cur);
@@ -1671,6 +1751,8 @@ static bool whisper_encode(
1671
  Vcross),
1672
  Vcross);
1673
 
 
 
1674
  //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1675
  //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1676
  struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
@@ -1686,7 +1768,12 @@ static bool whisper_encode(
1686
 
1687
  ////////////////////////////////////////////////////////////////////////////
1688
 
1689
- //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
 
 
 
 
 
1690
 
1691
  ggml_free(ctx0);
1692
 
@@ -1698,7 +1785,7 @@ static bool whisper_encode(
1698
 
1699
  // evaluate the decoder
1700
  //
1701
- // given text prompt + audio features -> predicts the probabilities for the next token
1702
  //
1703
  // - model: the model
1704
  // - n_threads: number of threads to use
@@ -1742,6 +1829,9 @@ static bool whisper_decode(
1742
 
1743
  struct ggml_context * ctx0 = ggml_init(params);
1744
 
 
 
 
1745
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1746
  memcpy(embd->data, tokens, N*ggml_element_size(embd));
1747
 
@@ -1750,6 +1840,8 @@ static bool whisper_decode(
1750
  ((int32_t *) position->data)[i] = n_past + i;
1751
  }
1752
 
 
 
1753
  // token encoding + position encoding
1754
  struct ggml_tensor * cur =
1755
  ggml_add(ctx0,
@@ -1761,211 +1853,248 @@ static bool whisper_decode(
1761
  for (int il = 0; il < n_layer; ++il) {
1762
  const auto & layer = model.layers_decoder[il];
1763
 
1764
- struct ggml_init_params paramsL;
1765
- paramsL.mem_size = wctx.buf_compute_layer.size();
1766
- paramsL.mem_buffer = wctx.buf_compute_layer.data();
1767
-
1768
- struct ggml_context * ctxL = ggml_init(paramsL);
1769
- struct ggml_cgraph gf = {};
1770
- gf.n_threads = n_threads;
1771
-
1772
  // norm
1773
  {
1774
- cur = ggml_norm(ctxL, inpL);
 
 
1775
 
1776
  // cur = ln_0_w*cur + ln_0_b
1777
- cur = ggml_add(ctxL,
1778
- ggml_mul(ctxL,
1779
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1780
  cur),
1781
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1782
  }
1783
 
1784
  // self-attention
1785
  {
1786
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
 
 
1787
  layer.attn_q_w,
1788
  cur);
1789
 
1790
- Qcur = ggml_add(ctxL,
1791
- ggml_repeat(ctxL,
1792
  layer.attn_q_b,
1793
  Qcur),
1794
  Qcur);
1795
 
1796
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1797
 
1798
  // note: no bias for Key
1799
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1800
  layer.attn_k_w,
1801
  cur);
1802
 
1803
- Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1804
 
1805
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1806
  layer.attn_v_w,
1807
  cur);
1808
 
1809
- Vcur = ggml_add(ctxL,
1810
- ggml_repeat(ctxL,
1811
  layer.attn_v_b,
1812
  Vcur),
1813
  Vcur);
1814
 
1815
  // store key and value to memory
1816
  {
1817
- struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1818
- struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1819
 
1820
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1821
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1822
  }
1823
 
1824
  // ------
1825
 
 
 
1826
  struct ggml_tensor * Q =
1827
- ggml_permute(ctxL,
1828
- ggml_cpy(ctxL,
1829
  Qcur,
1830
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1831
  0, 2, 1, 3);
1832
 
1833
  struct ggml_tensor * K =
1834
- ggml_permute(ctxL,
1835
- ggml_reshape_3d(ctxL,
1836
- ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
1837
  n_state/n_head, n_head, n_past + N),
1838
  0, 2, 1, 3);
1839
 
 
 
1840
  // K * Q
1841
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
 
 
1842
 
1843
  //struct ggml_tensor * KQ_scaled =
1844
- // ggml_scale(ctxL,
1845
  // KQ,
1846
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1847
  // );
1848
 
1849
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1850
 
1851
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
 
 
 
 
1852
 
1853
  struct ggml_tensor * V_trans =
1854
- ggml_permute(ctxL,
1855
- ggml_reshape_3d(ctxL,
1856
- ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1857
  n_state/n_head, n_head, n_past + N),
1858
  1, 2, 0, 3);
1859
 
1860
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1861
 
1862
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1863
 
1864
- cur = ggml_cpy(ctxL,
 
 
1865
  KQV_merged,
1866
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1867
  }
1868
 
 
1869
  {
1870
- cur = ggml_mul_mat(ctxL,
 
 
1871
  layer.attn_ln_1_w,
1872
  cur);
1873
 
1874
- cur = ggml_add(ctxL,
1875
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
 
 
1876
  cur);
1877
  }
1878
 
 
 
1879
  // add the input
1880
- struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1881
 
1882
  // norm
1883
  {
1884
- cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
 
 
 
 
1885
 
1886
  // cur = ln_0_w*cur + ln_0_b
1887
- cur = ggml_add(ctxL,
1888
- ggml_mul(ctxL,
1889
- ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1890
  cur),
1891
- ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1892
  }
1893
 
1894
  // cross-attention
1895
  {
1896
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
 
 
1897
  layer.cross_attn_q_w,
1898
  cur);
1899
 
1900
- Qcur = ggml_add(ctxL,
1901
- ggml_repeat(ctxL,
1902
  layer.cross_attn_q_b,
1903
  Qcur),
1904
  Qcur);
1905
 
1906
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1907
 
1908
  // Kcross is already scaled
1909
  struct ggml_tensor * Kcross =
1910
- ggml_reshape_3d(ctxL,
1911
- ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
1912
  n_state/n_head, n_head, M);
1913
 
1914
  struct ggml_tensor * Vcross =
1915
- ggml_reshape_3d(ctxL,
1916
- ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
1917
  n_state/n_head, n_head, M);
1918
 
 
 
1919
  // ------
1920
 
 
 
1921
  struct ggml_tensor * Q =
1922
- ggml_permute(ctxL,
1923
- ggml_cpy(ctxL,
1924
  Qcur,
1925
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1926
  0, 2, 1, 3);
1927
 
1928
- struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
 
 
1929
 
1930
  // K * Q
1931
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1932
 
1933
  //struct ggml_tensor * KQ_scaled =
1934
- // ggml_scale(ctxL,
1935
  // KQ,
1936
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1937
  // );
1938
 
1939
  // no masking for cross-attention
1940
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
 
 
1941
 
1942
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1943
 
1944
- struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1945
 
1946
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1947
 
1948
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
 
 
1949
 
1950
  // cur = KQV_merged.contiguous().view(n_state, N)
1951
- cur = ggml_cpy(ctxL,
1952
  KQV_merged,
1953
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1954
  }
1955
 
1956
  // projection
1957
  {
1958
- cur = ggml_mul_mat(ctxL,
 
 
1959
  layer.cross_attn_ln_1_w,
1960
  cur);
1961
 
1962
- cur = ggml_add(ctxL,
1963
- ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
 
 
1964
  cur);
1965
  }
1966
 
 
 
1967
  // add the input
1968
- cur = ggml_add(ctxL, cur, inpCA);
1969
 
1970
  struct ggml_tensor * inpFF = cur;
1971
 
@@ -1973,68 +2102,67 @@ static bool whisper_decode(
1973
  {
1974
  // norm
1975
  {
1976
- cur = ggml_norm(ctxL, inpFF);
 
 
 
 
1977
 
1978
  // cur = mlp_ln_w*cur + mlp_ln_b
1979
- cur = ggml_add(ctxL,
1980
- ggml_mul(ctxL,
1981
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1982
  cur),
1983
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1984
  }
1985
 
 
 
1986
  // fully connected
1987
- cur = ggml_mul_mat(ctxL,
1988
  layer.mlp_0_w,
1989
  cur);
1990
 
1991
- cur = ggml_add(ctxL,
1992
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
 
 
1993
  cur);
1994
 
 
 
1995
  // GELU activation
1996
- cur = ggml_gelu(ctxL, cur);
 
 
1997
 
1998
  // projection
1999
- cur = ggml_mul_mat(ctxL,
2000
  layer.mlp_1_w,
2001
  cur);
2002
 
2003
- cur = ggml_add(ctxL,
2004
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
2005
- cur);
2006
- }
2007
 
2008
- // output from this layer
2009
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
2010
-
2011
- {
2012
- ggml_build_forward_expand(&gf, inpO);
2013
- ggml_graph_compute (ctxL, &gf);
2014
-
2015
- //ggml_graph_print(&gf);
2016
  }
2017
 
2018
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
2019
- // input for next layer (inpO -> inpL)
2020
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
2021
- inpL->op = GGML_OP_NONE;
2022
- inpL->src0 = nullptr;
2023
- inpL->src1 = nullptr;
2024
 
2025
- if (N > 1) {
2026
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
2027
- }
2028
-
2029
- ggml_free(ctxL);
2030
  }
2031
 
2032
  cur = inpL;
2033
 
2034
  // norm
2035
  {
 
 
2036
  cur = ggml_norm(ctx0, cur);
2037
 
 
 
2038
  cur = ggml_add(ctx0,
2039
  ggml_mul(ctx0,
2040
  ggml_repeat(ctx0, model.d_ln_w, cur),
@@ -2042,24 +2170,38 @@ static bool whisper_decode(
2042
  ggml_repeat(ctx0, model.d_ln_b, cur));
2043
  }
2044
 
 
 
 
 
 
 
 
2045
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2046
 
 
 
2047
  // run the computation
2048
  {
2049
- struct ggml_cgraph gf = {};
2050
- gf.n_threads = n_threads;
2051
-
2052
  ggml_build_forward_expand(&gf, logits);
2053
  ggml_graph_compute (ctx0, &gf);
2054
  }
2055
 
2056
- logits_out.resize(N*n_vocab);
2057
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
 
 
 
 
 
2058
 
2059
  if (N > 1) {
2060
- //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
2061
- //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
2062
- //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
 
 
 
2063
  }
2064
 
2065
  ggml_free(ctx0);
 
103
  //#define WHISPER_USE_FLASH_FF
104
  #define WHISPER_MAX_DECODERS 16
105
 
106
+ #define WHISPER_USE_SCRATCH
107
+ #define WHISPER_MAX_SCRATCH_BUFFERS 16
108
+
109
  // available whisper models
110
  enum e_model {
111
  MODEL_UNKNOWN,
 
220
 
221
  static const size_t MB = 1024*1024;
222
 
223
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
224
+ { MODEL_TINY, 12ull*MB },
225
+ { MODEL_BASE, 15ull*MB },
226
+ { MODEL_SMALL, 23ull*MB },
227
+ { MODEL_MEDIUM, 31ull*MB },
228
+ { MODEL_LARGE, 38ull*MB },
229
+ };
230
+
231
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
232
+ { MODEL_TINY, 18ull*MB },
233
+ { MODEL_BASE, 24ull*MB },
234
+ { MODEL_SMALL, 36ull*MB },
235
+ { MODEL_MEDIUM, 48ull*MB },
236
+ { MODEL_LARGE, 60ull*MB },
237
+ };
238
+
239
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
240
+ { MODEL_TINY, 4ull*MB },
241
+ { MODEL_BASE, 4ull*MB },
242
+ { MODEL_SMALL, 6ull*MB },
243
+ { MODEL_MEDIUM, 7ull*MB },
244
+ { MODEL_LARGE, 9ull*MB },
245
+ };
246
+
247
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
248
+ { MODEL_TINY, 4ull*MB },
249
+ { MODEL_BASE, 4ull*MB },
250
+ { MODEL_SMALL, 6ull*MB },
251
+ { MODEL_MEDIUM, 7ull*MB },
252
+ { MODEL_LARGE, 9ull*MB },
253
+ };
254
+
255
  static const std::map<e_model, size_t> MEM_REQ_MODEL = {
256
  { MODEL_TINY, 74ull*MB },
257
  { MODEL_BASE, 142ull*MB },
 
277
  };
278
 
279
  static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
280
+ { MODEL_TINY, 6ull*MB },
281
+ { MODEL_BASE, 8ull*MB },
282
+ { MODEL_SMALL, 13ull*MB },
283
+ { MODEL_MEDIUM, 22ull*MB },
284
+ { MODEL_LARGE, 33ull*MB },
 
 
 
 
 
 
 
 
285
  };
286
 
287
  static const std::map<e_model, size_t> MEM_REQ_DECODE = {
288
+ { MODEL_TINY, 3ull*MB },
289
+ { MODEL_BASE, 5ull*MB },
290
+ { MODEL_SMALL, 10ull*MB },
291
+ { MODEL_MEDIUM, 18ull*MB },
292
+ { MODEL_LARGE, 27ull*MB },
 
 
 
 
 
 
 
 
293
  };
294
 
295
  struct whisper_mel {
 
576
 
577
  // memory buffers used by encode / decode contexts
578
  std::vector<uint8_t> buf_compute;
579
+ std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
580
+
581
+ int buf_last = 0;
582
+ size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
583
 
584
  // decode output (2-dimensional array: [n_tokens][n_vocab])
585
  std::vector<float> logits;
 
600
 
601
  // [EXPERIMENTAL] speed-up techniques
602
  int32_t exp_n_audio_ctx; // 0 - use default
603
+
604
+ void use_buf(struct ggml_context * ctx, int i) {
605
+ #if defined(WHISPER_USE_SCRATCH)
606
+ size_t last_size = 0;
607
+
608
+ if (i == -1) {
609
+ last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
610
+ } else {
611
+ auto & buf = buf_scratch[i];
612
+ last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
613
+ }
614
+
615
+ if (buf_last >= 0) {
616
+ buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
617
+ }
618
+
619
+ buf_last = i;
620
+ #else
621
+ (void) i;
622
+ (void) ctx;
623
+ #endif
624
+ }
625
+
626
+ size_t get_buf_max_mem(int i) const {
627
+ #if defined(WHISPER_USE_SCRATCH)
628
+ return buf_max_size[i];
629
+ #else
630
+ (void) i;
631
+ return 0;
632
+ #endif
633
+ }
634
  };
635
 
636
  template<typename T>
 
797
  {
798
  // this is the total memory required to run the inference
799
  const size_t mem_required =
800
+ MEM_REQ_SCRATCH0.at (model.type) +
801
+ MEM_REQ_SCRATCH1.at (model.type) +
802
+ MEM_REQ_SCRATCH2.at (model.type) +
803
+ MEM_REQ_SCRATCH3.at (model.type) +
804
+ scale*MEM_REQ_MODEL.at (model.type) +
805
+ scale*MEM_REQ_KV_CROSS.at(model.type) +
806
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
807
 
808
  // this is the memory required by one decoder
809
  const size_t mem_required_decoder =
 
839
  fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
840
  }
841
 
842
+ wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
843
+
844
+ wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
845
+ wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
846
+ wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
847
+ wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
848
  }
849
 
850
  // load mel filters
 
1377
 
1378
  struct ggml_context * ctx0 = ggml_init(params);
1379
 
1380
+ wctx.use_buf(ctx0, 0);
1381
+
1382
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1383
  assert(mel->type == GGML_TYPE_F32);
1384
  {
 
1399
 
1400
  // convolution + gelu
1401
  {
1402
+ wctx.use_buf(ctx0, 1);
1403
+
1404
  cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1405
  cur = ggml_add(ctx0,
1406
  ggml_repeat(ctx0,
 
1410
 
1411
  cur = ggml_gelu(ctx0, cur);
1412
 
1413
+ wctx.use_buf(ctx0, 0);
1414
+
1415
  cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1416
  cur = ggml_add(ctx0,
1417
  ggml_repeat(ctx0,
 
1422
  cur = ggml_gelu(ctx0, cur);
1423
  }
1424
 
1425
+ wctx.use_buf(ctx0, 3);
1426
+
1427
  // ===================================================================
1428
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1429
  //static int iter = -1;
 
1444
  struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1445
 
1446
  cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1447
+
1448
  // ===================================================================
1449
 
1450
  // original:
 
1455
  for (int il = 0; il < n_layer; ++il) {
1456
  const auto & layer = model.layers_encoder[il];
1457
 
 
 
 
 
 
 
 
 
1458
  // norm
1459
  {
1460
+ wctx.use_buf(ctx0, 0);
1461
+
1462
+ cur = ggml_norm(ctx0, inpL);
1463
 
1464
  // cur = ln_0_w*cur + ln_0_b
1465
+ cur = ggml_add(ctx0,
1466
+ ggml_mul(ctx0,
1467
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1468
  cur),
1469
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1470
  }
1471
 
1472
  // self-attention
1473
  {
1474
+ wctx.use_buf(ctx0, 1);
1475
+
1476
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1477
  layer.attn_q_w,
1478
  cur);
1479
 
1480
+ Qcur = ggml_add(ctx0,
1481
+ ggml_repeat(ctx0,
1482
  layer.attn_q_b,
1483
  Qcur),
1484
  Qcur);
1485
 
1486
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1487
 
1488
  // note: no bias for Key
1489
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1490
  layer.attn_k_w,
1491
  cur);
1492
 
1493
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1494
 
1495
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1496
  layer.attn_v_w,
1497
  cur);
1498
 
1499
+ Vcur = ggml_add(ctx0,
1500
+ ggml_repeat(ctx0,
1501
  layer.attn_v_b,
1502
  Vcur),
1503
  Vcur);
1504
 
1505
  // ------
1506
 
1507
+ wctx.use_buf(ctx0, 0);
1508
+
1509
  #ifdef WHISPER_USE_FLASH_ATTN
1510
  struct ggml_tensor * Q =
1511
+ ggml_permute(ctx0,
1512
+ ggml_cpy(ctx0,
1513
  Qcur,
1514
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1515
  0, 2, 1, 3);
1516
 
1517
  struct ggml_tensor * K =
1518
+ ggml_permute(ctx0,
1519
+ ggml_cpy(ctx0,
1520
  Kcur,
1521
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1522
  0, 2, 1, 3);
1523
 
1524
  struct ggml_tensor * V =
1525
+ ggml_cpy(ctx0,
1526
+ ggml_permute(ctx0,
1527
+ ggml_reshape_3d(ctx0,
1528
  Vcur,
1529
  n_state/n_head, n_head, n_ctx),
1530
  1, 2, 0, 3),
1531
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
1532
  );
1533
 
1534
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1535
  #else
1536
  struct ggml_tensor * Q =
1537
+ ggml_permute(ctx0,
1538
+ ggml_cpy(ctx0,
1539
  Qcur,
1540
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1541
  0, 2, 1, 3);
1542
 
1543
  struct ggml_tensor * K =
1544
+ ggml_permute(ctx0,
1545
+ ggml_cpy(ctx0,
1546
  Kcur,
1547
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1548
  0, 2, 1, 3);
1549
 
1550
  // K * Q
1551
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1552
 
1553
  struct ggml_tensor * KQ_scaled =
1554
+ ggml_scale(ctx0,
1555
  KQ,
1556
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1557
  );
1558
 
1559
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1560
 
1561
  //struct ggml_tensor * V_trans =
1562
+ // ggml_permute(ctx0,
1563
+ // ggml_cpy(ctx0,
1564
  // Vcur,
1565
+ // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1566
  // 1, 2, 0, 3);
1567
 
1568
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1569
 
1570
  struct ggml_tensor * V =
1571
+ ggml_cpy(ctx0,
1572
+ ggml_permute(ctx0,
1573
+ ggml_reshape_3d(ctx0,
1574
  Vcur,
1575
  n_state/n_head, n_head, n_ctx),
1576
  0, 2, 1, 3),
1577
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1578
  );
1579
 
1580
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
1581
  #endif
1582
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1583
 
1584
+ wctx.use_buf(ctx0, 1);
1585
 
1586
+ cur = ggml_cpy(ctx0,
1587
  KQV_merged,
1588
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1589
  }
1590
 
1591
  // projection
1592
  {
1593
+ wctx.use_buf(ctx0, 0);
1594
+
1595
+ cur = ggml_mul_mat(ctx0,
1596
  layer.attn_ln_1_w,
1597
  cur);
1598
 
1599
+ wctx.use_buf(ctx0, 1);
1600
+
1601
+ cur = ggml_add(ctx0,
1602
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1603
  cur);
1604
  }
1605
 
1606
+ wctx.use_buf(ctx0, 2);
1607
+
1608
  // add the input
1609
+ cur = ggml_add(ctx0, cur, inpL);
1610
 
1611
  struct ggml_tensor * inpFF = cur;
1612
 
 
1614
  {
1615
  // norm
1616
  {
1617
+ wctx.use_buf(ctx0, 0);
1618
+
1619
+ cur = ggml_norm(ctx0, inpFF);
1620
+
1621
+ wctx.use_buf(ctx0, 1);
1622
 
1623
  // cur = mlp_ln_w*cur + mlp_ln_b
1624
+ cur = ggml_add(ctx0,
1625
+ ggml_mul(ctx0,
1626
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1627
  cur),
1628
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1629
  }
1630
 
1631
  #ifdef WHISPER_USE_FLASH_FF
1632
+ wctx.use_buf(ctx0, 0);
1633
+
1634
+ cur = ggml_flash_ff(ctx0,
1635
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1636
  layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1637
  #else
1638
+ wctx.use_buf(ctx0, 0);
1639
+
1640
  // fully connected
1641
+ cur = ggml_mul_mat(ctx0,
1642
  layer.mlp_0_w,
1643
  cur);
1644
 
1645
+ wctx.use_buf(ctx0, 1);
1646
+
1647
+ cur = ggml_add(ctx0,
1648
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
1649
  cur);
1650
 
1651
+ wctx.use_buf(ctx0, 0);
1652
+
1653
  // GELU activation
1654
+ cur = ggml_gelu(ctx0, cur);
1655
+
1656
+ wctx.use_buf(ctx0, 1);
1657
 
1658
  // projection
1659
+ cur = ggml_mul_mat(ctx0,
1660
  layer.mlp_1_w,
1661
  cur);
1662
 
1663
+ wctx.use_buf(ctx0, 0);
1664
+
1665
+ cur = ggml_add(ctx0,
1666
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
1667
  cur);
1668
  #endif
1669
  }
1670
 
1671
+ wctx.use_buf(ctx0, 3);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1672
 
1673
+ inpL = ggml_add(ctx0, cur, inpFF);
 
 
1674
  }
1675
 
1676
  cur = inpL;
1677
 
1678
  // norm
1679
  {
1680
+ wctx.use_buf(ctx0, 0);
1681
+
1682
  cur = ggml_norm(ctx0, cur);
1683
 
1684
+ wctx.use_buf(ctx0, 1);
1685
+
1686
  // cur = ln_f_g*cur + ln_f_b
1687
  cur = ggml_add(ctx0,
1688
  ggml_mul(ctx0,
 
1691
  ggml_repeat(ctx0, model.e_ln_b, cur));
1692
  }
1693
 
1694
+ wctx.use_buf(ctx0, -1);
1695
+
1696
  // run the computation
1697
  {
1698
  struct ggml_cgraph gf = {};
 
1731
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1732
  auto & layer = model.layers_decoder[il];
1733
 
1734
+ wctx.use_buf(ctx0, 0);
1735
+
1736
  struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1737
  layer.cross_attn_k_w,
1738
  cur);
1739
 
1740
  Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1741
 
1742
+ wctx.use_buf(ctx0, 1);
1743
+
1744
  struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1745
  layer.cross_attn_v_w,
1746
  cur);
 
1751
  Vcross),
1752
  Vcross);
1753
 
1754
+ wctx.use_buf(ctx0, -1);
1755
+
1756
  //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1757
  //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1758
  struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
 
1768
 
1769
  ////////////////////////////////////////////////////////////////////////////
1770
 
1771
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1772
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
1773
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
1774
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
1775
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
1776
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
1777
 
1778
  ggml_free(ctx0);
1779
 
 
1785
 
1786
  // evaluate the decoder
1787
  //
1788
+ // given text prompt + audio features -> computes the logits for the next token
1789
  //
1790
  // - model: the model
1791
  // - n_threads: number of threads to use
 
1829
 
1830
  struct ggml_context * ctx0 = ggml_init(params);
1831
 
1832
+ struct ggml_cgraph gf = {};
1833
+ gf.n_threads = n_threads;
1834
+
1835
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1836
  memcpy(embd->data, tokens, N*ggml_element_size(embd));
1837
 
 
1840
  ((int32_t *) position->data)[i] = n_past + i;
1841
  }
1842
 
1843
+ wctx.use_buf(ctx0, 3);
1844
+
1845
  // token encoding + position encoding
1846
  struct ggml_tensor * cur =
1847
  ggml_add(ctx0,
 
1853
  for (int il = 0; il < n_layer; ++il) {
1854
  const auto & layer = model.layers_decoder[il];
1855
 
 
 
 
 
 
 
 
 
1856
  // norm
1857
  {
1858
+ wctx.use_buf(ctx0, 0);
1859
+
1860
+ cur = ggml_norm(ctx0, inpL);
1861
 
1862
  // cur = ln_0_w*cur + ln_0_b
1863
+ cur = ggml_add(ctx0,
1864
+ ggml_mul(ctx0,
1865
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1866
  cur),
1867
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1868
  }
1869
 
1870
  // self-attention
1871
  {
1872
+ wctx.use_buf(ctx0, 1);
1873
+
1874
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1875
  layer.attn_q_w,
1876
  cur);
1877
 
1878
+ Qcur = ggml_add(ctx0,
1879
+ ggml_repeat(ctx0,
1880
  layer.attn_q_b,
1881
  Qcur),
1882
  Qcur);
1883
 
1884
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1885
 
1886
  // note: no bias for Key
1887
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1888
  layer.attn_k_w,
1889
  cur);
1890
 
1891
+ Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1892
 
1893
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1894
  layer.attn_v_w,
1895
  cur);
1896
 
1897
+ Vcur = ggml_add(ctx0,
1898
+ ggml_repeat(ctx0,
1899
  layer.attn_v_b,
1900
  Vcur),
1901
  Vcur);
1902
 
1903
  // store key and value to memory
1904
  {
1905
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1906
+ struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1907
 
1908
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1909
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1910
  }
1911
 
1912
  // ------
1913
 
1914
+ wctx.use_buf(ctx0, 0);
1915
+
1916
  struct ggml_tensor * Q =
1917
+ ggml_permute(ctx0,
1918
+ ggml_cpy(ctx0,
1919
  Qcur,
1920
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1921
  0, 2, 1, 3);
1922
 
1923
  struct ggml_tensor * K =
1924
+ ggml_permute(ctx0,
1925
+ ggml_reshape_3d(ctx0,
1926
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
1927
  n_state/n_head, n_head, n_past + N),
1928
  0, 2, 1, 3);
1929
 
1930
+ wctx.use_buf(ctx0, 1);
1931
+
1932
  // K * Q
1933
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1934
+
1935
+ wctx.use_buf(ctx0, 0);
1936
 
1937
  //struct ggml_tensor * KQ_scaled =
1938
+ // ggml_scale(ctx0,
1939
  // KQ,
1940
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1941
  // );
1942
 
1943
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1944
 
1945
+ wctx.use_buf(ctx0, 1);
1946
+
1947
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1948
+
1949
+ wctx.use_buf(ctx0, 0);
1950
 
1951
  struct ggml_tensor * V_trans =
1952
+ ggml_permute(ctx0,
1953
+ ggml_reshape_3d(ctx0,
1954
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1955
  n_state/n_head, n_head, n_past + N),
1956
  1, 2, 0, 3);
1957
 
1958
+ wctx.use_buf(ctx0, 1);
1959
 
1960
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1961
 
1962
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1963
+
1964
+ cur = ggml_cpy(ctx0,
1965
  KQV_merged,
1966
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
1967
  }
1968
 
1969
+ // projection
1970
  {
1971
+ wctx.use_buf(ctx0, 0);
1972
+
1973
+ cur = ggml_mul_mat(ctx0,
1974
  layer.attn_ln_1_w,
1975
  cur);
1976
 
1977
+ wctx.use_buf(ctx0, 1);
1978
+
1979
+ cur = ggml_add(ctx0,
1980
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1981
  cur);
1982
  }
1983
 
1984
+ wctx.use_buf(ctx0, 2);
1985
+
1986
  // add the input
1987
+ struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1988
 
1989
  // norm
1990
  {
1991
+ wctx.use_buf(ctx0, 0);
1992
+
1993
+ cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1994
+
1995
+ wctx.use_buf(ctx0, 1);
1996
 
1997
  // cur = ln_0_w*cur + ln_0_b
1998
+ cur = ggml_add(ctx0,
1999
+ ggml_mul(ctx0,
2000
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2001
  cur),
2002
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2003
  }
2004
 
2005
  // cross-attention
2006
  {
2007
+ wctx.use_buf(ctx0, 0);
2008
+
2009
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2010
  layer.cross_attn_q_w,
2011
  cur);
2012
 
2013
+ Qcur = ggml_add(ctx0,
2014
+ ggml_repeat(ctx0,
2015
  layer.cross_attn_q_b,
2016
  Qcur),
2017
  Qcur);
2018
 
2019
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2020
 
2021
  // Kcross is already scaled
2022
  struct ggml_tensor * Kcross =
2023
+ ggml_reshape_3d(ctx0,
2024
+ ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
2025
  n_state/n_head, n_head, M);
2026
 
2027
  struct ggml_tensor * Vcross =
2028
+ ggml_reshape_3d(ctx0,
2029
+ ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2030
  n_state/n_head, n_head, M);
2031
 
2032
+ struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2033
+
2034
  // ------
2035
 
2036
+ wctx.use_buf(ctx0, 1);
2037
+
2038
  struct ggml_tensor * Q =
2039
+ ggml_permute(ctx0,
2040
+ ggml_cpy(ctx0,
2041
  Qcur,
2042
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2043
  0, 2, 1, 3);
2044
 
2045
+ struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2046
+
2047
+ wctx.use_buf(ctx0, 0);
2048
 
2049
  // K * Q
2050
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2051
 
2052
  //struct ggml_tensor * KQ_scaled =
2053
+ // ggml_scale(ctx0,
2054
  // KQ,
2055
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2056
  // );
2057
 
2058
  // no masking for cross-attention
2059
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2060
+
2061
+ wctx.use_buf(ctx0, 1);
2062
 
2063
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2064
 
2065
+ wctx.use_buf(ctx0, 0);
2066
 
2067
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2068
 
2069
+ wctx.use_buf(ctx0, 1);
2070
+
2071
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2072
 
2073
  // cur = KQV_merged.contiguous().view(n_state, N)
2074
+ cur = ggml_cpy(ctx0,
2075
  KQV_merged,
2076
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2077
  }
2078
 
2079
  // projection
2080
  {
2081
+ wctx.use_buf(ctx0, 0);
2082
+
2083
+ cur = ggml_mul_mat(ctx0,
2084
  layer.cross_attn_ln_1_w,
2085
  cur);
2086
 
2087
+ wctx.use_buf(ctx0, 1);
2088
+
2089
+ cur = ggml_add(ctx0,
2090
+ ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2091
  cur);
2092
  }
2093
 
2094
+ wctx.use_buf(ctx0, 2);
2095
+
2096
  // add the input
2097
+ cur = ggml_add(ctx0, cur, inpCA);
2098
 
2099
  struct ggml_tensor * inpFF = cur;
2100
 
 
2102
  {
2103
  // norm
2104
  {
2105
+ wctx.use_buf(ctx0, 0);
2106
+
2107
+ cur = ggml_norm(ctx0, inpFF);
2108
+
2109
+ wctx.use_buf(ctx0, 1);
2110
 
2111
  // cur = mlp_ln_w*cur + mlp_ln_b
2112
+ cur = ggml_add(ctx0,
2113
+ ggml_mul(ctx0,
2114
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2115
  cur),
2116
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2117
  }
2118
 
2119
+ wctx.use_buf(ctx0, 0);
2120
+
2121
  // fully connected
2122
+ cur = ggml_mul_mat(ctx0,
2123
  layer.mlp_0_w,
2124
  cur);
2125
 
2126
+ wctx.use_buf(ctx0, 1);
2127
+
2128
+ cur = ggml_add(ctx0,
2129
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
2130
  cur);
2131
 
2132
+ wctx.use_buf(ctx0, 0);
2133
+
2134
  // GELU activation
2135
+ cur = ggml_gelu(ctx0, cur);
2136
+
2137
+ wctx.use_buf(ctx0, 1);
2138
 
2139
  // projection
2140
+ cur = ggml_mul_mat(ctx0,
2141
  layer.mlp_1_w,
2142
  cur);
2143
 
2144
+ wctx.use_buf(ctx0, 0);
 
 
 
2145
 
2146
+ cur = ggml_add(ctx0,
2147
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
2148
+ cur);
 
 
 
 
 
2149
  }
2150
 
2151
+ wctx.use_buf(ctx0, 3);
 
 
 
 
 
2152
 
2153
+ inpL = ggml_add(ctx0, cur, inpFF);
 
 
 
 
2154
  }
2155
 
2156
  cur = inpL;
2157
 
2158
  // norm
2159
  {
2160
+ wctx.use_buf(ctx0, 0);
2161
+
2162
  cur = ggml_norm(ctx0, cur);
2163
 
2164
+ wctx.use_buf(ctx0, 1);
2165
+
2166
  cur = ggml_add(ctx0,
2167
  ggml_mul(ctx0,
2168
  ggml_repeat(ctx0, model.d_ln_w, cur),
 
2170
  ggml_repeat(ctx0, model.d_ln_b, cur));
2171
  }
2172
 
2173
+ wctx.use_buf(ctx0, 0);
2174
+
2175
+ // compute logits only for the last token
2176
+ // comment this line to compute logits for all N tokens
2177
+ // might be useful in the future
2178
+ cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2179
+
2180
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2181
 
2182
+ wctx.use_buf(ctx0, -1);
2183
+
2184
  // run the computation
2185
  {
 
 
 
2186
  ggml_build_forward_expand(&gf, logits);
2187
  ggml_graph_compute (ctx0, &gf);
2188
  }
2189
 
2190
+ // extract logits for all N tokens
2191
+ //logits_out.resize(N*n_vocab);
2192
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
2193
+
2194
+ // extract logits only for the last token
2195
+ logits_out.resize(n_vocab);
2196
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2197
 
2198
  if (N > 1) {
2199
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2200
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
2201
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
2202
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
2203
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
2204
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
2205
  }
2206
 
2207
  ggml_free(ctx0);