ggerganov commited on
Commit
411c667
·
unverified ·
1 Parent(s): 59192b4

talk.wasm : GPT-2 meets Whisper in WebAssembly (#155)

Browse files

* talk : initial real-time transcription in the browser

* talk : polishing the UI

* talk : ready for beta testing

* talk.wasm : rename example

bindings/javascript/CMakeLists.txt CHANGED
@@ -9,12 +9,13 @@ target_link_libraries(${TARGET} PRIVATE
9
  )
10
 
11
  unset(EXTRA_FLAGS)
 
12
  if (WHISPER_WASM_SINGLE_FILE)
13
  set(EXTRA_FLAGS "-s SINGLE_FILE=1")
14
  message(STATUS "Embedding WASM inside whisper.js")
15
 
16
  add_custom_command(
17
- TARGET libwhisper POST_BUILD
18
  COMMAND ${CMAKE_COMMAND} -E copy
19
  ${CMAKE_BINARY_DIR}/bin/libwhisper.js
20
  ${CMAKE_CURRENT_SOURCE_DIR}/whisper.js
 
9
  )
10
 
11
  unset(EXTRA_FLAGS)
12
+
13
  if (WHISPER_WASM_SINGLE_FILE)
14
  set(EXTRA_FLAGS "-s SINGLE_FILE=1")
15
  message(STATUS "Embedding WASM inside whisper.js")
16
 
17
  add_custom_command(
18
+ TARGET ${TARGET} POST_BUILD
19
  COMMAND ${CMAKE_COMMAND} -E copy
20
  ${CMAKE_BINARY_DIR}/bin/libwhisper.js
21
  ${CMAKE_CURRENT_SOURCE_DIR}/whisper.js
bindings/javascript/whisper.js CHANGED
The diff for this file is too large to render. See raw diff
 
examples/CMakeLists.txt CHANGED
@@ -20,6 +20,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
20
 
21
  if (EMSCRIPTEN)
22
  add_subdirectory(whisper.wasm)
 
23
  else()
24
  add_subdirectory(main)
25
  add_subdirectory(stream)
 
20
 
21
  if (EMSCRIPTEN)
22
  add_subdirectory(whisper.wasm)
23
+ add_subdirectory(talk.wasm)
24
  else()
25
  add_subdirectory(main)
26
  add_subdirectory(stream)
examples/talk.wasm/CMakeLists.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # libtalk
3
+ #
4
+
5
+ set(TARGET libtalk)
6
+
7
+ add_executable(${TARGET}
8
+ emscripten.cpp
9
+ )
10
+
11
+ target_link_libraries(${TARGET} PRIVATE
12
+ whisper
13
+ )
14
+
15
+ unset(EXTRA_FLAGS)
16
+
17
+ if (WHISPER_WASM_SINGLE_FILE)
18
+ set(EXTRA_FLAGS "-s SINGLE_FILE=1")
19
+ message(STATUS "Embedding WASM inside talk.js")
20
+
21
+ add_custom_command(
22
+ TARGET ${TARGET} POST_BUILD
23
+ COMMAND ${CMAKE_COMMAND} -E copy
24
+ ${CMAKE_BINARY_DIR}/bin/libtalk.js
25
+ ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/talk.wasm/talk.js
26
+ )
27
+ endif()
28
+
29
+ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
30
+ --bind \
31
+ -s USE_PTHREADS=1 \
32
+ -s PTHREAD_POOL_SIZE=8 \
33
+ -s INITIAL_MEMORY=1400MB \
34
+ -s TOTAL_MEMORY=1400MB \
35
+ -s FORCE_FILESYSTEM=1 \
36
+ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
37
+ ${EXTRA_FLAGS} \
38
+ ")
39
+
40
+ #
41
+ # talk.wasm
42
+ #
43
+
44
+ set(TARGET talk.wasm)
45
+
46
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
examples/talk.wasm/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # talk
2
+
3
+ WIP IN PROGRESS
4
+
5
+ ref: https://github.com/ggerganov/whisper.cpp/issues/154
6
+
7
+ demo: https://talk.ggerganov.com
examples/talk.wasm/emscripten.cpp ADDED
@@ -0,0 +1,1379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml.h"
2
+ #include "whisper.h"
3
+
4
+ #include <emscripten.h>
5
+ #include <emscripten/bind.h>
6
+
7
+ #include <atomic>
8
+ #include <cassert>
9
+ #include <cmath>
10
+ #include <cstdio>
11
+ #include <cstring>
12
+ #include <fstream>
13
+ #include <map>
14
+ #include <mutex>
15
+ #include <string>
16
+ #include <thread>
17
+ #include <vector>
18
+ #include <regex>
19
+ #include <random>
20
+
21
+ std::string to_timestamp(int64_t t) {
22
+ int64_t sec = t/100;
23
+ int64_t msec = t - sec*100;
24
+ int64_t min = sec/60;
25
+ sec = sec - min*60;
26
+
27
+ char buf[32];
28
+ snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
29
+
30
+ return std::string(buf);
31
+ }
32
+
33
+ /////////////////////// GPT-2 BEGIN /////////////////////////
34
+ // TODO: move to a separate file
35
+
36
+ //
37
+ // Vocab utils
38
+ //
39
+
40
+ struct gpt_vocab {
41
+ using id = int32_t;
42
+ using token = std::string;
43
+
44
+ std::map<token, id> token_to_id;
45
+ std::map<id, token> id_to_token;
46
+ };
47
+
48
+ void replace(std::string & str, const std::string & needle, const std::string & replacement) {
49
+ size_t pos = 0;
50
+ while ((pos = str.find(needle, pos)) != std::string::npos) {
51
+ str.replace(pos, needle.length(), replacement);
52
+ pos += replacement.length();
53
+ }
54
+ }
55
+
56
+ std::map<std::string, int32_t> json_parse(const std::string & fname) {
57
+ std::map<std::string, int32_t> result;
58
+
59
+ // read file into string
60
+ std::string json;
61
+ {
62
+ std::ifstream ifs(fname);
63
+ if (!ifs) {
64
+ fprintf(stderr, "Failed to open %s\n", fname.c_str());
65
+ exit(1);
66
+ }
67
+
68
+ json = std::string((std::istreambuf_iterator<char>(ifs)),
69
+ (std::istreambuf_iterator<char>()));
70
+ }
71
+
72
+ if (json[0] != '{') {
73
+ return result;
74
+ }
75
+
76
+ // parse json
77
+ {
78
+ bool has_key = false;
79
+ bool in_token = false;
80
+
81
+ std::string str_key = "";
82
+ std::string str_val = "";
83
+
84
+ int n = json.size();
85
+ for (int i = 1; i < n; ++i) {
86
+ if (!in_token) {
87
+ if (json[i] == ' ') continue;
88
+ if (json[i] == '"') {
89
+ in_token = true;
90
+ continue;
91
+ }
92
+ } else {
93
+ if (json[i] == '\\' && i+1 < n) {
94
+ if (has_key == false) {
95
+ str_key += json[i];
96
+ } else {
97
+ str_val += json[i];
98
+ }
99
+ ++i;
100
+ } else if (json[i] == '"') {
101
+ if (has_key == false) {
102
+ has_key = true;
103
+ ++i;
104
+ while (json[i] == ' ') ++i;
105
+ ++i; // :
106
+ while (json[i] == ' ') ++i;
107
+ if (json[i] != '\"') {
108
+ while (json[i] != ',' && json[i] != '}') {
109
+ str_val += json[i++];
110
+ }
111
+ has_key = false;
112
+ } else {
113
+ in_token = true;
114
+ continue;
115
+ }
116
+ } else {
117
+ has_key = false;
118
+ }
119
+
120
+ ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
121
+ ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
122
+ ::replace(str_key, "\\\"", "\""); // \\\" -> "
123
+
124
+ try {
125
+ result[str_key] = std::stoi(str_val);
126
+ } catch (...) {
127
+ //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
128
+
129
+ }
130
+ str_key = "";
131
+ str_val = "";
132
+ in_token = false;
133
+ continue;
134
+ }
135
+ if (has_key == false) {
136
+ str_key += json[i];
137
+ } else {
138
+ str_val += json[i];
139
+ }
140
+ }
141
+ }
142
+ }
143
+
144
+ return result;
145
+ }
146
+
147
+ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
148
+ std::vector<std::string> words;
149
+
150
+ // first split the text into words
151
+ {
152
+ std::string str = text;
153
+ std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
154
+
155
+ std::regex re(pat);
156
+ std::smatch m;
157
+
158
+ while (std::regex_search(str, m, re)) {
159
+ for (auto x : m) {
160
+ words.push_back(x);
161
+ }
162
+ str = m.suffix();
163
+ }
164
+ }
165
+
166
+ // find the longest tokens that form the words:
167
+ std::vector<gpt_vocab::id> tokens;
168
+ for (const auto & word : words) {
169
+ if (word.size() == 0) continue;
170
+
171
+ int i = 0;
172
+ int n = word.size();
173
+ while (i < n) {
174
+ int j = n;
175
+ while (j > i) {
176
+ auto it = vocab.token_to_id.find(word.substr(i, j-i));
177
+ if (it != vocab.token_to_id.end()) {
178
+ tokens.push_back(it->second);
179
+ i = j;
180
+ break;
181
+ }
182
+ --j;
183
+ }
184
+ if (i == n) {
185
+ break;
186
+ }
187
+ if (j == i) {
188
+ auto sub = word.substr(i, 1);
189
+ if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
190
+ tokens.push_back(vocab.token_to_id.at(sub));
191
+ } else {
192
+ fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
193
+ }
194
+ ++i;
195
+ }
196
+ }
197
+ }
198
+
199
+ return tokens;
200
+ }
201
+
202
+ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
203
+ printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
204
+
205
+ vocab.token_to_id = ::json_parse(fname);
206
+
207
+ for (const auto & kv : vocab.token_to_id) {
208
+ vocab.id_to_token[kv.second] = kv.first;
209
+ }
210
+
211
+ printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
212
+
213
+ // print the vocabulary
214
+ //for (auto kv : vocab.token_to_id) {
215
+ // printf("'%s' -> %d\n", kv.first.data(), kv.second);
216
+ //}
217
+
218
+ return true;
219
+ }
220
+
221
+ gpt_vocab::id gpt_sample_top_k_top_p(
222
+ const gpt_vocab & vocab,
223
+ const float * logits,
224
+ int top_k,
225
+ double top_p,
226
+ double temp,
227
+ std::mt19937 & rng) {
228
+ int n_logits = vocab.id_to_token.size();
229
+
230
+ std::vector<std::pair<double, gpt_vocab::id>> logits_id;
231
+ logits_id.reserve(n_logits);
232
+
233
+ for (int i = 0; i < n_logits; i++) {
234
+ logits_id.push_back(std::make_pair(logits[i], i));
235
+ }
236
+
237
+ // find the top K tokens
238
+ std::partial_sort(
239
+ logits_id.begin(),
240
+ logits_id.begin() + top_k, logits_id.end(),
241
+ [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
242
+ return a.first > b.first;
243
+ });
244
+
245
+ logits_id.resize(top_k);
246
+
247
+ // normalize
248
+ {
249
+ double sum = 0.0f;
250
+ for (int i = 0; i < (int)logits_id.size(); i++) {
251
+ sum += logits_id[i].first;
252
+ }
253
+
254
+ sum = 1.0/sum;
255
+ for (int i = 0; i < (int)logits_id.size(); i++) {
256
+ logits_id[i].first *= sum;
257
+ }
258
+ }
259
+
260
+ if (top_p < 1.0f) {
261
+ {
262
+ double cumsum = 0.0f;
263
+ for (int i = 0; i < top_k; i++) {
264
+ cumsum += logits_id[i].first;
265
+ if (cumsum >= top_p) {
266
+ logits_id.resize(i+1);
267
+ break;
268
+ }
269
+ }
270
+ }
271
+
272
+ // normalize again
273
+ {
274
+ double sum = 0.0f;
275
+ for (int i = 0; i < (int)logits_id.size(); i++) {
276
+ sum += logits_id[i].first;
277
+ }
278
+
279
+ sum = 1.0/sum;
280
+ for (int i = 0; i < (int)logits_id.size(); i++) {
281
+ logits_id[i].first *= sum;
282
+ }
283
+ }
284
+ }
285
+
286
+ //printf("\n");
287
+ //for (int i = 0; i < (int)logits_id.size(); i++) {
288
+ // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
289
+ //}
290
+ //exit(0);
291
+
292
+ // sample from the obtained distribution
293
+ std::vector<double> probs;
294
+ probs.reserve(logits_id.size());
295
+
296
+ for (int i = 0; i < (int) logits_id.size(); i++) {
297
+ probs.push_back(logits_id[i].first);
298
+ }
299
+
300
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
301
+ int idx = dist(rng);
302
+
303
+ return logits_id[idx].second;
304
+ }
305
+
306
+ // default hparams (GPT-2 117M)
307
+ struct gpt2_hparams {
308
+ int32_t n_vocab = 50257;
309
+ int32_t n_ctx = 1024;
310
+ int32_t n_embd = 768;
311
+ int32_t n_head = 12;
312
+ int32_t n_layer = 12;
313
+ int32_t f16 = 1;
314
+ };
315
+
316
+ struct gpt2_layer {
317
+ // normalization
318
+ struct ggml_tensor * ln_1_g;
319
+ struct ggml_tensor * ln_1_b;
320
+
321
+ struct ggml_tensor * ln_2_g;
322
+ struct ggml_tensor * ln_2_b;
323
+
324
+ // attention
325
+ struct ggml_tensor * c_attn_attn_w;
326
+ struct ggml_tensor * c_attn_attn_b;
327
+
328
+ struct ggml_tensor * c_attn_proj_w;
329
+ struct ggml_tensor * c_attn_proj_b;
330
+
331
+ // mlp
332
+ struct ggml_tensor * c_mlp_fc_w;
333
+ struct ggml_tensor * c_mlp_fc_b;
334
+
335
+ struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
336
+ struct ggml_tensor * c_mlp_proj_b;
337
+ };
338
+
339
+ struct gpt2_model {
340
+ gpt2_hparams hparams;
341
+
342
+ // normalization
343
+ struct ggml_tensor * ln_f_g;
344
+ struct ggml_tensor * ln_f_b;
345
+
346
+ struct ggml_tensor * wte; // position embedding
347
+ struct ggml_tensor * wpe; // token embedding
348
+
349
+ std::vector<gpt2_layer> layers;
350
+
351
+ // key + value memory
352
+ struct ggml_tensor * memory_k;
353
+ struct ggml_tensor * memory_v;
354
+
355
+ //
356
+ struct ggml_context * ctx;
357
+ std::map<std::string, struct ggml_tensor *> tensors;
358
+ };
359
+
360
+ // load the model's weights from a file
361
+ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
362
+ printf("%s: loading model from '%s'\n", __func__, fname.c_str());
363
+
364
+ auto fin = std::ifstream(fname, std::ios::binary);
365
+ if (!fin) {
366
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
367
+ return false;
368
+ }
369
+
370
+ // verify magic
371
+ {
372
+ uint32_t magic;
373
+ fin.read((char *) &magic, sizeof(magic));
374
+ if (magic != 0x67676d6c) {
375
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
376
+ return false;
377
+ }
378
+ }
379
+
380
+ // load hparams
381
+ {
382
+ auto & hparams = model.hparams;
383
+
384
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
385
+ fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
386
+ fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
387
+ fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
388
+ fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
389
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
390
+
391
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
392
+ printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
393
+ printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
394
+ printf("%s: n_head = %d\n", __func__, hparams.n_head);
395
+ printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
396
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
397
+ }
398
+
399
+ // load vocab
400
+ {
401
+ int32_t n_vocab = 0;
402
+ fin.read((char *) &n_vocab, sizeof(n_vocab));
403
+
404
+ if (n_vocab != model.hparams.n_vocab) {
405
+ fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
406
+ __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
407
+ return false;
408
+ }
409
+
410
+ std::string word;
411
+ for (int i = 0; i < n_vocab; i++) {
412
+ uint32_t len;
413
+ fin.read((char *) &len, sizeof(len));
414
+
415
+ word.resize(len);
416
+ fin.read((char *) word.data(), len);
417
+
418
+ vocab.token_to_id[word] = i;
419
+ vocab.id_to_token[i] = word;
420
+ }
421
+ }
422
+
423
+ // for the big tensors, we have the option to store the data in 16-bit floats
424
+ // in order to save memory and also to speed up the computation
425
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
426
+
427
+ auto & ctx = model.ctx;
428
+
429
+ size_t ctx_size = 0;
430
+
431
+ {
432
+ const auto & hparams = model.hparams;
433
+
434
+ const int n_embd = hparams.n_embd;
435
+ const int n_layer = hparams.n_layer;
436
+ const int n_ctx = hparams.n_ctx;
437
+ const int n_vocab = hparams.n_vocab;
438
+
439
+ ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
440
+ ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
441
+
442
+ ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
443
+ ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
444
+
445
+ ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
446
+ ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
447
+
448
+ ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
449
+ ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
450
+
451
+ ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
452
+ ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
453
+
454
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
455
+ ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
456
+
457
+ ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
458
+ ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
459
+
460
+ ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
461
+ ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
462
+
463
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
464
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
465
+
466
+ ctx_size += (6 + 12*n_layer)*256; // object overhead
467
+
468
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
469
+ }
470
+
471
+ // create the ggml context
472
+ {
473
+ struct ggml_init_params params = {
474
+ .mem_size = ctx_size,
475
+ .mem_buffer = NULL,
476
+ };
477
+
478
+ model.ctx = ggml_init(params);
479
+ if (!model.ctx) {
480
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
481
+ return false;
482
+ }
483
+ }
484
+
485
+ // prepare memory for the weights
486
+ {
487
+ const auto & hparams = model.hparams;
488
+
489
+ const int n_embd = hparams.n_embd;
490
+ const int n_layer = hparams.n_layer;
491
+ const int n_ctx = hparams.n_ctx;
492
+ const int n_vocab = hparams.n_vocab;
493
+
494
+ model.layers.resize(n_layer);
495
+
496
+ model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
497
+ model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
498
+
499
+ model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
500
+ model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
501
+
502
+ // map by name
503
+ model.tensors["model/ln_f/g"] = model.ln_f_g;
504
+ model.tensors["model/ln_f/b"] = model.ln_f_b;
505
+
506
+ model.tensors["model/wte"] = model.wte;
507
+ model.tensors["model/wpe"] = model.wpe;
508
+
509
+ for (int i = 0; i < n_layer; ++i) {
510
+ auto & layer = model.layers[i];
511
+
512
+ layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
513
+ layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
514
+
515
+ layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
516
+ layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
517
+
518
+ layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
519
+ layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
520
+
521
+ layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
522
+ layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
523
+
524
+ layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
525
+ layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
526
+
527
+ layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
528
+ layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
529
+
530
+ // map by name
531
+ model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
532
+ model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
533
+
534
+ model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
535
+ model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
536
+
537
+ model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
538
+ model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
539
+
540
+ model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
541
+ model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
542
+
543
+ model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
544
+ model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
545
+
546
+ model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
547
+ model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
548
+ }
549
+ }
550
+
551
+ // key + value memory
552
+ {
553
+ const auto & hparams = model.hparams;
554
+
555
+ const int n_embd = hparams.n_embd;
556
+ const int n_layer = hparams.n_layer;
557
+ const int n_ctx = hparams.n_ctx;
558
+
559
+ const int n_mem = n_layer*n_ctx;
560
+ const int n_elements = n_embd*n_mem;
561
+
562
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
563
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
564
+
565
+ const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
566
+
567
+ printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
568
+ }
569
+
570
+ // load weights
571
+ {
572
+ size_t total_size = 0;
573
+
574
+ while (true) {
575
+ int32_t n_dims;
576
+ int32_t length;
577
+ int32_t ftype;
578
+
579
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
580
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
581
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
582
+
583
+ if (fin.eof()) {
584
+ break;
585
+ }
586
+
587
+ int32_t nelements = 1;
588
+ int32_t ne[2] = { 1, 1 };
589
+ for (int i = 0; i < n_dims; ++i) {
590
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
591
+ nelements *= ne[i];
592
+ }
593
+
594
+ std::string name(length, 0);
595
+ fin.read(&name[0], length);
596
+
597
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
598
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
599
+ return false;
600
+ }
601
+
602
+ auto tensor = model.tensors[name.data()];
603
+ if (ggml_nelements(tensor) != nelements) {
604
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
605
+ return false;
606
+ }
607
+
608
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
609
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
610
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
611
+ return false;
612
+ }
613
+
614
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
615
+
616
+ if (nelements*bpe != ggml_nbytes(tensor)) {
617
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
618
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
619
+ return false;
620
+ }
621
+
622
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
623
+
624
+ //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
625
+ total_size += ggml_nbytes(tensor);
626
+ }
627
+
628
+ printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
629
+ }
630
+
631
+ fin.close();
632
+
633
+ return true;
634
+ }
635
+
636
+ // evaluate the transformer
637
+ //
638
+ // - model: the model
639
+ // - n_threads: number of threads to use
640
+ // - n_past: the context size so far
641
+ // - embd_inp: the embeddings of the tokens in the context
642
+ // - embd_w: the predicted probabilities of the next token
643
+ //
644
+ bool gpt2_eval(
645
+ const gpt2_model & model,
646
+ const int n_threads,
647
+ const int n_past,
648
+ const std::vector<gpt_vocab::id> & embd_inp,
649
+ std::vector<float> & embd_w,
650
+ size_t & mem_per_token) {
651
+ const int N = embd_inp.size();
652
+
653
+ const auto & hparams = model.hparams;
654
+
655
+ const int n_embd = hparams.n_embd;
656
+ const int n_layer = hparams.n_layer;
657
+ const int n_ctx = hparams.n_ctx;
658
+ const int n_head = hparams.n_head;
659
+ const int n_vocab = hparams.n_vocab;
660
+
661
+ static size_t buf_size = 512u*1024*1024;
662
+ static void * buf = malloc(buf_size);
663
+
664
+ if (mem_per_token > 0 && mem_per_token*N > buf_size) {
665
+ const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
666
+ printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
667
+
668
+ // reallocate
669
+ buf_size = buf_size_new;
670
+ buf = realloc(buf, buf_size);
671
+ if (buf == nullptr) {
672
+ fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
673
+ return false;
674
+ }
675
+ }
676
+
677
+ struct ggml_init_params params = {
678
+ .mem_size = buf_size,
679
+ .mem_buffer = buf,
680
+ };
681
+
682
+ struct ggml_context * ctx0 = ggml_init(params);
683
+ struct ggml_cgraph gf = { .n_threads = n_threads };
684
+
685
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
686
+ memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
687
+
688
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
689
+ for (int i = 0; i < N; ++i) {
690
+ ((int32_t *) position->data)[i] = n_past + i;
691
+ }
692
+
693
+ // wte + wpe
694
+ struct ggml_tensor * inpL =
695
+ ggml_add(ctx0,
696
+ ggml_get_rows(ctx0, model.wte, embd),
697
+ ggml_get_rows(ctx0, model.wpe, position));
698
+
699
+ for (int il = 0; il < n_layer; ++il) {
700
+ struct ggml_tensor * cur;
701
+
702
+ // norm
703
+ {
704
+ // [ 768, N]
705
+ cur = ggml_norm(ctx0, inpL);
706
+
707
+ // cur = ln_1_g*cur + ln_1_b
708
+ // [ 768, N]
709
+ cur = ggml_add(ctx0,
710
+ ggml_mul(ctx0,
711
+ ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
712
+ cur),
713
+ ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
714
+ }
715
+
716
+ // attn
717
+ // [2304, 768] - model.layers[il].c_attn_attn_w
718
+ // [2304, 1] - model.layers[il].c_attn_attn_b
719
+ // [ 768, N] - cur (in)
720
+ // [2304, N] - cur (out)
721
+ //
722
+ // cur = attn_w*cur + attn_b
723
+ // [2304, N]
724
+ {
725
+ cur = ggml_mul_mat(ctx0,
726
+ ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
727
+ cur);
728
+
729
+ cur = ggml_add(ctx0,
730
+ ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
731
+ cur);
732
+ }
733
+
734
+ // self-attention
735
+ {
736
+ struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
737
+ struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
738
+ struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
739
+
740
+ // store key and value to memory
741
+ if (N >= 1) {
742
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
743
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
744
+
745
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
746
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
747
+ }
748
+
749
+ // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
750
+ // [64, N, 12]
751
+ struct ggml_tensor * Q =
752
+ ggml_permute(ctx0,
753
+ ggml_cpy(ctx0,
754
+ Qcur,
755
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
756
+ 0, 2, 1, 3);
757
+
758
+ // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
759
+ // [64, n_past + N, 12]
760
+ struct ggml_tensor * K =
761
+ ggml_permute(ctx0,
762
+ ggml_reshape_3d(ctx0,
763
+ ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
764
+ n_embd/n_head, n_head, n_past + N),
765
+ 0, 2, 1, 3);
766
+
767
+ // GG: flash attention
768
+ //struct ggml_tensor * V =
769
+ // ggml_cpy(ctx0,
770
+ // ggml_permute(ctx0,
771
+ // ggml_reshape_3d(ctx0,
772
+ // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
773
+ // n_embd/n_head, n_head, n_past + N),
774
+ // 1, 2, 0, 3),
775
+ // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
776
+
777
+ //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
778
+
779
+ // K * Q
780
+ // [n_past + N, N, 12]
781
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
782
+
783
+ // KQ_scaled = KQ / sqrt(n_embd/n_head)
784
+ // [n_past + N, N, 12]
785
+ struct ggml_tensor * KQ_scaled =
786
+ ggml_scale(ctx0,
787
+ KQ,
788
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
789
+ );
790
+
791
+ // KQ_masked = mask_past(KQ_scaled)
792
+ // [n_past + N, N, 12]
793
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
794
+
795
+ // KQ = soft_max(KQ_masked)
796
+ // [n_past + N, N, 12]
797
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
798
+
799
+ // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
800
+ // [n_past + N, 64, 12]
801
+ struct ggml_tensor * V_trans =
802
+ ggml_permute(ctx0,
803
+ ggml_reshape_3d(ctx0,
804
+ ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
805
+ n_embd/n_head, n_head, n_past + N),
806
+ 1, 2, 0, 3);
807
+
808
+ // KQV = transpose(V) * KQ_soft_max
809
+ // [64, N, 12]
810
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
811
+
812
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
813
+ // [64, 12, N]
814
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
815
+
816
+ // cur = KQV_merged.contiguous().view(n_embd, N)
817
+ // [768, N]
818
+ cur = ggml_cpy(ctx0,
819
+ KQV_merged,
820
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
821
+ }
822
+
823
+ // projection
824
+ // [ 768, 768] - model.layers[il].c_attn_proj_w
825
+ // [ 768, 1] - model.layers[il].c_attn_proj_b
826
+ // [ 768, N] - cur (in)
827
+ // [ 768, N] - cur (out)
828
+ //
829
+ // cur = proj_w*cur + proj_b
830
+ // [768, N]
831
+ {
832
+ cur = ggml_mul_mat(ctx0,
833
+ ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
834
+ cur);
835
+
836
+ cur = ggml_add(ctx0,
837
+ ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
838
+ cur);
839
+ }
840
+
841
+ // add the input
842
+ cur = ggml_add(ctx0, cur, inpL);
843
+
844
+ struct ggml_tensor * inpFF = cur;
845
+
846
+ // feed-forward network
847
+ {
848
+ // norm
849
+ {
850
+ cur = ggml_norm(ctx0, inpFF);
851
+
852
+ // cur = ln_2_g*cur + ln_2_b
853
+ // [ 768, N]
854
+ cur = ggml_add(ctx0,
855
+ ggml_mul(ctx0,
856
+ ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
857
+ cur),
858
+ ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
859
+ }
860
+
861
+ // fully connected
862
+ // [3072, 768] - model.layers[il].c_mlp_fc_w
863
+ // [3072, 1] - model.layers[il].c_mlp_fc_b
864
+ // [ 768, N] - cur (in)
865
+ // [3072, N] - cur (out)
866
+ //
867
+ // cur = fc_w*cur + fc_b
868
+ // [3072, N]
869
+ cur = ggml_mul_mat(ctx0,
870
+ ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
871
+ cur);
872
+
873
+ cur = ggml_add(ctx0,
874
+ ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
875
+ cur);
876
+
877
+ // GELU activation
878
+ // [3072, N]
879
+ cur = ggml_gelu(ctx0, cur);
880
+
881
+ // projection
882
+ // [ 768, 3072] - model.layers[il].c_mlp_proj_w
883
+ // [ 768, 1] - model.layers[il].c_mlp_proj_b
884
+ // [3072, N] - cur (in)
885
+ // [ 768, N] - cur (out)
886
+ //
887
+ // cur = proj_w*cur + proj_b
888
+ // [768, N]
889
+ cur = ggml_mul_mat(ctx0,
890
+ model.layers[il].c_mlp_proj_w_trans,
891
+ cur);
892
+
893
+ cur = ggml_add(ctx0,
894
+ ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
895
+ cur);
896
+ }
897
+
898
+ // input for next layer
899
+ inpL = ggml_add(ctx0, cur, inpFF);
900
+ }
901
+
902
+ // norm
903
+ {
904
+ // [ 768, N]
905
+ inpL = ggml_norm(ctx0, inpL);
906
+
907
+ // inpL = ln_f_g*inpL + ln_f_b
908
+ // [ 768, N]
909
+ inpL = ggml_add(ctx0,
910
+ ggml_mul(ctx0,
911
+ ggml_repeat(ctx0, model.ln_f_g, inpL),
912
+ inpL),
913
+ ggml_repeat(ctx0, model.ln_f_b, inpL));
914
+ }
915
+
916
+ // inpL = WTE * inpL
917
+ // [ 768, 50257] - model.wte
918
+ // [ 768, N] - inpL
919
+ inpL = ggml_mul_mat(ctx0, model.wte, inpL);
920
+
921
+ // logits -> probs
922
+ inpL = ggml_soft_max(ctx0, inpL);
923
+
924
+ // run the computation
925
+ ggml_build_forward_expand(&gf, inpL);
926
+ ggml_graph_compute (ctx0, &gf);
927
+
928
+ //if (n_past%100 == 0) {
929
+ // ggml_graph_print (&gf);
930
+ // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
931
+ //}
932
+
933
+ //embd_w.resize(n_vocab*N);
934
+ //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
935
+
936
+ // return result for just the last token
937
+ embd_w.resize(n_vocab);
938
+ memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
939
+
940
+ if (mem_per_token == 0) {
941
+ mem_per_token = ggml_used_mem(ctx0)/N;
942
+ }
943
+ //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
944
+
945
+ ggml_free(ctx0);
946
+
947
+ return true;
948
+ }
949
+
950
+ /////////////////////////////// GPT-2 END ////////////////////////////////
951
+
952
+ constexpr int N_THREAD = 8;
953
+
954
+ struct gpt2_state {
955
+ std::string prompt_base = R"(Hello, how are you?
956
+ I'm fine, thanks. How are you?
957
+ Thanks, I'm fine too. What are you doing?
958
+ I'm just sitting here.
959
+ It's a lovely day, isn't it?
960
+ Yes, it is.
961
+ Did you know that I'm a robot?
962
+ I wasn't aware of that.
963
+ )";
964
+
965
+ std::mt19937 rng;
966
+
967
+ gpt_vocab vocab;
968
+ gpt2_model model;
969
+
970
+ int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
971
+ int32_t n_predict = 32; // new tokens to predict
972
+
973
+ // sampling parameters
974
+ int32_t top_k = 40;
975
+ float top_p = 0.9f;
976
+ float temp = 1.0f;
977
+ };
978
+
979
+ struct gpt2_state g_gpt2;
980
+
981
+ std::vector<float> g_pcmf32;
982
+
983
+ std::vector<struct whisper_context *> g_contexts(4, nullptr);
984
+
985
+ std::mutex g_mutex;
986
+ std::thread g_worker;
987
+ std::atomic<bool> g_running(false);
988
+
989
+ bool g_force_speak = false;
990
+ std::string g_text_to_speak = "";
991
+ std::string g_status = "idle";
992
+ std::string g_status_forced = "";
993
+
994
+ std::string gpt2_gen_text(const std::string & prompt) {
995
+ int n_past = 0;
996
+
997
+ std::vector<float> embd_w;
998
+
999
+ // tokenize the prompt
1000
+ std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, g_gpt2.prompt_base + prompt);
1001
+
1002
+ g_gpt2.n_predict = std::min(g_gpt2.n_predict, g_gpt2.model.hparams.n_ctx - (int) embd_inp.size());
1003
+
1004
+ std::vector<gpt_vocab::id> embd = embd_inp;
1005
+
1006
+ size_t mem_per_token = 3000000;
1007
+
1008
+ std::string result;
1009
+
1010
+ for (int i = embd.size(); i < embd_inp.size() + g_gpt2.n_predict; i++) {
1011
+ // predict
1012
+ if (embd.size() > 0) {
1013
+ if (!gpt2_eval(g_gpt2.model, g_gpt2.n_threads, n_past, embd, embd_w, mem_per_token)) {
1014
+ printf("gpt-2: failed to generate text\n");
1015
+ return "";
1016
+ }
1017
+ }
1018
+
1019
+ n_past += embd.size();
1020
+ embd.clear();
1021
+
1022
+ {
1023
+ // sample next token
1024
+ const int top_k = g_gpt2.top_k;
1025
+ const float top_p = g_gpt2.top_p;
1026
+ const float temp = g_gpt2.temp;
1027
+
1028
+ const int n_vocab = g_gpt2.model.hparams.n_vocab;
1029
+
1030
+ const gpt_vocab::id id = gpt_sample_top_k_top_p(g_gpt2.vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, g_gpt2.rng);
1031
+
1032
+ // add it to the context
1033
+ embd.push_back(id);
1034
+ }
1035
+
1036
+ result += g_gpt2.vocab.id_to_token[embd[0]];
1037
+
1038
+ // end of text token
1039
+ if (embd.back() == 50256 ||
1040
+ g_gpt2.vocab.id_to_token[embd.back()] == "." ||
1041
+ g_gpt2.vocab.id_to_token[embd.back()] == "!" ||
1042
+ g_gpt2.vocab.id_to_token[embd.back()] == "?") {
1043
+ break;
1044
+ }
1045
+ }
1046
+
1047
+ return result;
1048
+ }
1049
+
1050
+ void talk_set_status(const std::string & status) {
1051
+ std::lock_guard<std::mutex> lock(g_mutex);
1052
+ g_status = status;
1053
+ }
1054
+
1055
+ void talk_main(size_t index) {
1056
+ talk_set_status("loading data ...");
1057
+
1058
+ struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
1059
+
1060
+ wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
1061
+ wparams.offset_ms = 0;
1062
+ wparams.translate = false;
1063
+ wparams.no_context = true;
1064
+ wparams.single_segment = true;
1065
+ wparams.print_realtime = false;
1066
+ wparams.print_progress = false;
1067
+ wparams.print_timestamps = true;
1068
+ wparams.print_special_tokens = false;
1069
+
1070
+ wparams.max_tokens = 32;
1071
+ wparams.audio_ctx = 768;
1072
+
1073
+ wparams.language = "en";
1074
+
1075
+ g_gpt2.rng = std::mt19937(time(NULL));
1076
+
1077
+ // load the model
1078
+ {
1079
+ const int64_t t_start_us = ggml_time_us();
1080
+
1081
+ if (!gpt2_model_load("gpt-2.bin", g_gpt2.model, g_gpt2.vocab)) {
1082
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
1083
+ return;
1084
+ }
1085
+
1086
+ const int64_t t_load_us = ggml_time_us() - t_start_us;
1087
+
1088
+ printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
1089
+ }
1090
+
1091
+ std::vector<float> pcmf32;
1092
+
1093
+ auto & ctx = g_contexts[index];
1094
+
1095
+ const int64_t step_samples = 2*WHISPER_SAMPLE_RATE;
1096
+ const int64_t step_ms = (step_samples*1000)/WHISPER_SAMPLE_RATE;
1097
+ const int64_t window_samples = 9*WHISPER_SAMPLE_RATE;
1098
+
1099
+ auto t_last = std::chrono::high_resolution_clock::now();
1100
+
1101
+ talk_set_status("listening ...");
1102
+
1103
+ while (g_running) {
1104
+
1105
+ const auto t_now = std::chrono::high_resolution_clock::now();
1106
+ if (std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count() < step_ms) {
1107
+ {
1108
+ std::lock_guard<std::mutex> lock(g_mutex);
1109
+ g_pcmf32.clear();
1110
+ }
1111
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
1112
+ continue;
1113
+ }
1114
+
1115
+ talk_set_status("listening ...");
1116
+
1117
+ {
1118
+ std::unique_lock<std::mutex> lock(g_mutex);
1119
+
1120
+ if (g_pcmf32.size() < step_samples) {
1121
+ lock.unlock();
1122
+
1123
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
1124
+
1125
+ continue;
1126
+ }
1127
+
1128
+ pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
1129
+ }
1130
+
1131
+ // if energy in during last second is above threshold, then skip
1132
+ {
1133
+ float energy_all = 0.0f;
1134
+ float energy_1s = 0.0f;
1135
+
1136
+ for (size_t i = 0; i < pcmf32.size(); i++) {
1137
+ energy_all += fabsf(pcmf32[i]);
1138
+
1139
+ if (i >= pcmf32.size() - WHISPER_SAMPLE_RATE) {
1140
+ energy_1s += fabsf(pcmf32[i]);
1141
+ }
1142
+ }
1143
+
1144
+ energy_all /= pcmf32.size();
1145
+ energy_1s /= WHISPER_SAMPLE_RATE;
1146
+
1147
+ if (energy_1s > 0.1f*energy_all && !g_force_speak) {
1148
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
1149
+ continue;
1150
+ }
1151
+ }
1152
+
1153
+ talk_set_status("processing ...");
1154
+
1155
+ g_force_speak = false;
1156
+
1157
+ t_last = t_now;
1158
+
1159
+ {
1160
+ const auto t_start = std::chrono::high_resolution_clock::now();
1161
+
1162
+ int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
1163
+ if (ret != 0) {
1164
+ printf("whisper_full() failed: %d\n", ret);
1165
+ break;
1166
+ }
1167
+
1168
+ const auto t_end = std::chrono::high_resolution_clock::now();
1169
+
1170
+ printf("whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
1171
+ }
1172
+
1173
+ {
1174
+ std::string text_heard;
1175
+
1176
+ const int n_segments = whisper_full_n_segments(ctx);
1177
+ for (int i = n_segments - 1; i < n_segments; ++i) {
1178
+ const char * text = whisper_full_get_segment_text(ctx, i);
1179
+
1180
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
1181
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
1182
+
1183
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
1184
+
1185
+ text_heard += text;
1186
+ }
1187
+
1188
+ // remove text between brackets using regex
1189
+ {
1190
+ std::regex re("\\[.*?\\]");
1191
+ text_heard = std::regex_replace(text_heard, re, "");
1192
+ }
1193
+
1194
+ // remove text between brackets using regex
1195
+ {
1196
+ std::regex re("\\(.*?\\)");
1197
+ text_heard = std::regex_replace(text_heard, re, "");
1198
+ }
1199
+
1200
+ // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
1201
+ text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
1202
+
1203
+ // take first line
1204
+ text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
1205
+
1206
+ // remove leading and trailing whitespace
1207
+ text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
1208
+ text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
1209
+
1210
+ talk_set_status("'" + text_heard + "' - thinking how to respond ...");
1211
+
1212
+ const std::vector<gpt_vocab::id> tokens = ::gpt_tokenize(g_gpt2.vocab, text_heard);
1213
+
1214
+ printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
1215
+
1216
+ std::string text_to_speak;
1217
+
1218
+ if (tokens.size() > 0) {
1219
+ text_to_speak = gpt2_gen_text(text_heard + "\n");
1220
+ text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
1221
+ text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
1222
+
1223
+ std::lock_guard<std::mutex> lock(g_mutex);
1224
+
1225
+ // remove first 2 lines of base prompt
1226
+ {
1227
+ const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1228
+ if (pos != std::string::npos) {
1229
+ g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1230
+ }
1231
+ }
1232
+ {
1233
+ const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1234
+ if (pos != std::string::npos) {
1235
+ g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1236
+ }
1237
+ }
1238
+ g_gpt2.prompt_base += text_heard + "\n" + text_to_speak + "\n";
1239
+ } else {
1240
+ text_to_speak = gpt2_gen_text("");
1241
+ text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
1242
+ text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
1243
+
1244
+ std::lock_guard<std::mutex> lock(g_mutex);
1245
+
1246
+ const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
1247
+ if (pos != std::string::npos) {
1248
+ g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
1249
+ }
1250
+ g_gpt2.prompt_base += text_to_speak + "\n";
1251
+ }
1252
+
1253
+ printf("gpt-2: %s\n", text_to_speak.c_str());
1254
+
1255
+ //printf("========================\n");
1256
+ //printf("gpt-2: prompt_base:\n'%s'\n", g_gpt2.prompt_base.c_str());
1257
+ //printf("========================\n");
1258
+
1259
+ {
1260
+ std::lock_guard<std::mutex> lock(g_mutex);
1261
+ t_last = std::chrono::high_resolution_clock::now();
1262
+ g_text_to_speak = text_to_speak;
1263
+ g_pcmf32.clear();
1264
+ }
1265
+
1266
+ talk_set_status("speaking ...");
1267
+ }
1268
+ }
1269
+
1270
+ if (index < g_contexts.size()) {
1271
+ whisper_free(g_contexts[index]);
1272
+ g_contexts[index] = nullptr;
1273
+ }
1274
+ }
1275
+
1276
+ EMSCRIPTEN_BINDINGS(talk) {
1277
+ emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
1278
+ for (size_t i = 0; i < g_contexts.size(); ++i) {
1279
+ if (g_contexts[i] == nullptr) {
1280
+ g_contexts[i] = whisper_init(path_model.c_str());
1281
+ if (g_contexts[i] != nullptr) {
1282
+ g_running = true;
1283
+ if (g_worker.joinable()) {
1284
+ g_worker.join();
1285
+ }
1286
+ g_worker = std::thread([i]() {
1287
+ talk_main(i);
1288
+ });
1289
+
1290
+ return i + 1;
1291
+ } else {
1292
+ return (size_t) 0;
1293
+ }
1294
+ }
1295
+ }
1296
+
1297
+ return (size_t) 0;
1298
+ }));
1299
+
1300
+ emscripten::function("free", emscripten::optional_override([](size_t index) {
1301
+ if (g_running) {
1302
+ g_running = false;
1303
+ }
1304
+ }));
1305
+
1306
+ emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
1307
+ --index;
1308
+
1309
+ if (index >= g_contexts.size()) {
1310
+ return -1;
1311
+ }
1312
+
1313
+ if (g_contexts[index] == nullptr) {
1314
+ return -2;
1315
+ }
1316
+
1317
+ {
1318
+ std::lock_guard<std::mutex> lock(g_mutex);
1319
+ const int n = audio["length"].as<int>();
1320
+
1321
+ emscripten::val heap = emscripten::val::module_property("HEAPU8");
1322
+ emscripten::val memory = heap["buffer"];
1323
+
1324
+ g_pcmf32.resize(n);
1325
+
1326
+ emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
1327
+ memoryView.call<void>("set", audio);
1328
+ }
1329
+
1330
+ return 0;
1331
+ }));
1332
+
1333
+ emscripten::function("force_speak", emscripten::optional_override([](size_t index) {
1334
+ {
1335
+ std::lock_guard<std::mutex> lock(g_mutex);
1336
+ g_force_speak = true;
1337
+ }
1338
+ }));
1339
+
1340
+ emscripten::function("get_text_context", emscripten::optional_override([]() {
1341
+ std::string text_context;
1342
+
1343
+ {
1344
+ std::lock_guard<std::mutex> lock(g_mutex);
1345
+ text_context = g_gpt2.prompt_base;
1346
+ }
1347
+
1348
+ return text_context;
1349
+ }));
1350
+
1351
+ emscripten::function("get_text_to_speak", emscripten::optional_override([]() {
1352
+ std::string text_to_speak;
1353
+
1354
+ {
1355
+ std::lock_guard<std::mutex> lock(g_mutex);
1356
+ text_to_speak = std::move(g_text_to_speak);
1357
+ }
1358
+
1359
+ return text_to_speak;
1360
+ }));
1361
+
1362
+ emscripten::function("get_status", emscripten::optional_override([]() {
1363
+ std::string status;
1364
+
1365
+ {
1366
+ std::lock_guard<std::mutex> lock(g_mutex);
1367
+ status = g_status_forced.empty() ? g_status : g_status_forced;
1368
+ }
1369
+
1370
+ return status;
1371
+ }));
1372
+
1373
+ emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
1374
+ {
1375
+ std::lock_guard<std::mutex> lock(g_mutex);
1376
+ g_status_forced = status;
1377
+ }
1378
+ }));
1379
+ }
examples/talk.wasm/index-tmpl.html ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en-us">
3
+ <head>
4
+ <title>Talk - GPT-2 meets Whisper in WebAssembly</title>
5
+
6
+ <style>
7
+ #output {
8
+ width: 100%;
9
+ height: 100%;
10
+ margin: 0 auto;
11
+ margin-top: 10px;
12
+ border-left: 0px;
13
+ border-right: 0px;
14
+ padding-left: 0px;
15
+ padding-right: 0px;
16
+ display: block;
17
+ background-color: black;
18
+ color: white;
19
+ font-size: 10px;
20
+ font-family: 'Lucida Console', Monaco, monospace;
21
+ outline: none;
22
+ white-space: pre;
23
+ overflow-wrap: normal;
24
+ overflow-x: scroll;
25
+ }
26
+ </style>
27
+ </head>
28
+ <body>
29
+ <div id="main-container">
30
+ <b>Talk - GPT-2 meets Whisper in WebAssembly</b>
31
+
32
+ <br><br>
33
+
34
+ On this page you can talk with an AI entity. It uses:
35
+
36
+ <ul>
37
+ <li><a href="https://github.com/ggerganov/whisper.cpp">OpenAI's Whisper</a> model to listen to you as you speak in the microphone</li>
38
+ <li><a href="https://github.com/ggerganov/ggml/tree/master/examples/gpt-2">OpenAI's GPT-2</a> model to generate a text response</li>
39
+ <li><a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API">Web Speech API</a> to speak the response to you through the speakers</li>
40
+ </ul>
41
+
42
+ All of this runs <b>locally in your browser</b> using WebAssembly.<br>
43
+ You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">GitHub</a>.
44
+
45
+ <br><br>
46
+
47
+ <hr>
48
+
49
+ <div id="model-whisper">
50
+ <span id="model-whisper-status">Whisper model:</span>
51
+ <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
52
+ <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
53
+ <span id="fetch-whisper-progress"></span>
54
+
55
+ <!--
56
+ <input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
57
+ -->
58
+ </div>
59
+
60
+ <br>
61
+
62
+ <div id="model-gpt-2">
63
+ <span id="model-gpt-2-status">GPT-2 model:</span>
64
+ <button id="fetch-gpt-2-small" onclick="loadGPT2('small')">small 117M (240 MB)</button>
65
+ <!--<button id="fetch-gpt-2-medium" onclick="loadGPT2('medium')">medium 345M (720 MB)</button>-->
66
+ <span id="fetch-gpt-2-progress"></span>
67
+
68
+ <!--
69
+ <input type="file" id="file" name="file" onchange="loadFile(event, 'gpt-2.bin')" />
70
+ -->
71
+ </div>
72
+
73
+ <br>
74
+
75
+ <div id="input">
76
+ <button id="start" onclick="onStart()">Start</button>
77
+ <button id="stop" onclick="onStop()" disabled>Stop</button>
78
+ <select id="voice" onchange="onVoiceChange()">
79
+ <option value="0">Default</option>
80
+ </select>
81
+ <button id="speak" onclick="onSpeak('Hello')">Say hello</button>
82
+ <button id="speak" onclick="onSpeakRandom()">Say something</button>
83
+ <button id="speak" onclick="clearCache()">Clear Cache</button>
84
+ </div>
85
+
86
+ <br>
87
+
88
+ <div id="state">
89
+ Status: <b><span id="state-status">idle</span></b>
90
+
91
+ <pre id="state-context">[The text context will be displayed here]</pre>
92
+ </div>
93
+
94
+ <hr>
95
+
96
+ Debug output:
97
+ <textarea id="output" rows="20"></textarea>
98
+
99
+ <br>
100
+
101
+ <b>Troubleshooting</b>
102
+
103
+ <br><br>
104
+
105
+ The page does some heavy computations, so make sure:
106
+
107
+ <ul>
108
+ <li>To use a modern web browser (e.g. Chrome, Firefox)</li>
109
+ <li>To use a fast desktop or laptop computer (e.g. not a mobile phone)</li>
110
+ <li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
111
+ </ul>
112
+
113
+ <br><br>
114
+
115
+ <div class="cell-version">
116
+ <span>
117
+ |
118
+ Build time: <span class="nav-link">@GIT_DATE@</span> |
119
+ Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
120
+ Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
121
+ <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">Source Code</a> |
122
+ </span>
123
+ </div>
124
+ </div>
125
+
126
+ <script type='text/javascript'>
127
+ var printTextarea = (function() {
128
+ var element = document.getElementById('output');
129
+ if (element) element.alue = ''; // clear browser cache
130
+ return function(text) {
131
+ if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
132
+ console.log(text);
133
+ if (element) {
134
+ element.value += text + "\n";
135
+ element.scrollTop = element.scrollHeight; // focus on bottom
136
+ }
137
+ };
138
+ })();
139
+
140
+ const kRestartRecording_s = 15;
141
+ const kSampleRate = 16000;
142
+
143
+ window.AudioContext = window.AudioContext || window.webkitAudioContext;
144
+ window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
145
+
146
+ // web audio context
147
+ var context = null;
148
+
149
+ // audio data
150
+ var audio = null;
151
+ var audio0 = null;
152
+
153
+ // the talk instance
154
+ var instance = null;
155
+
156
+ // model names
157
+ var model_whisper = null;
158
+ var model_gpt_2 = null;
159
+
160
+ // speech synthesis
161
+ const synth = window.speechSynthesis;
162
+ var voice = null;
163
+
164
+ var Module = {
165
+ print: printTextarea,
166
+ printErr: printTextarea,
167
+ setStatus: function(text) {
168
+ printTextarea('js: ' + text);
169
+ },
170
+ monitorRunDependencies: function(left) {
171
+ },
172
+ preRun: function() {
173
+ printTextarea('js: Preparing ...');
174
+ },
175
+ postRun: function() {
176
+ printTextarea('js: Initialized successfully!');
177
+
178
+ // populate the voice list
179
+ var voices = synth.getVoices();
180
+ var el = document.getElementById('voice');
181
+
182
+ var n = 0;
183
+ voices.forEach(function(voice, i) {
184
+ if (!voice.lang.startsWith('en')) return;
185
+ var option = document.createElement('option');
186
+ option.value = i;
187
+ option.innerHTML = voice.name + ' (' + voice.lang + ')';
188
+ el.appendChild(option);
189
+ n++;
190
+ });
191
+
192
+ // select random voice
193
+ if (n > 0) {
194
+ for (var k = 0; k < 10; k++) {
195
+ var i = Math.floor(Math.random() * n);
196
+ el.selectedIndex = i;
197
+ voice = voices[document.getElementById('voice').options[i].value];
198
+
199
+ // give preference to Google voices
200
+ if (voice.name.startsWith('Google')) break;
201
+ }
202
+ }
203
+ }
204
+ };
205
+
206
+ // helper function
207
+ function convertTypedArray(src, type) {
208
+ var buffer = new ArrayBuffer(src.byteLength);
209
+ var baseView = new src.constructor(buffer).set(src);
210
+ return new type(buffer);
211
+ }
212
+
213
+ //
214
+ // fetch models
215
+ //
216
+
217
+ function storeFS(fname, buf) {
218
+ // write to WASM file using FS_createDataFile
219
+ // if the file exists, delete it
220
+ try {
221
+ Module.FS_unlink(fname);
222
+ } catch (e) {
223
+ // ignore
224
+ }
225
+
226
+ Module.FS_createDataFile("/", fname, buf, true, true);
227
+
228
+ printTextarea('js: stored model: ' + fname + ' size: ' + buf.length);
229
+
230
+ if (fname == 'whisper.bin') {
231
+ document.getElementById('model-whisper').innerHTML = 'Whisper model: loaded "' + model_whisper + '"!';
232
+ } else if (fname == 'gpt-2.bin') {
233
+ document.getElementById('model-gpt-2').innerHTML = 'GPT-2 model: loaded "' + model_gpt_2 + '"!';
234
+ }
235
+ }
236
+
237
+ let dbVersion = 1
238
+ let dbName = 'talk.ggerganov.com';
239
+ let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
240
+
241
+ // fetch a remote file from remote URL using the Fetch API
242
+ async function fetchRemote(url, elProgress) {
243
+ printTextarea('js: downloading with fetch()...');
244
+
245
+ const response = await fetch(
246
+ url,
247
+ {
248
+ method: 'GET',
249
+ headers: {
250
+ 'Content-Type': 'application/octet-stream',
251
+ },
252
+ }
253
+ );
254
+
255
+ if (!response.ok) {
256
+ printTextarea('js: failed to fetch ' + url);
257
+ return;
258
+ }
259
+
260
+ const contentLength = response.headers.get('content-length');
261
+ const total = parseInt(contentLength, 10);
262
+ const reader = response.body.getReader();
263
+
264
+ var chunks = [];
265
+ var receivedLength = 0;
266
+ var progressLast = -1;
267
+
268
+ while (true) {
269
+ const { done, value } = await reader.read();
270
+
271
+ if (done) {
272
+ break;
273
+ }
274
+
275
+ chunks.push(value);
276
+ receivedLength += value.length;
277
+
278
+ if (contentLength) {
279
+ // update progress bar element with the new percentage
280
+ elProgress.innerHTML = Math.round((receivedLength / total) * 100) + '%';
281
+
282
+ var progressCur = Math.round((receivedLength / total) * 10);
283
+ if (progressCur != progressLast) {
284
+ printTextarea('js: fetching ' + 10*progressCur + '% ...');
285
+ progressLast = progressCur;
286
+ }
287
+ }
288
+ }
289
+
290
+ var chunksAll = new Uint8Array(receivedLength);
291
+ var position = 0;
292
+ for (var chunk of chunks) {
293
+ chunksAll.set(chunk, position);
294
+ position += chunk.length;
295
+ }
296
+
297
+ return chunksAll;
298
+ }
299
+
300
+ // load remote data
301
+ // - check if the data is already in the IndexedDB
302
+ // - if not, fetch it from the remote URL and store it in the IndexedDB
303
+ // - store it in WASM memory
304
+ function loadRemote(url, dst, elProgress, size_mb) {
305
+ // query the storage quota and print it
306
+ navigator.storage.estimate().then(function (estimate) {
307
+ printTextarea('js: storage quota: ' + estimate.quota + ' bytes');
308
+ printTextarea('js: storage usage: ' + estimate.usage + ' bytes');
309
+ });
310
+
311
+ // check if the data is already in the IndexedDB
312
+ var request = indexedDB.open(dbName, dbVersion);
313
+
314
+ request.onupgradeneeded = function (event) {
315
+ var db = event.target.result;
316
+ if (db.version == 1) {
317
+ var objectStore = db.createObjectStore('models', { autoIncrement: false });
318
+ printTextarea('js: created IndexedDB ' + db.name + ' version ' + db.version);
319
+ } else {
320
+ // clear the database
321
+ var objectStore = event.currentTarget.transaction.objectStore('models');
322
+ objectStore.clear();
323
+ printTextarea('js: cleared IndexedDB ' + db.name + ' version ' + db.version);
324
+ }
325
+ };
326
+
327
+ request.onsuccess = function (event) {
328
+ var db = event.target.result;
329
+ var transaction = db.transaction(['models'], 'readonly');
330
+ var objectStore = transaction.objectStore('models');
331
+ var request = objectStore.get(url);
332
+
333
+ request.onsuccess = function (event) {
334
+ if (request.result) {
335
+ printTextarea('js: "' + url + '" is already in the IndexedDB');
336
+ storeFS(dst, request.result);
337
+ } else {
338
+ // data is not in the IndexedDB
339
+ printTextarea('js: "' + url + '" is not in the IndexedDB');
340
+
341
+ // alert and ask the user to confirm
342
+ if (!confirm('You are about to download ' + size_mb + ' MB of data.\nThe model data will be cached in the browser for future use.\n\nPress OK to continue.')) {
343
+ document.getElementById('fetch-whisper-tiny-en').style.display = 'inline-block';
344
+ document.getElementById('fetch-whisper-base-en').style.display = 'inline-block';
345
+ document.getElementById('fetch-gpt-2-small').style.display = 'inline-block';
346
+ return;
347
+ }
348
+
349
+ fetchRemote(url, elProgress).then(function (data) {
350
+ if (data) {
351
+ // store the data in the IndexedDB
352
+ var request = indexedDB.open(dbName, dbVersion);
353
+ request.onsuccess = function (event) {
354
+ var db = event.target.result;
355
+ var transaction = db.transaction(['models'], 'readwrite');
356
+ var objectStore = transaction.objectStore('models');
357
+ var request = objectStore.put(data, url);
358
+
359
+ request.onsuccess = function (event) {
360
+ printTextarea('js: "' + url + '" stored in the IndexedDB');
361
+ storeFS(dst, data);
362
+ };
363
+
364
+ request.onerror = function (event) {
365
+ printTextarea('js: failed to store "' + url + '" in the IndexedDB');
366
+ };
367
+ };
368
+ }
369
+ });
370
+ }
371
+ };
372
+
373
+ request.onerror = function (event) {
374
+ printTextarea('js: failed to get data from the IndexedDB');
375
+ };
376
+ };
377
+
378
+ request.onerror = function (event) {
379
+ printTextarea('js: failed to open IndexedDB');
380
+ };
381
+
382
+ request.onblocked = function (event) {
383
+ printTextarea('js: failed to open IndexedDB: blocked');
384
+ };
385
+
386
+ request.onabort = function (event) {
387
+ printTextarea('js: failed to open IndexedDB: abort');
388
+ };
389
+ }
390
+
391
+ function loadWhisper(model) {
392
+ let urls = {
393
+ 'tiny.en': 'https://talk.ggerganov.com/ggml-model-whisper-tiny.en.bin',
394
+ 'base.en': 'https://talk.ggerganov.com/ggml-model-whisper-base.en.bin',
395
+ };
396
+
397
+ let sizes = {
398
+ 'tiny.en': 75,
399
+ 'base.en': 142,
400
+ };
401
+
402
+ let url = urls[model];
403
+ let dst = 'whisper.bin';
404
+ let el = document.getElementById('fetch-whisper-progress');
405
+ let size_mb = sizes[model];
406
+
407
+ model_whisper = model;
408
+
409
+ document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
410
+ document.getElementById('fetch-whisper-base-en').style.display = 'none';
411
+ document.getElementById('model-whisper-status').innerHTML = 'Whisper model: loading "' + model + '" ... ';
412
+
413
+ loadRemote(url, dst, el, size_mb);
414
+ }
415
+
416
+ function loadGPT2(model) {
417
+ let urls = {
418
+ 'small': 'https://talk.ggerganov.com/ggml-model-gpt-2-117M.bin',
419
+ 'medium': 'https://talk.ggerganov.com/ggml-model-gpt-2-345M.bin',
420
+ };
421
+
422
+ let sizes = {
423
+ 'small': 240,
424
+ 'medium': 712,
425
+ };
426
+
427
+ let url = urls[model];
428
+ let dst = 'gpt-2.bin';
429
+ let el = document.getElementById('fetch-gpt-2-progress');
430
+ let size_mb = sizes[model];
431
+
432
+ model_gpt_2 = model;
433
+
434
+ document.getElementById('fetch-gpt-2-small').style.display = 'none';
435
+ document.getElementById('model-gpt-2-status').innerHTML = 'GPT-2 model: loading "' + model + '" ... ';
436
+
437
+ loadRemote(url, dst, el, size_mb);
438
+ }
439
+
440
+ //
441
+ // microphone
442
+ //
443
+
444
+ var mediaRecorder = null;
445
+ var doRecording = false;
446
+ var startTime = 0;
447
+
448
+ function stopRecording() {
449
+ Module.set_status("paused");
450
+ doRecording = false;
451
+ audio0 = null;
452
+ audio = null;
453
+ }
454
+
455
+ function startRecording() {
456
+ if (!context) {
457
+ context = new AudioContext({sampleRate: 16000});
458
+ }
459
+
460
+ Module.set_status("");
461
+
462
+ document.getElementById('start').disabled = true;
463
+ document.getElementById('stop').disabled = false;
464
+
465
+ doRecording = true;
466
+ startTime = Date.now();
467
+
468
+ var chunks = [];
469
+ var stream = null;
470
+
471
+ navigator.mediaDevices.getUserMedia({audio: true, video: false})
472
+ .then(function(s) {
473
+ stream = s;
474
+ mediaRecorder = new MediaRecorder(stream);
475
+ mediaRecorder.ondataavailable = function(e) {
476
+ chunks.push(e.data);
477
+
478
+ var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
479
+ var reader = new FileReader();
480
+
481
+ reader.onload = function(event) {
482
+ var buf = new Uint8Array(reader.result);
483
+
484
+ context.decodeAudioData(buf.buffer, function(audioBuffer) {
485
+ var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
486
+ var source = offlineContext.createBufferSource();
487
+ source.buffer = audioBuffer;
488
+ source.connect(offlineContext.destination);
489
+ source.start(0);
490
+
491
+ offlineContext.startRendering().then(function(renderedBuffer) {
492
+ audio = renderedBuffer.getChannelData(0);
493
+
494
+ //printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
495
+
496
+ var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
497
+ if (audio0 != null) {
498
+ audioAll.set(audio0, 0);
499
+ }
500
+ audioAll.set(audio, audio0 == null ? 0 : audio0.length);
501
+
502
+ if (instance) {
503
+ Module.set_audio(instance, audioAll);
504
+ }
505
+ });
506
+ }, function(e) {
507
+ audio = null;
508
+ });
509
+ }
510
+
511
+ reader.readAsArrayBuffer(blob);
512
+ };
513
+
514
+ mediaRecorder.onstop = function(e) {
515
+ if (doRecording) {
516
+ setTimeout(function() {
517
+ startRecording();
518
+ });
519
+ }
520
+ };
521
+
522
+ mediaRecorder.start(250);
523
+ })
524
+ .catch(function(err) {
525
+ printTextarea('js: error getting audio stream: ' + err);
526
+ });
527
+
528
+ var interval = setInterval(function() {
529
+ if (!doRecording) {
530
+ clearInterval(interval);
531
+ mediaRecorder.stop();
532
+ stream.getTracks().forEach(function(track) {
533
+ track.stop();
534
+ });
535
+
536
+ document.getElementById('start').disabled = false;
537
+ document.getElementById('stop').disabled = true;
538
+
539
+ mediaRecorder = null;
540
+ }
541
+
542
+ // if audio length is more than kRestartRecording_s seconds, restart recording
543
+ if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
544
+ if (doRecording) {
545
+ //printTextarea('js: restarting recording');
546
+
547
+ clearInterval(interval);
548
+ audio0 = audio;
549
+ audio = null;
550
+ mediaRecorder.stop();
551
+ stream.getTracks().forEach(function(track) {
552
+ track.stop();
553
+ });
554
+ }
555
+ }
556
+ }, 250);
557
+ }
558
+
559
+ //
560
+ // speak
561
+ //
562
+
563
+ function onSpeak(text) {
564
+ var voices = synth.getVoices();
565
+ var msg = new SpeechSynthesisUtterance(text);
566
+
567
+ if (voice == null) {
568
+ voice = voices[0];
569
+ }
570
+
571
+ msg.voice = voice;
572
+ synth.speak(msg);
573
+
574
+ if (doRecording) {
575
+ Module.set_status("speaking ...");
576
+ printTextarea('js: speaking');
577
+ stopRecording();
578
+ var interval = setInterval(function() {
579
+ if (!synth.speaking) {
580
+ printTextarea('js: done speaking');
581
+ clearInterval(interval);
582
+ startRecording();
583
+ } else {
584
+ Module.set_status("");
585
+ }
586
+ }, 100);
587
+ }
588
+ }
589
+
590
+ function onSpeakRandom() {
591
+ Module.force_speak(instance);
592
+ }
593
+
594
+ async function clearCache() {
595
+ if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
596
+ indexedDB.deleteDatabase(dbName);
597
+ }
598
+ }
599
+
600
+ //
601
+ // main
602
+ //
603
+
604
+ var intervalUpdate = null;
605
+
606
+ function onStart() {
607
+ if (!instance) {
608
+ instance = Module.init('whisper.bin');
609
+
610
+ if (instance) {
611
+ printTextarea("js: whisper initialized, instance: " + instance);
612
+ }
613
+ }
614
+
615
+ if (!instance) {
616
+ printTextarea("js: failed to initialize whisper");
617
+ return;
618
+ }
619
+
620
+ startRecording();
621
+
622
+ intervalUpdate = setInterval(function() {
623
+ var textToSpeak = Module.get_text_to_speak();
624
+
625
+ if (textToSpeak != null && textToSpeak.length > 1) {
626
+ onSpeak(textToSpeak);
627
+ }
628
+
629
+ document.getElementById('state-status').innerHTML = Module.get_status();
630
+ document.getElementById('state-context').innerHTML = Module.get_text_context();
631
+ }, 100);
632
+ }
633
+
634
+ function onStop() {
635
+ stopRecording();
636
+ }
637
+
638
+ function onVoiceChange() {
639
+ printTextarea('js: voice changed to: ' + document.getElementById('voice').value);
640
+ voice = synth.getVoices()[document.getElementById('voice').value];
641
+ }
642
+
643
+ </script>
644
+ <script type="text/javascript" src="talk.js"></script>
645
+ </body>
646
+ </html>
whisper.cpp CHANGED
@@ -2750,7 +2750,7 @@ int whisper_full(
2750
  } else {
2751
  text += whisper_token_to_str(ctx, tokens_cur[i].id);
2752
  }
2753
- if (tokens_cur[i].id > whisper_token_beg(ctx)) {
2754
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
2755
  if (!text.empty()) {
2756
  const auto tt0 = params.speed_up ? 2*t0 : t0;
 
2750
  } else {
2751
  text += whisper_token_to_str(ctx, tokens_cur[i].id);
2752
  }
2753
+ if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
2754
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
2755
  if (!text.empty()) {
2756
  const auto tt0 = params.speed_up ? 2*t0 : t0;